Solution: Distributed Training with JAX and Flax
Let's review the solution.
We'll cover the following...
We'll cover the following...
Solution 1: Load the dataset
First, we define the dataset class like we did in previous chapters. However, this time for the Cars and Bikes dataset:
In the code above:
- Lines 6–9: The parametrized constructor defines the root directory containing the images, image labels (using the annotation file), and image transformations (transforms).
- Line 11: The
__len__function returns the number of samples in our dataset. - Lines 14–22: The
__getitem__function finds the image in the dataset at indexidx, converts it to an RGB, applies transformations (if any), and returns it along with its corresponding label that is converted to a tensor.
Visualizing the dataset
Let’s visualize one image from each class of the Cars and Bikes dataset.
Next, we create a pandas DataFrame that will contain the categories.
In the code above:
- Line 4: We initialize a DataFrame for storing paths and labels of our images.
- Lines 5–11: We populate the DataFrame by getting the image paths from the directory