# Awesome JAX Libraries

A curated list of awesome JAX libraries.

## We'll cover the following

We have just seen the tip of the iceberg here in the course when it comes to JAX libraries.

Here is a curated list of awesome JAX libraries, taken from Awesome JAX.

## Libraries

- Neural Network Libraries
- Flax: Centered on flexibility and clarity.
- Haiku: Focused on simplicity, created by the authors of Sonnet at DeepMind.
- Objax: Has an object oriented design similar to PyTorch.
- Elegy: A high-level API for deep learning in JAX. Supports Flax, Haiku, and Optax.
- Trax: “Batteries included” deep learning library focused on providing solutions for common workloads.
- Jraph: Lightweight graph neural network library.
- Neural Tangents: A high-level API for specifying neural networks of both finite and
*infinite*width. - HuggingFace: Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
- Equinox: Callable PyTrees and filtered JIT/grad transformations into neural networks in JAX.

- NumPyro: Probabilistic programming based on the Pyro library.
- Chex: Utilities to write and test reliable JAX code.
- Optax: Gradient processing and optimization library.
- RLax: Library for implementing reinforcement learning agents.
- JAX, M.D.: Accelerated, differential molecular dynamics.
- Coax: Turn RL papers into code, the easy way.
- SymJAX: Symbolic CPU/GPU/TPU programming.
- mcx: Express & compile probabilistic programs for performant inference.
- Distrax: Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
- cvxpylayers: Construct differentiable convex optimization layers.
- TensorLy: Tensor learning made simple.
- NetKet: Machine learning toolbox for quantum physics.

### New libraries

This section contains well-made and useful libraries that have not necessarily been battle-tested by a large userbase yet.

- Neural Network Libraries
- FedJAX: Federated learning in JAX, built on Optax and Haiku.
- Equivariant MLP: Construct equivariant neural network layers.
- jax-resnet: Implementations and checkpoints for ResNet variants in Flax.

- jax-unirep: Library implementing the UniRep model for protein machine learning applications.
- jax-flows: Normalizing flows in JAX.
- sklearn-jax-kernels:
`scikit-learn`

kernel matrices using JAX. - jax-cosmo: Differentiable cosmology library.
- efax: Exponential Families in JAX.
- mpi4jax: Combine MPI operations with your JAX code on CPUs and GPUs.
- imax: Image augmentations and transformations.
- FlaxVision: Flax version of TorchVision.
- Oryx: Probabilistic programming language based on program transformations.
- Optimal Transport Tools: Toolbox that bundles utilities to solve optimal transport problems.
- delta PV: A photovoltaic simulator with automatic differentation.
- jaxlie: Lie theory library for rigid body transformations and optimization.
- BRAX: Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
- flaxmodels: Pretrained models for JAX/Flax.
- CR.Sparse: XLA accelerated algorithms for sparse representations and compressive sensing.
- exojax: Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible with JAX.
- JAXopt: Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
- PIX: PIX is an image processing library in JAX, for JAX.
- bayex: Bayesian Optimization powered by JAX.
- JaxDF: Framework for differentiable simulators with arbitrary discretizations.

[Credits: Github:n2cholas/awesome-jax]

Get hands-on with 1200+ tech skills courses.