Awesome JAX-based Models
This appendix extends the preceding one by adding awesome JAX-based models.
We'll cover the following
There are also some ready-made models (and projects) developed in JAX. This list will be extremely helpful for someone doing research or some data science task.
Models and projects
JAX
- Fourier Feature Networks: Official implementation of Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.
- kalman-jax: Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.
- GPJax: Gaussian processes in JAX.
- jaxns: Nested sampling in JAX.
- Amortized Bayesian Optimization: Code related to Amortized Bayesian Optimization over Discrete Spaces.
- Accurate Quantized Training: Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax.
- BNN-HMC: Implementation for the paper What Are Bayesian Neural Network Posteriors Really Like?.
- JAX-DFT: One-dimensional density functional theory (DFT) in JAX, with implementation of Kohn-Sham equations as regularizer: building prior knowledge into machine-learned physics.
- Robust Loss: Reference code for the paper A General and Adaptive Robust Loss Function.
Flax
- Performer: Flax implementation of the Performer (linear transformer via FAVOR+) architecture.
- JaxNeRF: Implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis with multi-device GPU/TPU support.
- Big Transfer (BiT): Implementation of Big Transfer (BiT): General Visual Representation Learning.
- JAX RL: Implementations of reinforcement learning algorithms.
- gMLP: Implementation of Pay Attention to MLPs.
- MLP Mixer: Minimal implementation of MLP-Mixer: An all-MLP Architecture for Vision.
- Distributed Shampoo: Implementation of Second Order Optimization Made Practical.
- NesT: Official implementation of Aggregating Nested Transformers.
- XMC-GAN: Official implementation of Cross-Modal Contrastive Learning for Text-to-Image Generation.
- FNet: Official implementation of FNet: Mixing Tokens with Fourier Transforms.
- GFSA: Official implementation of Learning Graph Structure With A Finite-State Automaton Layer.
- IPA-GNN: Official implementation of Learning to Execute Programs with Instruction Pointer Attention Graph Neural Networks.
- Flax Models: Collection of models and methods implemented in Flax.
- Protein LM: Implements BERT and autoregressive models for proteins, as described in Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences and ProGen: Language Modeling for Protein Generation.
- Slot Attention: Reference implementation for Differentiable Patch Selection for Image Recognition.
- Vision Transformer: Official implementation of An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
- FID computation: Port of mseitzer/pytorch-fid to Flax.
Haiku
- AlphaFold: Implementation of the inference pipeline of AlphaFold v2.0, presented in Highly accurate protein structure prediction with AlphaFold.
- Adversarial Robustness: Reference code for Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples and Fixing Data Augmentation to Improve Adversarial Robustness.
- Bootstrap Your Own Latent: Implementation for the paper Bootstrap your own latent: A new approach to self-supervised Learning.
- Gated Linear Networks: GLNs are a family of backpropagation-free neural networks.
- Glassy Dynamics: Open source implementation of the paper Unveiling the predictive power of static structure in glassy systems.
- MMV: Code for the models in Self-Supervised MultiModal Versatile Networks.
- Normalizer-Free Networks: Official Haiku implementation of NFNets.
- NuX: Normalizing flows with JAX.
- OGB-LSC: This repository contains DeepMind’s entry to the PCQM4M-LSC (quantum chemistry) and MAG240M-LSC (academic graph) tracks of the OGB Large-Scale Challenge (OGB-LSC).
- Persistent Evolution Strategies: Code used for the paper Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies.
- WikiGraphs: Baseline code to reproduce results in WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Dataset.
Trax
- Reformer: Implementation of the Reformer (efficient transformer) architecture.
[Credits: Github:n2cholas/awesome-jax]
Get hands-on with 1200+ tech skills courses.