Training The Fashion MNIST GAN
Learn how we can now train the fashion MNIST pattern GAN.
The code is similar to the code we used for training the MNIST GAN here.
Training the Fashion MNIST GAN
We’re ready to train the GAN using the 3 step training loop. Have a look at the following code:
Press + to interact
main.py
Generator.py
Discriminator.py
Dataset.py
fashion-mnist_train.csv
import torchimport matplotlib.pyplot as pltfrom Dataset import FMnistDatasetfrom Discriminator import Discriminatorfrom Generator import Generator# load datafmnist_dataset = FMnistDataset('fashion-mnist_train.csv')# functions to generate random datadef generate_random_image(size):random_data = torch.rand(size)return random_datadef generate_random_seed(size):random_data = torch.randn(size)return random_data# create Discriminator and GeneratorD = Discriminator()G = Generator()epochs = 4for epoch in range(epochs):print ("epoch = ", epoch + 1)# train Discriminator and Generatorfor label, image_data_tensor, target_tensor in fmnist_dataset:# train discriminator on trueD.train(image_data_tensor, torch.FloatTensor([1.0]))# train discriminator on false# use detach() so gradients in G are not calculatedD.train(G.forward(generate_random_seed(100)).detach(), torch.FloatTensor([0.0]))# train generatorG.train(D, generate_random_seed(100), torch.FloatTensor([1.0]))passpass# plot several outputs from the trained generator# plot a 3 column, 2 row array of generated imagesf, axarr = plt.subplots(2,3, figsize=(16,8))for i in range(2):for j in range(3):output = G.forward(generate_random_seed(100))img = output.detach().numpy().reshape(28,28)axarr[i,j].imshow(img, interpolation='none', cmap='Blues')passpass
Discriminator and generator objects
We first create fresh discriminator and generator objects, before running the training loop 10,000 times.
Training loop
Inside the loop, we can see the 3 steps of the GAN training loop we talked about earlier.