Search⌘ K
AI Features

Just-in-Time Compilation

Explore how just-in-time compilation (JIT) in JAX accelerates deep learning computations by compiling Python functions at runtime. Understand asynchronous dispatch, pure functions, and tracer-based intermediate representations. Learn to optimize your functions for faster execution, especially on GPUs and TPUs, while handling static arguments and JAX control flow.

How fast is JAX?

JAX uses asynchronous dispatch, meaning it does not wait for computation to complete to give control back to the Python program. Therefore, when we perform an execution, JAX will return a future. JAX forces Python to wait for the execution when we want to print the output or convert the result to a NumPy array.

Therefore, if we want to compute the time of execution of a program, we’ll have to convert the result to a NumPy array using block_until_ready() to wait for the execution to complete. Generally speaking, NumPy will outperform JAX on the CPU, but JAX will outperform NumPy on accelerators and when using jitted functions.

Using jit() to speed up functions

The jit() method performs just-in-time compilationJust-in-time (JIT) compilation is a method to compile the code at runtime. It is also called dynamic translation. As a result, the code execution is sped up. with XLA. The jax.jit() method expects a pure function. Any side effects in the function will only be executed once. Let’s create a pure function and time its execution time without jit().

Python 3.8
def test_fn(sample_rate=3000,frequency=3):
x = jnp.arange(sample_rate)
y = np.sin(2*jnp.pi*frequency * (frequency/sample_rate))
return jnp.dot(x,y)
start_time = time.time()
x = test_fn()
print("--- %s seconds ---" % (time.time() - start_time))

In the code above:

  • Lines 1–4: We define a function, ...