Introduction to JAX Data Types and Array

Learn about the important datatypes and arrays in JAX.


JAX is a Python library offering high performance in machine learning with XLAXLA is a domain-specific compiler for linear algebra that helps in accelerating TensorFlow models. and just-in-time (JIT) compilation. Its API is similar to NumPy’s, with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include:

  • Automatic differentiation

  • Vectorization

  • JIT compilation

Press + to interact

Data types in JAX

The data types in NumPy are similar to those in JAX arrays. For instance, here is how we can create float and int data in JAX.

Press + to interact
import jax.numpy as jnp
x = jnp.float32(1.25844)
print("x :", x)
y = jnp.int32(45.25844)
print("y :", y)

In the code above, we import the JAX version of NumPy and name it jnp. We define two JAX variables, x and y, of types float32 and int32, respectively. Lastly, we print the values of both variables.

When we check the type of the data, we will see that it’s a DeviceArray. In the code below, we can see the same type for both float32 and int32 variables.

Press + to interact
import jax.numpy as jnp
x = jnp.float32(1.25844)
print("type of x: ",type(x))
y = jnp.int32(45.25844)
print("type of y: ",type(y))

The DeviceArray in JAX is the equivalent of numpy.ndarry in NumPy, and jax.numpy provides an interface similar to NumPy’s. However, JAX also provides jax.lax, a low-level API that is more powerful and stricter. For example, with jax.numpy, we can add numbers that have mixed types, but jax.lax will not allow this.

Ways to create JAX arrays

We can create JAX arrays like we would in NumPy. For example, we can use:

  • The arange() function
  • The linspace() function
  • Python lists
  • The zeros() function
  • The ones() function
  • The identity() or eye() function
%0 node_1 J node_2 A node_3 X node_1667569681532 node_1667569701511 A node_1667569675545 r node_1667569722553 r node_1667569688155 a node_1667569704686 y

Let’s look at the outputs of the functions above:

Press + to interact
import jax.numpy as jnp
a = jnp.arange(10)
print("a : ", a)
b = jnp.linspace(0, 10, 30)
print("b :", b)
scores = [50,60,70,30,25,70]
scores_array = jnp.array(scores)
print("scores_array :", scores_array)
c = jnp.zeros(5)
print("c :", c)
d = jnp.ones(5)
print("d :", d)
e = jnp.eye(5)
print("e :", e)
f = jnp.identity(5)
print("f :", f)

Let’s understand the code above:

  • Line 3: We call the jnp.arange() method that generates the JAX array of 10 elements from 0 to 9.

  • Line 6: We call the jnp.linspace() method that creates a JAX array of 30 values that are linearly distributed between 0 to 10. By default, the linspace() method generates 50 values. We can generate any number of values in a given range.

  • Lines 9–10: We define a Python list, scores, and use the jnp.array() method to convert the scores into a JAX array.

  • Line 13: We call the jnp.zeros() method to generate the JAX array of 5 zero values.

  • Line 16: Similarly, we call the jnp.ones() method to generate the JAX array of 5 one values.

  • Line 19: We create an identity matrixAn identity matrix is a square matrix where diagonal values are one, and all other elements are zero. of 5×55\times5 by calling the jnp.eye() method.

  • Line 22: Just like the jnp.eye() method, we can also generate an identity matrix with the jnp.identity() method.