Trusted answers to developer questions

Talha Irfan

Grokking Modern System Design Interview for Engineers & Managers

Ace your System Design Interview and take your career to the next level. Learn to handle the design of applications like Netflix, Quora, Facebook, Uber, and many more in a 45-min interview. Learn the RESHADED framework for architecting web-scale applications by determining requirements, constraints, and assumptions before diving into a step-by-step design process.

**JAX** (**J**ust **A**fter e**X**ecution) is a recent machine/deep learning library developed by DeepMind.

Unlike Tensorflow, JAX is not an official Google product and is used for research purposes. The use of JAX is growing among the research community due to some really cool features. Additionally, the need to learn new syntax to use JAX is reduced by its NumPy-like syntax.

JAX is basically a *Just-In-Time (JIT)* compiler focused on harnessing the maximum number of

- Just-in-Time (JIT) compilation.
- Enables NumPy code on not only CPU but GPU and TPU as well.
- Automatic differentiation of NumPy and native Python code.
- Automatic vectorization.
- Express and compose transformations of numerical programs.
- Advanced (pseudo) random number generation.
- More options for control flow.

While JAX requires a course of its own, we will go through some of its features in this shot.

import jaximport jax.numpy as jnpimport numpy as npa = np.linspace(0.0,2.0,10)print(a)b = np.zeros((10,20))print(b)print("---And now JAX versions-----")jnp_a = jnp.linspace(0.0,2.0,10)jnp_b = jnp.zeros((10,20))print(jnp_a)print(jnp_b)

JAX has the same syntax as NumPy

Although JAX has the same syntax as NumPy, it differs in some aspects:

- Support of GPU/TPU (hence the warning).
- Different datatypes.
- Some restrictions.

The examples below will further illustrate these differences.

import jaximport jax.numpy as jnpimport numpy as npa = np.linspace(0.0,2.0,10)b = np.zeros((10,20))print(type(a))print("---JAX array---")jnp_a = jnp.linspace(0.0,2.0,10)jnp_b = jnp.zeros((10,20))print(type(jnp_a))

JAX and NumPy have different array types

JAX allows us to perform JIT compilation. All you have to do is call the function within `jit()`

or decorate it with `@jit`

, as shown below.

import jaximport jax.numpy as jnpimport numpy as npfrom jax import jitdef Square(x):return x*x@jitdef JSquare(x):return x*xprint(Square(4.1))print(JSquare(4.1))

We can use @jit decorator for JIT compilation

The output is the same, but you will notice the difference (due to tracer objects preempting the result) if you use some selection structure in the code above.

JAX has support for both Autograd and Autovectorization. The example below shows a glimpse of Autograd.

from jax import graddef Y(a):return 3*a*a-a+1dy = (grad(Y))dy2 = grad(grad(Y))dy3 = grad(grad(grad(Y)))a = 2.0print(dy(a))print(dy2(a))print(dy3(a))

We can use grad() for higher-order derivatives as well

For more details, please check the relevant course.

RELATED TAGS

python

communitycreator

CONTRIBUTOR

Talha Irfan

Copyright ©2022 Educative, Inc. All rights reserved

Grokking Modern System Design Interview for Engineers & Managers

Ace your System Design Interview and take your career to the next level. Learn to handle the design of applications like Netflix, Quora, Facebook, Uber, and many more in a 45-min interview. Learn the RESHADED framework for architecting web-scale applications by determining requirements, constraints, and assumptions before diving into a step-by-step design process.

Keep Exploring

Related Courses