Search⌘ K
AI Features

JAX and NumPy

Explore how JAX's NumPy variant relates to standard NumPy, including syntax similarities and core differences such as array mutability, GPU and TPU utilization, and handling of random number generation to optimize numerical computing.

This lesson assumes a certain level of familiarity with NumPy. Let’s first start with how normal NumPy and JAX relate.

JAX Numpy

JAX has its own variant of NumPy, which we can import as:

Python 3.8
import jax.numpy

One might worry that we have to re-learn a new NumPy from scratch, but luckily the syntax of NumPy and JAX is the same. For ...