Generating Images from Labels with the CGAN
Explore training a Conditional GAN to generate digit images based on label inputs using PyTorch. Learn how to set up the generator and discriminator, configure training parameters, and evaluate results. This lesson guides you through the code structure and process to produce labeled image outputs from random noise.
We'll cover the following...
We have already defined the architecture of both generator and discriminator networks of the CGAN. Now, let’s write the code for model training. In order to make it easy to reproduce the results, we will use MNIST as the training set to see how the CGAN performs in image generation. What we want to accomplish here is that, after the model is trained, it can generate the correct digit image we tell it to, with extensive variety.
One-stop model training API
First, let’s create a new Model class that serves as a wrapper for different models and provides the one-stop training API. Create a new file named build_gan.py and import the necessary modules:
Then, let's create the Model class. In this class, we will initialize the Generator and Discriminator modules and provide train and eval methods so that users can simply call Model.train() or Model.eval() somewhere else to complete the model training or evaluation.
Here, the generator network, netG, and the discriminator network, netD, are initialized based on the class number (classes), image channel (channels), image size (img_size), and length of the latent vector(latent_dim). These arguments will be given later. For now, let's assume that these values are already known. Since we ...