Solution: Basics of JAX
Explore the core basics of JAX by learning how to create and manipulate arrays, generate reproducible random numbers with keys, update arrays immutably, interpret jaxpr representations, and compute derivatives using JAX's automatic differentiation. This lesson helps you understand essential JAX features critical for building efficient deep learning models.
Let’s go through each solution in detail.
Solution 1: JAX arrays
We use the arange function to create a JAX array with values ranging from 0 to 66.
Let’s review the code:
- Line 1: We pass
67in thearange()function to generate a JAX array from 0 to 66. - Line 2: We print the
array.
Solution 2: Random numbers
We know that JAX implements random number generation using a random state. This random state is referred to as a key. Using the same key will always generate the same output. We can split this key and generate different ...