Challenge: Distributed Training with JAX and Flax

We will perform distributed training using JAX and Flax in this challenge. We have imported all the necessary libraries for you.

Challenge 1: Load the dataset

In the /usr/local/notebooks directory, we have a dataset in a zipped folder,, containing images from two classes: cars and bikes. There are twelve images from each class, making a total of 24 images. Load the dataset using the image paths and labels (each image is named with the class it belongs to along with a serial number, e.g., bike.0.jpg or car.4.jpg). Moreover, define a Dataset class for loading it using the DataLoader to create training and validation sets.

