Project: GAN Using the JAX ecosystem
Let’s conclude the course with a working example. In this project, you will use the Flax to make a Generative Adversarial Network (GAN) and train it on the basic MNIST dataset.
This example will help us practice a number of features of JAX, including PRNG, PyTrees, JIT, etc. as well as Optax for optimization. Also, we will use the TrainState class for Flax we didn’t get a chance to practice during the course.