Matrices

We'll learn how to perform various matrix operations using JAX.

We'll cover the following...

Matrices are at the core of almost any data application. Even vectors can be treated as a matrix with either a row or column dimension of 1.

Note: We will use the common notation of capital letters of linear algebra for matrices in our codes as well.

Python 3.8
a = jnp.arange(5)
b = 2*jnp.arange(5)
c = -1*jnp.arange(5)
A = jnp.array((a,b,c)) #Concatenation of vectors to make a matrix
print(A)
print(A.shape)

Slicing

We can make submatrices by using the slicing (:) notation.

Note: Python uses 0-based indexing, as does JAX.

For example, as shown in line 10 below, a submatrix containing the first 2 rows and first 3 columns of the above matrix will be:

B = A[0:2,0:3]

Reshaping functions

Reshaping functions are commonly used in several applications, especially computer vision.

The rule behind any reshaping function is simple: If the input and output matrices have mxn and jxk dimensions respectively, then:

m×n=j×km\times n = j\times k ...