Just-in-Time Compilation

Learn about ways to accelerate functions in JAX.

How fast is JAX?

JAX uses asynchronous dispatch, meaning it does not wait for computation to complete to give control back to the Python program. Therefore, when we perform an execution, JAX will return a future. JAX forces Python to wait for the execution when we want to print the output or convert the result to a NumPy array.

Therefore, if we want to compute the time of execution of a program, we’ll have to convert the result to a NumPy array using block_until_ready() to wait for the execution to complete. Generally speaking, NumPy will outperform JAX on the CPU, but JAX will outperform NumPy on accelerators and when using jitted functions.

Get hands-on with 1200+ tech skills courses.