Solution: Distributed Training with JAX and Flax

Let's review the solution.

Solution 1: Load the dataset

First, we define the dataset class like we did in previous chapters. However, this time for the Cars and Bikes dataset:

Press + to interact
from PIL import Image
import pandas as pd
from torch.utils.data import Dataset
class CarBikeDataset(Dataset):
def __init__(self, root_dir, annotation_file, transform=None):
self.root_dir = root_dir
self.annotations = pd.read_csv(annotation_file)
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
img_id = self.annotations.iloc[index, 0]
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
y_label = torch.tensor(float(self.annotations.iloc[index, 1]))
if self.transform is not None:
img = self.transform(img)
return (img, y_label)

In the code above:

  • Lines 6–9: The parametrized constructor defines the root directory containing the images, image labels (using the annotation file), and image transformations (transforms).
  • Line 11: The __len__ function returns the number of samples in our dataset.
  • Lines 14–22: The __getitem__ function finds the image in the dataset at index idx, converts it to an RGB, applies transformations (if any), and returns it along with its corresponding label that is converted to a tensor.

Visualizing the dataset

Let’s visualize one image from each class of the Cars and Bikes dataset.

Press + to interact
from PIL import Image
img1 = Image.open("/usercode/cat.1.png")
img1.save("output/img1.png")
img2 = Image.open("/usercode/dog.3.png")
img2.save("output/img2.png")

Next, we create a pandas DataFrame that will contain the categories.

Press + to interact
import os
import pandas as pd
train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("cars_and_bikes/")
for idx, i in enumerate(os.listdir("cars_and_bikes/")):
if "car" in i:
train_df["label"][idx] = 0
if "bike" in i:
train_df["label"][idx] = 1
train_df.to_csv (r'train_csv.csv', index = False, header=True)

In the code above:

  • Line 4: We initialize a DataFrame for storing paths and labels of our images.
  • Lines 5–11: We populate the DataFrame by getting
...

Get hands-on with 1400+ tech skills courses.