The MNIST Dataset Class

We’re learning about PyTorch so we should try to load and use data the PyTorch way. Let's learn to load the MNIST Dataset in the PyTorch way using a class.

We'll cover the following

MNIST dataset class

We’ve previously loaded the MNIST data from a CSV file into a pandas dataframe. We could use the data directly from the dataframe and that would be fine. However, since we’re learning about PyTorch we should try to load and use data in the PyTorch way.

PyTorch can do useful things like automatically shuffling data, loading it with multiple processes in parallel, and providing it in batches. PyTorch does this using torch.utils.data.DataLoader which expects the data itself to come through a torch.utils.data.Dataset object.

To keep things simple, we won’t be using shuffling or batching, but we will work with the torch.utils.data.Dataset class to gain some experience of working with PyTorch’s machinery.

We import the PyTorch torch.utils.data.Dataset class as follows.

Get hands-on with 1200+ tech skills courses.