Trusted answers to developer questions
Trusted Answers to Developer Questions

Related Tags

python
communitycreator

What is JAX?

Talha Irfan

Grokking Modern System Design Interview for Engineers & Managers

Ace your System Design Interview and take your career to the next level. Learn to handle the design of applications like Netflix, Quora, Facebook, Uber, and many more in a 45-min interview. Learn the RESHADED framework for architecting web-scale applications by determining requirements, constraints, and assumptions before diving into a step-by-step design process.

Overview

JAX (Just After eXecution) is a recent machine/deep learning library developed by DeepMind.

Unlike Tensorflow, JAX is not an official Google product and is used for research purposes. The use of JAX is growing among the research community due to some really cool features. Additionally, the need to learn new syntax to use JAX is reduced by its NumPy-like syntax.

Features of JAX

JAX is basically a Just-In-Time (JIT) compiler focused on harnessing the maximum number of FLOPsFloating-Point Operations Per Second to generate optimized code while using the simplicity of pure Python. Some of the salient features of JAX are:

  • Just-in-Time (JIT) compilation.
  • Enables NumPy code on not only CPU but GPU and TPU as well.
  • Automatic differentiation of NumPy and native Python code.
  • Automatic vectorization.
  • Express and compose transformations of numerical programs.
  • Advanced (pseudo) random number generation.
  • More options for control flow.

While JAX requires a course of its own, we will go through some of its features in this shot.

import jax
import jax.numpy as jnp
import numpy as np
a = np.linspace(0.0,2.0,10)
print(a)
b = np.zeros((10,20))
print(b)
print("---And now JAX versions-----")
jnp_a = jnp.linspace(0.0,2.0,10)
jnp_b = jnp.zeros((10,20))
print(jnp_a)
print(jnp_b)
JAX has the same syntax as NumPy

Although JAX has the same syntax as NumPy, it differs in some aspects:

  • Support of GPU/TPU (hence the warning).
  • Different datatypes.
  • Some restrictions.

The examples below will further illustrate these differences.

import jax
import jax.numpy as jnp
import numpy as np
a = np.linspace(0.0,2.0,10)
b = np.zeros((10,20))
print(type(a))
print("---JAX array---")
jnp_a = jnp.linspace(0.0,2.0,10)
jnp_b = jnp.zeros((10,20))
print(type(jnp_a))
JAX and NumPy have different array types

JIT compilation

JAX allows us to perform JIT compilation. All you have to do is call the function within jit() or decorate it with @jit, as shown below.

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
def Square(x):
return x*x
@jit
def JSquare(x):
return x*x
print(Square(4.1))
print(JSquare(4.1))
We can use @jit decorator for JIT compilation

The output is the same, but you will notice the difference (due to tracer objects preempting the result) if you use some selection structure in the code above.

Autograd

JAX has support for both Autograd and Autovectorization. The example below shows a glimpse of Autograd.

from jax import grad
def Y(a):
return 3*a*a-a+1
dy = (grad(Y))
dy2 = grad(grad(Y))
dy3 = grad(grad(grad(Y)))
a = 2.0
print(dy(a))
print(dy2(a))
print(dy3(a))
We can use grad() for higher-order derivatives as well

For more details, please check the relevant course.

RELATED TAGS

python
communitycreator

CONTRIBUTOR

Talha Irfan
Copyright ©2022 Educative, Inc. All rights reserved

Grokking Modern System Design Interview for Engineers & Managers

Ace your System Design Interview and take your career to the next level. Learn to handle the design of applications like Netflix, Quora, Facebook, Uber, and many more in a 45-min interview. Learn the RESHADED framework for architecting web-scale applications by determining requirements, constraints, and assumptions before diving into a step-by-step design process.

Keep Exploring