Search⌘ K

Machine Learning with JAX

Learn about various machine learning functionalities available in the JAX library.

Taking derivatives with grad()

Computing derivatives in JAX is done using jax.grad.

Python 3.8
@jax.jit
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic)
print("Original: ", x_small)
print("Derivative: ", derivative_fn(x_small))

In the code above:

  • Lines 1–3: We apply the @jax.jit decorator to the sum_logistic() function.

  • Line 5: We generate a JAX array of values from zero to five and store it in x_small.

  • Line 6: We use the jax.grad() function to calculate the derivative of the sum_logistics() function with respect to its input. We store the derivative function to the derivative_fn.

  • Lines 7–8: We print the original JAX array, x_small, and the derivative of it using derivative_fn.

The grad function has a has_aux argument that allows us to return auxiliary data. For example, when building machine learning models, we can use it to return loss and gradients.

Python 3.8
@jax.jit
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)
x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic, has_aux=True)
print("Original: ", x_small)
print("Derivative: ", derivative_fn(x_small))

In the code above:

  • Line 6: We pass a True value to the has_aux argument to make sure that the sum_logistic() function returns the auxiliary data.
  • Line 8: We print the derivative of x_small using derivative_fn. We can see the auxiliary data along with the derivative results in the output.

We can perform advanced automatic differentiation using jax.vjp() and jax.jvp().

Auto-vectorization with vmap

The vmap (vectorizing map) allows us to write a function that can be applied to a single data, and then vmap will map it to a batch of data. Without vmap, the solution would be to loop through the batches while applying the function. Using jit with for loops is a little complicated and may be slower.

Python 3.8
seed = 98
key = jax.random.PRNGKey(seed)
mat = jax.random.normal(key, (150, 100))
batched_x = jax.random.normal(key, (10, 100))
def apply_matrix(v):
return jnp.dot(mat, v)
@jax.jit
def vmap_batched_apply_matrix(v_batched):
return jax.vmap(apply_matrix)(v_batched)
print('Auto-vectorized with vmap')
start_time = time.time()
print(vmap_batched_apply_matrix(batched_x).block_until_ready())
print("--- Execution time: %s seconds ---" % (time.time() - start_time))

In the code above:

  • Lines 1–4: We generate random matrices, mat, with dimensions of ...