Search⌘ K
AI Features

Challenge: Basics of JAX

Explore the basics of JAX programming by working through exercises that involve testing arrays for zeros, defining functions with JAX NumPy, and using Autograd to compute derivatives. This lesson helps you build practical skills in JAX syntax and differentiable programming necessary for advanced deep learning challenges.

Important points

Before solving the exercises, please make sure to:

  • Do not change the name of the function below because it will be used by the evaluator.
  • Do not print anything.
  • Do not import any libraries. They are pre-imported. We’ll continue the convention of jnp for JAX NumPy.
  • Variables are usually initialized with 11 to help run the test case. Please override them with intended expressions/values in your implementation.

Note: These points will also apply to all the subsequent challenges.

Exercise 1: Testing array for zeros

Let’s test our familiarity with the basic JAX and NumPy syntax.

In this simple exercise, we are given a JAX array a. Identify whether it contains ...