Awesome JAX Libraries

A curated list of awesome JAX libraries.

We'll cover the following

We have just seen the tip of the iceberg here in the course when it comes to JAX libraries.
Here is a curated list of awesome JAX libraries, taken from Awesome JAX.


  • Neural Network Libraries
    • Flax: Centered on flexibility and clarity.
    • Haiku: Focused on simplicity, created by the authors of Sonnet at DeepMind.
    • Objax: Has an object oriented design similar to PyTorch.
    • Elegy: A high-level API for deep learning in JAX. Supports Flax, Haiku, and Optax.
    • Trax: “Batteries included” deep learning library focused on providing solutions for common workloads.
    • Jraph: Lightweight graph neural network library.
    • Neural Tangents: A high-level API for specifying neural networks of both finite and infinite width.
    • HuggingFace: Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
    • Equinox: Callable PyTrees and filtered JIT/grad transformations into neural networks in JAX.
  • NumPyro: Probabilistic programming based on the Pyro library.
  • Chex: Utilities to write and test reliable JAX code.
  • Optax: Gradient processing and optimization library.
  • RLax: Library for implementing reinforcement learning agents.
  • JAX, M.D.: Accelerated, differential molecular dynamics.
  • Coax: Turn RL papers into code, the easy way.
  • SymJAX: Symbolic CPU/GPU/TPU programming.
  • mcx: Express & compile probabilistic programs for performant inference.
  • Distrax: Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
  • cvxpylayers: Construct differentiable convex optimization layers.
  • TensorLy: Tensor learning made simple.
  • NetKet: Machine learning toolbox for quantum physics.

New libraries

This section contains well-made and useful libraries that have not necessarily been battle-tested by a large userbase yet.

  • Neural Network Libraries
    • FedJAX: Federated learning in JAX, built on Optax and Haiku.
    • Equivariant MLP: Construct equivariant neural network layers.
    • jax-resnet: Implementations and checkpoints for ResNet variants in Flax.
  • jax-unirep: Library implementing the UniRep model for protein machine learning applications.
  • jax-flows: Normalizing flows in JAX.
  • sklearn-jax-kernels: scikit-learn kernel matrices using JAX.
  • jax-cosmo: Differentiable cosmology library.
  • efax: Exponential Families in JAX.
  • mpi4jax: Combine MPI operations with your JAX code on CPUs and GPUs.
  • imax: Image augmentations and transformations.
  • FlaxVision: Flax version of TorchVision.
  • Oryx: Probabilistic programming language based on program transformations.
  • Optimal Transport Tools: Toolbox that bundles utilities to solve optimal transport problems.
  • delta PV: A photovoltaic simulator with automatic differentation.
  • jaxlie: Lie theory library for rigid body transformations and optimization.
  • BRAX: Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
  • flaxmodels: Pretrained models for JAX/Flax.
  • CR.Sparse: XLA accelerated algorithms for sparse representations and compressive sensing.
  • exojax: Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible with JAX.
  • JAXopt: Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
  • PIX: PIX is an image processing library in JAX, for JAX.
  • bayex: Bayesian Optimization powered by JAX.
  • JaxDF: Framework for differentiable simulators with arbitrary discretizations.

[Credits: Github:n2cholas/awesome-jax]

Get hands-on with 1200+ tech skills courses.