# JAX Overview

Let's get started with an introduction to JAX and its uses.

## We'll cover the following

## Python and deep learning

In the current era, most programmers are more or less familiar with the role of deep learning in shaping today’s world and some of its applications.

Some of the advancements in the field of deep learning can be attributed to the rapid rise in Python’s usage, and libraries like **NumPy**, **SciPy**, **Keras**, as well as more specialized ones like **PyTorch** or **Tensorflow**.

## What is JAX?

JAX (**J**ust **A**fter e**X**ecution) is a recent machine and deep learning library. Why should we invest our time in learning a new library, though?

Before delving deeper into JAX and its architecture, it’s useful to have a quick overview of its features. We’ll find the answer to this “*Why?*” below.

## JAX features

JAX is basically a focused on harnessing the maximum number of

- Just-in-Time (
**JIT**) compilation. - Enabling NumPy code on not only CPU but
**GPU**and**TPU**as well. of bothAutomatic differentiation A superior technique for computing the derivatives without any manual calculations **NumPy**and**native Python**code. .Automatic vectorization A technique for automatically batching the data using vectorized map - Expressing and composing
**transformations**of numerical programs. - Advanced (pseudo) random number generation.
- More options for control flow.

## JAX ecosystem

JAX does not stop there, though. It provides us with a whole ecosystem of exciting libraries like:

**Haiku**is a neural network library providing object-oriented programming models.**RLax**is a library for**deep reinforcement learning**.**Jraph**, pronounced “giraffe”, is a library used for Graph Neural Networks (GNNs).**Optax**provides an easy one-liner interface to utilize gradient-based optimization methods efficiently.**Chex**is used for testing purposes.