This device is not compatible.


Dataset Augmentation with GANs

Learn to augment datasets with Generative Adversarial Networks (GANs) in PyTorch.

Dataset Augmentation with GANs

You will learn to:

Create Generative Adversarial Networks from scratch.

Define and use losses to train neural networks for a given task.

Train deep learning models with automatic differentiation.

Use the Matplotlib library to visualize the training process.


Deep Learning

Data Visualization

Data Augmentation


Intermediate programming skills in Python

Intermediate knowledge of the PyTorch library

Intermediate understanding of deep learning

Familiarity with creating visualizations using the Matplotlib library





Project Description

Generative Adversarial Networks (GANs) are deep neural networks trained to generate data distributions that are not identifiable as being synthetic. Two neural networks are trained in parallel to accomplish this–the generator and the discriminator. The discriminator's job is to identify a sample as either real or synthetic. The generator's job is to generate samples that fool the discriminator into classifying them as real.

Let's say we wanted to create a classification model on a dataset with a high class imbalance. When we trained the model, we got an excellent accuracy score, but when we made a confusion matrix, we saw that the model only predicted one class. Oversampling and undersampling are two solutions to the class imbalance problem. The former includes the minority class samples multiple times, which gives duplicate samples in the dataset. The latter drops a chunk of the majority class samples, which results in the loss of valuable data samples. To resolve these problems, we can use GANs to generate synthetic data for the minority class. This will ensure that there are no duplicate samples and that the classes are equally represented in the dataset. Any model trained on this data will, therefore, learn to predict both classes.

We'll create a GAN from scratch for dataset augmentation in this project. We'll use PyTorch for machine learning and Matplotlib for visualizations. Automatic differentiation will be used to train the deep learning models. The generator will be trained so that its generated distribution closely matches the actual distribution of the data.

Project Tasks


Getting Started

Task 0: Introduction

Task 1: Import the Libraries

Task 2: Create the Data


The Model

Task 3: Define Generators and Discriminators

Task 4: Define the Function for Discriminator Updates

Task 5: Perform Generator Updates


The Train Function

Task 6: Initialize the Parameters

Task 7: Compute the Losses

Task 8: Display Generated Distributions

Task 9: Display the Losses

Task 10: Create the Training Function