Image reconstruction with autoencoders

This hands-on Answer shows how autoencoders can compress and rebuild data. Imagine turning complex data into a smaller version and then bringing it back to its original detail. This will help us understand how autoencoders work and see their abilities up close. We’ll use a custom dataset to practice compressing and rebuilding data.

In image reconstruction, autoencoders are trained to compress images (encode) into a smaller representation (latent space) and then reconstruct the original images from this compressed form (decode).

How image reconstruction with autoencoders works

  1. Encoder: The encoder network compresses the input image into a low-dimensional representation. This can be seen as extracting the most important features from the image.

  2. Latent space: This is the compressed version of the input image, representing its most important features in fewer dimensions.

  3. Decoder: The decoder takes the compressed representation and reconstructs the original image from it. The goal is for the output to be as close to the original image as possible.

Autoencoder for data compression and reconstruction
Autoencoder for data compression and reconstruction

Steps for image reconstruction using autoencoders

  1. Data preprocessing: Prepare and normalize your image dataset (e.g., MNIST or CIFAR-10).

  2. Model definition: Define the autoencoder model architecture, which includes the encoder and decoder.

  3. Training: Train the autoencoder using a loss function like Mean Squared Error (MSE) that compares the original image with the reconstructed image.

  4. Reconstruction: Use the trained model to reconstruct images from the test set.

Coding example

Here’s an example of how to implement an autoencoder for image reconstruction using PyTorch:

Import libraries

We are importing essential libraries for building, training, and visualizing an autoencoder in PyTorch for image reconstruction using the MNIST dataset.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
Importing necessary libraries

The above code includes tools for tensor computations, neural network layers, and optimization algorithms. Additionally, it uses torchvision for handling image datasets and transformations, and matplotlib for visualizing data, such as displaying the original and reconstructed images from an autoencoder.

Define the autoencoder architecture

We are building an autoencoder model using convolutional layers to compress (encode) and reconstruct (decode) images.

class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=7)
)
# Decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, kernel_size=7),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Sigmoid() # Output range should be between 0 and 1
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
Defining the autoencoder architecture

In the above code:

  • Lines 4-11: self.encoder defines the encoder part of the autoencoder model, using nn.Sequential to stack multiple layers.

  • Lines 12-20: self.decoder defines the decoder part of the autoencoder, which tries to reconstruct the original image from the compressed representation.

  • Lines 22-25: The forward(self, x) method defines how data passes through the model.

Load and preprocess the data

For this example, we'll use the MNIST dataset (grayscale handwritten digits).

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
Importing necessary libraries

In the above code:

  • Line 1: Prepare the MNIST dataset for training and testing by converting images to tensors with a specified transformation.

  • Lines 3-4: Create a DataLoader for the training data with shuffling and batching of 64 samples per batch.

  • Lines 6-7: Create another DataLoader for the test data without shuffling, ensuring the dataset is ready for model training and evaluation.

Train the autoencoder

We initialize an autoencoder model, sets up the loss function and optimizer, trains the model for 10 epochs on the training data, and prints the loss at the end of each epoch.

model = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
for epoch in range(num_epochs):
for data in train_loader:
img, _ = data
# Forward pass
output = model(img)
loss = criterion(output, img)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
Training the autoencoder

In the above code:

  • Lines 1-3: Initializes an Autoencoder model, defines the mean squared error loss function (nn.MSELoss()), and sets up the Adam optimizer with a learning rate of 0.001.

  • Line 5: It then trains the model for 10 epochs.

  • Lines 6-16: During each epoch, it iterates over batches of training data. For each batch, it performs a forward pass to get the model’s output, computes the reconstruction loss between the output and the original image, and performs a backward pass to calculate the gradients. The optimizer then updates the model’s parameters to minimize the loss.

  • Line 18: After processing all batches in an epoch, the code prints the current epoch number and the loss, providing insights into the model’s training progress.

Test and visualize the reconstruction

We visualize the original and reconstructed images side by side for evaluation of the autoencoder’s performance.

def show_images(original, reconstructed):
fig, axes = plt.subplots(1, 2)
axes[0].imshow(original.squeeze().numpy(), cmap='gray')
axes[0].set_title('Original Image')
axes[1].imshow(reconstructed.squeeze().detach().numpy(), cmap='gray')
axes[1].set_title('Reconstructed Image')
plt.show()
model.eval()
with torch.no_grad():
for data in test_loader:
img, _ = data
reconstructed = model(img)
show_images(img[0], reconstructed[0])
break
Testing and visualizing the reconstruction

In the above code:

  • Lines 1-8: The show_images function creates a side-by-side plot of the original and reconstructed images using Matplotlib.

    • Lines 4 and 7: It takes two images as inputs—original and reconstructed—and displays them in a single figure with two subplots.

    • Lines 3 and 6: The imshow function is used to render each image in grayscale (cmap='gray').

  • Line 10: After defining the show_images function, the model is set to evaluation mode using model.eval().

  • Lines 11-16: torch.no_grad() to prevent gradient calculations, the code iterates over a batch from the test loader. For each batch, it feeds the images through the autoencoder to obtain the reconstructed outputs, then uses show_images to display the first image in the batch alongside its reconstruction. The loop breaks after the first batch to avoid displaying multiple images.

Implementation

In the below widget, we've set all the above code. Just press "Run" and click on the link in the widget to see output in the jupyter notebook file. All set it for you!

Please note that the notebook cells have been preconfigured to display the outputs
for your convenience and to facilitate an understanding of the concepts covered. 
This hands-on approach will allow you to experiment with the memory techniques discussed, providing a more immersive learning experience.
Reconstructing image with autoencoder

Conclusion

In image reconstruction with an autoencoder, the original images are compressed into a latent space by the encoder. The decoder then reconstructs the image from this latent representation. By comparing the reconstructed image to the original, the model is optimized to reduce reconstruction error, allowing it to efficiently learn and reconstruct the image.

Free Resources

Copyright ©2024 Educative, Inc. All rights reserved