...

/

Training The Fashion MNIST GAN

Training The Fashion MNIST GAN

Learn how we can now train the fashion MNIST pattern GAN.

The code is similar to the code we used for training the MNIST GAN here.

Training the Fashion MNIST GAN

We’re ready to train the GAN using the 3 step training loop. Have a look at the following code:

Press + to interact
main.py
Generator.py
Discriminator.py
Dataset.py
fashion-mnist_train.csv
import torch
import matplotlib.pyplot as plt
from Dataset import FMnistDataset
from Discriminator import Discriminator
from Generator import Generator
# load data
fmnist_dataset = FMnistDataset('fashion-mnist_train.csv')
# functions to generate random data
def generate_random_image(size):
random_data = torch.rand(size)
return random_data
def generate_random_seed(size):
random_data = torch.randn(size)
return random_data
# create Discriminator and Generator
D = Discriminator()
G = Generator()
epochs = 4
for epoch in range(epochs):
print ("epoch = ", epoch + 1)
# train Discriminator and Generator
for label, image_data_tensor, target_tensor in fmnist_dataset:
# train discriminator on true
D.train(image_data_tensor, torch.FloatTensor([1.0]))
# train discriminator on false
# use detach() so gradients in G are not calculated
D.train(G.forward(generate_random_seed(100)).detach(), torch.FloatTensor([0.0]))
# train generator
G.train(D, generate_random_seed(100), torch.FloatTensor([1.0]))
pass
pass
# plot several outputs from the trained generator
# plot a 3 column, 2 row array of generated images
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
for j in range(3):
output = G.forward(generate_random_seed(100))
img = output.detach().numpy().reshape(28,28)
axarr[i,j].imshow(img, interpolation='none', cmap='Blues')
pass
pass

Discriminator and generator objects

We first create fresh discriminator and generator objects, before running the training loop 10,000 times.

Training loop

Inside the loop, we can see the 3 steps of the GAN training loop we talked about earlier.

...