Machine Learning with JAX
Learn about various machine learning functionalities available in the JAX library.
We'll cover the following...
Taking derivatives with grad()
Computing derivatives in JAX is done using jax.grad
.
@jax.jitdef sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))x_small = jnp.arange(6.)derivative_fn = jax.grad(sum_logistic)print("Original: ", x_small)print("Derivative: ", derivative_fn(x_small))
In the code above:
Lines 1–3: We apply the
@jax.jit
decorator to thesum_logistic()
function.Line 5: We generate a JAX array of values from zero to five and store it in
x_small
.Line 6: We use the
jax.grad()
function to calculate the derivative of thesum_logistics()
function with respect to its input. We store the derivative function to thederivative_fn
.Lines 7–8: We print the original JAX array,
x_small
, and the derivative of it usingderivative_fn
.
The grad
function has a has_aux
argument that allows us to return auxiliary data. For example, when building machine learning models, we can use it to return loss and gradients.
@jax.jitdef sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)x_small = jnp.arange(6.)derivative_fn = jax.grad(sum_logistic, has_aux=True)print("Original: ", x_small)print("Derivative: ", derivative_fn(x_small))
In the code above:
- Line 6: We pass a
True
value to thehas_aux
argument to make sure that thesum_logistic()
function returns the auxiliary data. - Line 8: We print the derivative of
x_small
usingderivative_fn
. We can see the auxiliary data along with the derivative results in the output.
We can perform advanced automatic differentiation using jax.vjp()
and jax.jvp()
.
Auto-vectorization with vmap
The vmap
(vectorizing map) allows us to write a function that can be applied to a single data, and then vmap
will map it to a batch of data. Without vmap
, the solution would be to loop through the batches while applying the function. Using jit
with for
loops is a little complicated and may be slower.
seed = 98key = jax.random.PRNGKey(seed)mat = jax.random.normal(key, (150, 100))batched_x = jax.random.normal(key, (10, 100))def apply_matrix(v):return jnp.dot(mat, v)@jax.jitdef vmap_batched_apply_matrix(v_batched):return jax.vmap(apply_matrix)(v_batched)print('Auto-vectorized with vmap')start_time = time.time()print(vmap_batched_apply_matrix(batched_x).block_until_ready())print("--- Execution time: %s seconds ---" % (time.time() - start_time))
In the code above:
Lines 1–4: We generate random matrices,
mat
, with dimensions of...