Search⌘ K
AI Features

VAEs in Action

Explore how variational autoencoders (VAEs) work by encoding MNIST digit images into a latent space and generating new images from this space. Understand the architecture, training process, and the balance between reconstruction loss and KL divergence essential for successful image generation.

In this lesson, we’ll delve into VAEs and their application to the MNIST dataset. Like traditional autoencoders, VAEs consist of an encoder and a decoder, but they go beyond simple image reconstruction. VAEs allow us to generate new data points by sampling from the learned latent space. Our focus will be on understanding the fascinating concept of generating new images from latent representations.

VAE generated MNIST
VAE generated MNIST

Loading and preprocessing the data

In this step, we load the MNIST dataset and prepare it for training the VAE. The dataset is converted to floating-point format, and the pixel values are normalized to the range [0, 1]. Additionally, we perform a visual inspection to understand the data distribution and validate the preprocessing. This ensures our VAE model can work effectively with the data and facilitates any necessary adjustments in the preprocessing pipeline.

Unique training approach in VAEs

In the context of VAEs, the conventional division of data into training and test sets is not required. Unlike classification tasks, where we need to evaluate the model’s performance on unseen data, VAEs aim to learn a latent representation of the entire dataset. All data is used for training the model, and new samples are generated from the learned latent space.

Python 3.10.4
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as nnF
# Load MNIST training dataset
mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True)
# Extract images and labels, normalize pixel values to [0,1]
x_train = mnist_train.data
y_train = mnist_train.targets
x_train = x_train.float() / 255.0
# Visualization setup
import matplotlib.pyplot as plt
import numpy as np
num_rows = 2
num_columns = 5
fig, axes = plt.subplots(num_rows, num_columns, figsize=(4, 2))
# Plot one random image per digit (0-9)
for category in range(10):
category_indices = np.where(y_train == category)[0] # Get indices of images for this digit
random_index = np.random.choice(category_indices) # Pick a random sample
image_np = np.array(x_train[random_index]) # Convert tensor to NumPy array
row = category // num_columns
col = category % num_columns
ax = axes[row, col]
ax.imshow(image_np, cmap='gray') # Display image in grayscale
ax.axis('off') # Hide axes for clarity
plt.tight_layout()
plt.show()

Defining the architecture

The autoencoder’s architecture has three main components.

Encoder

The encoder class is a crucial part of the VAE architecture responsible for compressing the input MNIST image into a lower-dimensional latent space. It consists of three linear layers: self.input, self.hidden_mean, and self.hidden_std. The encoder processes the input image through these layers and outputs the mean (mean) and standard deviation (std) vectors of the latent Gaussian distributionA latent Gaussian distribution models the compressed data as a normal distribution, allowing smooth sampling and meaningful interpolation in latent space.. These vectors define a probability ...