Introduction to Flax and Linen

While JAX has powerful features, coding our deep learning applications can still be tricky. This isn’t surprising, since JAX is intended to be a generic numeric computational library.

JAX does offer some pretty useful libraries for designing neural networks, though. We’ll review them in this chapter and consolidate our understanding to build the project at the conclusion.

Flax

As a high-performance neural network library, Flax aims to provide flexible designs while coding in JAX.

The main packages in Flax are:

  • Neural networks
  • Utilities

Neural networks

The package flax.linen is used for all the required neural network classes. Because of the wide range of functionalities for neural networks, we’ll restrict ourselves to only the most relevant ones here.

Get hands-on with 1200+ tech skills courses.