Search⌘ K
AI Features

Challenge: Basics of JAX

Explore the basics of JAX programming by working through exercises that involve testing arrays for zeros, defining functions with JAX NumPy, and using Autograd to compute derivatives. This lesson helps you build practical skills in JAX syntax and differentiable programming necessary for advanced deep learning challenges.

Important points

Before solving the exercises, please make sure to:

  • Do not change the name of the function below because it will be used by the evaluator.
  • Do not print anything.
  • Do not import any libraries. They are pre-imported. We’ll continue the convention of jnp for JAX NumPy.
  • Variables are usually initialized with 11 to help run the test case. Please override them with intended expressions/values in your implementation.

Note: These points will also apply to all the subsequent challenges.

Exercise 1: Testing array for zeros

Let’s test our familiarity with the basic JAX and NumPy syntax.

In this simple exercise, we are given a JAX array a. Identify whether it contains 00 or not. In other words, if it contains an(y) instance of zero, it should return true and vice-versa.

Python 3.8
#a will be any JAX array
def TestZeroExists(a):
#Please use the condition below to test for zeros
if (1==1): #Condition is initialized with just a dummy 1=1 case. Please override it
return False
else:
return True

Congratulations on completing this challenge. Now we will begin the real deal by practicing Autograd.

...