Introduction to Flax and Linen
While JAX has powerful features, coding our deep learning applications can still be tricky. This isn’t surprising, since JAX is intended to be a generic numeric computational library.
JAX does offer some pretty useful libraries for designing neural networks, though. We’ll review them in this chapter and consolidate our understanding to build the project at the conclusion.
As a high-performance neural network library, Flax aims to provide flexible designs while coding in JAX.
The main packages in Flax are:
- Neural networks
flax.linen is used for all the required neural network classes. Because of the wide range of functionalities for neural networks, we’ll restrict ourselves to only the most relevant ones here.