Solution: Optimizers
Explore how to load and preprocess the Fashion MNIST dataset, define a vanilla CNN model using Flax's linen module, and implement training with AdaBelief and SGD optimizers. Understand batch training, loss calculation, and accuracy evaluation to improve model performance using JAX and Flax.
Loading dataset
We use the keras.dataset library to load and visualize the Fashion MNIST dataset.
In the code above:
Lines 1–5: We import the
keraslibrary from TensorFlow to load the dataset, thetrain_test_splitmethod fromsklearn.mode_selectionto split the dataset, and the JAX version of NumPy to perform numerical operations. Also, we import thenumpyandmatplotliblibraries for visualization.Lines 7–9: We load the Fashion MNIST dataset and combine the training and test datasets.
Lines 11–12: We define the
train_sizeas0.8and split the train and test dataset.Lines 14–16: We use the
forloop to create a plot of9images to display in the output.Line 17: We save the image in the
outputfolder to show in the output of the playground.Lines 19–23: We convert the train and test data to the JAX arrays.
Line 24: We reshape the dataset with the ...