Array Operations
Learn about the common array-related operations of NumPy and JAX.
We'll cover the following...
We'll cover the following...
Array operations
Operations on JAX arrays are similar to operations with NumPy arrays. For example, we can use the max()
, argmax()
, and sum()
functions the same as we do in NumPy.
Press + to interact
Python 3.8
matrix = jnp.arange(17,33)matrix = matrix.reshape(4,4)print("Matrix :",matrix)print("Maximum :",jnp.max(matrix))print("Argmax :",jnp.argmax(matrix))print("Minimum :",jnp.min(matrix))print("Argmin :",jnp.argmin(matrix))print("Sum :",jnp.sum(matrix))print("Square root :",jnp.sqrt(matrix))print("Transpose :",matrix.transpose())
Let’s review the code:
Lines 1–3: We create a JAX array
matrix
of values17
to32
and reshape it as the dimension of. Lastly, we use the ...