What is JAX?
Overview
JAX (Just After eXecution) is a recent machine/deep learning library developed by DeepMind.
Unlike Tensorflow, JAX is not an official Google product and is used for research purposes. The use of JAX is growing among the research community due to some really cool features. Additionally, the need to learn new syntax to use JAX is reduced by its NumPy-like syntax.
Features of JAX
JAX is basically a Just-In-Time (JIT) compiler focused on harnessing the maximum number of
- Just-in-Time (JIT) compilation.
- Enables NumPy code on not only CPU but GPU and TPU as well.
- Automatic differentiation of NumPy and native Python code.
- Automatic vectorization.
- Express and compose transformations of numerical programs.
- Advanced (pseudo) random number generation.
- More options for control flow.
While JAX requires a course of its own, we will go through some of its features in this shot.
import jaximport jax.numpy as jnpimport numpy as npa = np.linspace(0.0,2.0,10)print(a)b = np.zeros((10,20))print(b)print("---And now JAX versions-----")jnp_a = jnp.linspace(0.0,2.0,10)jnp_b = jnp.zeros((10,20))print(jnp_a)print(jnp_b)
Although JAX has the same syntax as NumPy, it differs in some aspects:
- Support of GPU/TPU (hence the warning).
- Different datatypes.
- Some restrictions.
The examples below will further illustrate these differences.
import jaximport jax.numpy as jnpimport numpy as npa = np.linspace(0.0,2.0,10)b = np.zeros((10,20))print(type(a))print("---JAX array---")jnp_a = jnp.linspace(0.0,2.0,10)jnp_b = jnp.zeros((10,20))print(type(jnp_a))
JIT compilation
JAX allows us to perform JIT compilation. All you have to do is call the function within jit() or decorate it with @jit, as shown below.
import jaximport jax.numpy as jnpimport numpy as npfrom jax import jitdef Square(x):return x*x@jitdef JSquare(x):return x*xprint(Square(4.1))print(JSquare(4.1))
The output is the same, but you will notice the difference (due to tracer objects preempting the result) if you use some selection structure in the code above.
Autograd
JAX has support for both Autograd and Autovectorization. The example below shows a glimpse of Autograd.
from jax import graddef Y(a):return 3*a*a-a+1dy = (grad(Y))dy2 = grad(grad(Y))dy3 = grad(grad(grad(Y)))a = 2.0print(dy(a))print(dy2(a))print(dy3(a))
For more details, please check the relevant course.
Free Resources