# Introduction to JAX Data Types and Array

Learn about the important datatypes and arrays in JAX.

## Overview

JAX is a Python library offering high performance in machine learning with

Automatic differentiation

Vectorization

JIT compilation

### Data types in JAX

The data types in NumPy are similar to those in JAX arrays. For instance, here is how we can create `float`

and `int`

data in JAX.

import jax.numpy as jnpx = jnp.float32(1.25844)print("x :", x)y = jnp.int32(45.25844)print("y :", y)

In the code above, we import the JAX version of NumPy and name it `jnp`

. We define two JAX variables, `x`

and `y`

, of types `float32`

and `int32`

, respectively. Lastly, we print the values of both variables.

When we check the type of the data, we will see that it’s a `DeviceArray`

. In the code below, we can see the same type for both `float32`

and `int32`

variables.

import jax.numpy as jnpx = jnp.float32(1.25844)print("type of x: ",type(x))y = jnp.int32(45.25844)print("type of y: ",type(y))

The `DeviceArray`

in JAX is the equivalent of `numpy.ndarry`

in NumPy, and `jax.numpy`

provides an interface similar to NumPy’s. However, JAX also provides `jax.lax`

, a low-level API that is more powerful and stricter. For example, with `jax.numpy`

, we can add numbers that have mixed types, but `jax.lax`

will not allow this.

## Ways to create JAX arrays

We can create JAX arrays like we would in NumPy. For example, we can use:

- The
`arange()`

function - The
`linspace()`

function - Python lists
- The
`zeros()`

function - The
`ones()`

function - The
`identity()`

or`eye()`

function

Let’s look at the outputs of the functions above:

import jax.numpy as jnpa = jnp.arange(10)print("a : ", a)b = jnp.linspace(0, 10, 30)print("b :", b)scores = [50,60,70,30,25,70]scores_array = jnp.array(scores)print("scores_array :", scores_array)c = jnp.zeros(5)print("c :", c)d = jnp.ones(5)print("d :", d)e = jnp.eye(5)print("e :", e)f = jnp.identity(5)print("f :", f)

Let’s understand the code above:

**Line 3:**We call the`jnp.arange()`

method that generates the JAX array of`10`

elements from 0 to 9.**Line 6:**We call the`jnp.linspace()`

method that creates a JAX array of`30`

values that are linearly distributed between 0 to 10. By default, the`linspace()`

method generates`50`

values. We can generate any number of values in a given range.**Lines 9–10:**We define a Python list,`scores`

, and use the`jnp.array()`

method to convert the`scores`

into a JAX array.**Line 13:**We call the`jnp.zeros()`

method to generate the JAX array of`5`

zero values.**Line 16:**Similarly, we call the`jnp.ones()`

method to generate the JAX array of`5`

one values.**Line 19:**We create an ofidentity matrix An identity matrix is a square matrix where diagonal values are one, and all other elements are zero. $5\times5$ by calling the`jnp.eye()`

method.**Line 22:**Just like the`jnp.eye()`

method, we can also generate an identity matrix with the`jnp.identity()`

method.