A curated list of awesome JAX libraries.

We have just seen the tip of the iceberg here in the course when it comes to JAX libraries.
Here is a curated list of awesome JAX libraries, taken from Awesome JAX.


  • Neural Network Libraries
    • Flax: Centered on flexibility and clarity.
    • Haiku: Focused on simplicity, created by the authors of Sonnet at DeepMind.
    • Objax: Has an object oriented design similar to PyTorch.
    • Elegy: A high-level API for deep learning in JAX. Supports Flax, Haiku, and Optax.
    • Trax: “Batteries included” deep learning library focused on providing solutions for common workloads.
    • Jraph: Lightweight graph neural network library.
    • Neural Tangents: A high-level API for specifying neural networks of both finite and infinite width.
    • HuggingFace: Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
    • Equinox: Callable PyTrees and filtered JIT/grad transformations into neural networks in JAX.
  • NumPyro: Probabilistic programming based on the Pyro library.
  • Chex: Utilities to write and test reliable JAX code.
  • Optax: Gradient processing and optimization library.
  • RLax: Library for implementing reinforcement learning agents.
  • JAX, M.D.: Accelerated, differential molecular dynamics.
  • Coax: Turn RL papers into code, the easy way.
  • SymJAX: Symbolic CPU/GPU/TPU programming.
  • mcx: Express & compile probabilistic programs for performant inference.
  • Distrax: Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
  • cvxpylayers: Construct differentiable convex optimization layers.
  • TensorLy: Tensor learning made simple.
  • NetKet: Machine learning toolbox for quantum physics.

New libraries

This section contains well-made and useful libraries that have not necessarily been battle-tested by a large userbase yet.

  • Neural Network Libraries
    • FedJAX: Federated learning in JAX, built on Optax and Haiku.
    • Equivariant MLP: Construct equivariant neural network layers.
    • jax-resnet: Implementations and checkpoints for ResNet variants in Flax.
  • jax-unirep: Library implementing the UniRep model for protein machine learning applications.
  • jax-flows: Normalizing flows in JAX.
  • sklearn-jax-kernels: scikit-learn kernel matrices using JAX.
  • jax-cosmo: Differentiable cosmology library.
  • efax: Exponential Families in JAX.
  • mpi4jax: Combine MPI operations with your JAX code on CPUs and GPUs.
  • imax: Image augmentations and transformations.
  • FlaxVision: Flax version of TorchVision.
  • Oryx: Probabilistic programming language based on program transformations.
  • Optimal Transport Tools: Toolbox that bundles utilities to solve optimal transport problems.
  • delta PV: A photovoltaic simulator with automatic differentation.
  • jaxlie: Lie theory library for rigid body transformations and optimization.
  • BRAX: Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
  • flaxmodels: Pretrained models for JAX/Flax.
  • CR.Sparse: XLA accelerated algorithms for sparse representations and compressive sensing.
  • exojax: Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible with JAX.
  • JAXopt: Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
  • PIX: PIX is an image processing library in JAX, for JAX.
  • bayex: Bayesian Optimization powered by JAX.
  • JaxDF: Framework for differentiable simulators with arbitrary discretizations.

[Credits: Github:n2cholas/awesome-jax]

