JAX Overview

Let's get started with an introduction to JAX and its uses.

Python and deep learning

In the current era, most programmers are more or less familiar with the role of deep learning in shaping today’s world and some of its applications.

Some of the advancements in the field of deep learning can be attributed to the rapid rise in Python’s usage, and libraries like NumPy, SciPy, Keras, as well as more specialized ones like PyTorch or Tensorflow.

What is JAX?

JAX (Just After eXecution) is a recent machine and deep learning library. Why should we invest our time in learning a new library, though?

Before delving deeper into JAX and its architecture, it’s useful to have a quick overview of its features. We’ll find the answer to this “Why?” below.

JAX features

JAX is basically a Just-In-Time (JIT) compilerA way of executing computer code that involves compilation at run time rather than the compile time. focused on harnessing the maximum number of FLOPsFloating-Point Operations Per Second to generate optimized code while using the simplicity of pure Python. Some of its most important features are:

  • Just-in-Time (JIT) compilation.
  • Enabling NumPy code on not only CPU but GPU and TPU as well.
  • Automatic differentiationA superior technique for computing the derivatives without any manual calculations of both NumPy and native Python code.
  • Automatic vectorizationA technique for automatically batching the data using vectorized map.
  • Expressing and composing transformations of numerical programs.
  • Advanced (pseudo) random number generation.
  • More options for control flow.

JAX ecosystem

JAX does not stop there, though. It provides us with a whole ecosystem of exciting libraries like:

  • Haiku is a neural network library providing object-oriented programming models.
  • RLax is a library for deep reinforcement learning.
  • Jraph, pronounced “giraffe”, is a library used for Graph Neural Networks (GNNs).
  • Optax provides an easy one-liner interface to utilize gradient-based optimization methods efficiently.
  • Chex is used for testing purposes.