Solution: Distributed Training with JAX and Flax
Let's review the solution.
Solution 1: Load the dataset
First, we define the dataset class like we did in previous chapters. However, this time for the Cars and Bikes dataset:
Press + to interact
from PIL import Imageimport pandas as pdfrom torch.utils.data import Datasetclass CarBikeDataset(Dataset):def __init__(self, root_dir, annotation_file, transform=None):self.root_dir = root_dirself.annotations = pd.read_csv(annotation_file)self.transform = transformdef __len__(self):return len(self.annotations)def __getitem__(self, index):img_id = self.annotations.iloc[index, 0]img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")y_label = torch.tensor(float(self.annotations.iloc[index, 1]))if self.transform is not None:img = self.transform(img)return (img, y_label)
In the code above:
- Lines 6–9: The parametrized constructor defines the root directory containing the images, image labels (using the annotation file), and image transformations (transforms).
- Line 11: The
__len__
function returns the number of samples in our dataset. - Lines 14–22: The
__getitem__
function finds the image in the dataset at indexidx
, converts it to an RGB, applies transformations (if any), and returns it along with its corresponding label that is converted to a tensor.
Visualizing the dataset
Let’s visualize one image from each class of the Cars and Bikes dataset.
Press + to interact
from PIL import Imageimg1 = Image.open("/usercode/cat.1.png")img1.save("output/img1.png")img2 = Image.open("/usercode/dog.3.png")img2.save("output/img2.png")
Next, we create a pandas DataFrame that will contain the categories.
Press + to interact
import osimport pandas as pdtrain_df = pd.DataFrame(columns=["img_path","label"])train_df["img_path"] = os.listdir("cars_and_bikes/")for idx, i in enumerate(os.listdir("cars_and_bikes/")):if "car" in i:train_df["label"][idx] = 0if "bike" in i:train_df["label"][idx] = 1train_df.to_csv (r'train_csv.csv', index = False, header=True)
In the code above:
- Line 4: We initialize a DataFrame for storing paths and labels of our images.
- Lines 5–11: We populate the DataFrame by getting
Get hands-on with 1400+ tech skills courses.