Challenge: Basics of JAX
Explore the basics of JAX programming by working through exercises that involve testing arrays for zeros, defining functions with JAX NumPy, and using Autograd to compute derivatives. This lesson helps you build practical skills in JAX syntax and differentiable programming necessary for advanced deep learning challenges.
We'll cover the following...
Important points
Before solving the exercises, please make sure to:
- Do not change the name of the function below because it will be used by the evaluator.
- Do not print anything.
- Do not import any libraries. They are pre-imported. We’ll continue the convention of
jnpfor JAX NumPy. - Variables are usually initialized with to help run the test case. Please override them with intended expressions/values in your implementation.
Note: These points will also apply to all the subsequent challenges.
Exercise 1: Testing array for zeros
Let’s test our familiarity with the basic JAX and NumPy syntax.
In this simple exercise, we are given a JAX array a. Identify whether it contains or not. In other words, if it contains an(y) instance of zero, it should return true and vice-versa.
Congratulations on completing this challenge. Now we will begin the real deal by practicing Autograd.
...