Search⌘ K
AI Features

Solution: Basics of JAX

Explore the core basics of JAX by learning how to create and manipulate arrays, generate reproducible random numbers with keys, update arrays immutably, interpret jaxpr representations, and compute derivatives using JAX's automatic differentiation. This lesson helps you understand essential JAX features critical for building efficient deep learning models.

Let’s go through each solution in detail.

Solution 1: JAX arrays

We use the arange function to create a JAX array with values ranging from 0 to 66.

Python 3.8
array=jnp.arange(67)
print('array=',array)

Let’s review the code:

  • Line 1: We pass 67 in the arange() function to generate a JAX array from 0 to 66.
  • Line 2: We print the array.

Solution 2: Random numbers

We know that JAX implements random number generation using a random state. This random state is referred to as a key. Using the same key will always generate the same output. We can split this key and generate different ...