Search⌘ K
AI Features

Introduction to Masked Image Modeling

Explore masked image modeling to understand how self-supervised learning predicts missing parts of images. Learn about masking strategies, encoder design using vision transformers, multi-head self-attention, and how these components work together for image reconstruction and representation learning.

What is masked image modeling?

Inspired by NLP, masked image modeling (MIM) is a new self-supervised learning paradigm where the masked portion of the input is used to learn and predict the masked signal. This approach provides competitive results to approaches like contrastive learning.

Masked image modeling
Masked image modeling

Applying masked image modeling can be challenging because of several reasons:

  • Pixels close to each other are highly correlated. As a result, sometimes the image can be reconstructed well enough, even by duplicating nearby pixels. This leads to trivial solutions and inefficient learning.

  • Signals at pixel levels are very raw and contain low-level information.

  • Signals in image data are also continuous, unlike text data, where they are discrete.

Thus, masked image modeling must be accomplished properly to avoid correlation/trivial solutions.

The framework of masked image modeling

Masked image modeling aims to predict the original signals from a masked input. As illustrated below, the framework involves the following components:

  • Masking strategy: A masking strategy is based on selecting the area to mask and performing masking on that area. Usually, masking is done at the image patch level rather than the pixel level (i.e., masking is applied to patches rather than to pixels). We can use various strategies for image masking, like square shape masking, random masking, etc. This masked image is used as an input to the neural network.

  • Encoder: This component is a neural network that should be able to take a masked image as input and extract useful latent representations to predict the original signals at the masked areas. Generally, transformer models like Vision Transformer and Swin Transformers (discussed subsequently) are used as encoder architectures.

  • Prediction head: This component should reconstruct the original signals at the masked region of the input when the encoder features are given as input.

  • Prediction target: This component calculates the loss function on prediction head output. The loss type can be a cross-entropy classification L1L_1 or L2L_2 pixel regression loss. Pixel regression means we predict the values of masked regions of the input image.

Here, y^\hat{y} is a model prediction and yy is a target.

Overview of vision transformers

Most approaches in mask image modeling use masking strategies that operate at the image patch level. Instead of masking an image at each pixel, they mask N×NN \times N sized patches. For the same reason, these approaches prefer vision transformers (as they operate on the patch level) rather than convolutional neural networks as their primary network architectures. So, here is an overview of vision transformers to understand the concepts of masked image modeling better.

Patch embeddings

The first step is to represent the input image Xi[0,1]H×W×CX_i \in [0,1]^{H \times W\times C} (height HH, width WW, and channels CC) as a sequence of N×NN \times N-sized non-overlapping patches Xipatched=[Xi1,Xi2,...,XiP]X_i^{\text{patched}} = [ X_i^1, X_i^2, ..., X_i^P] (here, P=(H×WN2)P = (\frac{H \times W}{N^2}) is the total number of patches, and XipX_i^p represents the pthp^{th} patch [0,1]N×N×C\in [0,1]^{N \times N \times C}).

Image patches
Image patches

The next step is to project each of these patches ([0,1]N×N×C\in [0,1]^{N \times N \times C}) in the sequence into dd-dimensional patch embeddings. In other words:

Here, PatchEmbed(.)\text{PatchEmbed}(.) linearly projects a flattened patch Xip[0,1]N2CX_i^p \in [0,1]^{N^2C} (after flattening, the shape of the patch changes from N×N×CN \times N \times C to N2CN^2C) to its patch embedding eipRde_i^p \in \R^d.

[CLS] token

After obtaining patch embeddings EiRP×d\mathcal{E}_i \in \R^{P \times d}, we add a special classification ([CLS]) token e[CLS]Rde^{\text{[CLS]}} \in \R^d to the patch embedding sequence. The [CLS] token aims to capture and summarize the information present in all patches embeddings in single dd-dimensional representation. This happens in the multi-head self-attention blocks (discussed later). The final representation of the [CLS] token (i.e., after multi-head self-attention blocks) is passed through a linear layer for classification. The initial value of the special token is a parameter of the model that needs to be learned. Input EiR(P+1)×d\mathcal{E}_i \in \R^{(P+1) \times d}is now:

Note: The size of input Ei\mathcal{E}_i increases by one (i.e., (P+1)×d(P+1) \times d) after adding [CLS] token e[CLS]e^{\text{[CLS]}}.

Positional embeddings

Next, a positional encoding S=[s0,s1,...,sP]\mathcal{S} = [s^0, s^1, ..., s^P] (s0s^0 is for [CLS] token and siRds^i \in \R^d) is added to the input sequence, Ei\mathcal{E}_i, which allows the model to understand that the placement of each patch, XipX_i^p, is in the original image. The positional embedding, SR(P+1)×d\mathcal{S} \in \R^{(P+1) \times d}, is also a learnable parameter updated along with the [CLS] token during training. The final input sequence is written as:

Here ,++ represents the vector addition.

Self-attention

Self-attention heads are the building blocks of vision transformers. Self-attention allows a global lookup as the model can look at the information present in the embeddings present in other positions of the input sequence. Through global lookup, the self-attention-based model can simultaneously focus on the whole image and summarize it better. This is impossible in convolutional models, as they use convolutional operations that only operate on the part of the input image.

A self-attention head first embeds the input sequence, EiR(P+1)×d\mathcal{E}_i \in \R^{(P+1) \times d}, to query Eiq\mathcal{E}^{q}_i , key Eik\mathcal{E}^k_i , and value Eiv\mathcal{E}^v_i vector sequence as follows:

Here, the matrixes Wq,Wk,WvRd×lW_q, W_k, W_v \in \R^{d \times l} are learnable network parameters. Note that since WqRd×lW_q \in \R^{d \times l} and EiR(P+1)×d\mathcal{E}_i \in \R^{(P+1)\times d}, their matrix product Eiq=EiWq\mathcal{E}_i^q = \mathcal{E}_i W_q will be in R(P+1)×l\R^{(P+1)\times l}. Similar applies for Eik\mathcal{E}_i^k and Eiv\mathcal{E}_i^v.

Note: Think of these query, key, and value vectors as intermediate outputs used in the computation of final output.

Next, we compute self-attention scores Ai[0,1](P+1)×(P+1)\mathcal{A}_i \in [0,1]^{(P+1)\times (P+1)} and multiply them with the value vectors Eiv\mathcal{E}^v_i to return a sequence of self-attended outputs EioR(P+1)×l\mathcal{E}^o_i \in \R^{(P+1) \times l} as follows:

Note that since both Eik,EiqR(P+1)×l\mathcal{E}_i^k, \mathcal{E}_i^q \in \R^{(P+1) \times l}, the matrix product Eiq(Eik)T\mathcal{E}^q_i (\mathcal{E}^k_i)^T will be of shape (P+1)×(P+1)(P+1) \times (P+1). The (m,n)th(m,n)^{th} entry Ai(m,n)\mathcal{A}_i^{(m,n)} in the attention score matrix Ai\mathcal{A}_i denotes the similarity between the mthm^{th} query vector Eiq[m]\mathcal{E}^q_i[m] and the nthn^{th} key vector Eik[n]\mathcal{E}^k_i[n].

Multi-head self-attention (MSA)

A multi-head self-attention (MSA) layer comprises more than one self-attention head. The output of all these self-attention heads is concatenated together as one single output of the MSA layer. Generally, we have dl\frac{d}{l} number of self-attention heads so that the output sequence (after concatenation) is also in R(P+1)×d\R^{(P+1) \times d} (each self-attention head produces R(P+1)×l\R^{(P+1) \times l} outputs, concatenating d/ld/l such sequences give R(P+1)×d\R^{(P+1) \times d} outputs). The figure below gives a high-level idea of an MSA layer.

A vision transformer consists of several such MSA layers after one and a multi-layer perceptron. We take the final embedding corresponding to the [CLS] token in the output sequence for the classification.

The code snippet below implements the vision transformer. We import the implementation of MSA layers from PyTorch’s timm library.

Python 3.8
import math
import torch
import torch.nn as nn
from timm.models.vision_transformer import Block
from PIL import Image
import torchvision.transforms.functional as T
import torchvision
from utils import extract_image_patches
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
return x.flatten(2).transpose(1, 2)
class VisionTransformer(nn.Module):
""" Vision Transformer """
def __init__(self, img_size=[224], patch_size=16,
in_chans=3, num_classes=0,
embed_dim=768, depth=12, num_heads=12):
super().__init__()
self.num_features = self.embed_dim = embed_dim
print("Shape of image:", img_size[0])
self.patch_embed = PatchEmbed(
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
print("Total patches :", num_patches)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=0)
print("Shape of positional embeddings: ", self.pos_embed.shape)
self.blocks = nn.ModuleList([
Block(dim=embed_dim, num_heads=num_heads)
for i in range(depth)])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
B, nc, w, h = x.shape
# patch linear embedding
x = self.patch_embed(x)
print("Shape of patch embeddings:", x.shape)
# add the [CLS] token to the embed patch tokens
print("Shape of [CLS] token :", self.cls_token.shape)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
print("Shape after adding [CLS] token :", x.shape)
# add positional encoding to each token
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks: # pass through MSA layers
x = blk(x)
print("Shape after MSA Layers :", x.shape)
x = self.norm(x) # use layer norm
return x[:, 0] # return [CLS] representation
model = VisionTransformer(patch_size=16, embed_dim=192, depth=12, num_heads=3)
image = T.to_tensor(Image.open("cat.jpg").resize((224,224)))[None, :, :, :]
torchvision.utils.save_image(image, "./output/image.png", normalize=True)
patches = extract_image_patches(image, kernel=16, stride=16)
print(patches.shape)
torchvision.utils.save_image(patches[0], "./output/patches.png", normalize=True, nrow=14)
f = model(image)
print("Shape of [CLS] token feature: ", f.shape)
  1. Lines 10–24: We define the PatchEmbed class that takes img_size (H=WH=W), patch_size (NN), and the network embedding size embed_dim (dd) as input in its __init__ call and calculates the total number of patches self.num_patches (PP).

  2. Line 20: We define the self.proj layer as a convolutional layer with kernel_size and stride as patch_size. With this kernel size and stride, the convolutional kernel will operate on non-overlapping image patches of size patch_size and project the image batch of shape B×C×H×WB \times C \times H \times W into a feature volume of shape B×d×P1/2×P1/2B \times d \times P^{1/2} \times P^{1/2} (BB is the batch size). This volume is further flattened and reshaped in the forward() function to give patch embeddings of size B×P×dB \times P \times d.

  3. Line 26: We define the VisionTransformer class that takes img_size, patch_size, embed_dim, depth (TT) and num_heads (dl\frac{d}{l}) as inputs in its __init__ call.

  4. Lines 34–35: We create an instance, self.patch_embed, of the class PatchEmbed.

  5. Lines 40–41: We define the [CLS] token, self.cls_token (e[CLS]e^{\text{[CLS]}}), and positional encoding self.pos_embed (SS) as learnable parameters.

  6. Line 42: We define a dropout layer, self.pos_drop.

  7. Lines 45–47: We define the depth (TT) number of MSA blocks in self.blocks. Each MSA block is composed of num_heads self-attention heads.

  8. Line 49: We define the LayerNorm layer, self.norm.

  9. Line 51: We implement the forward call that takes an input image, x, as input and converts it to patch embeddings (of size B×P×dB \times P \times d) using the self.patch_embed layer in line 54.

  10. Lines 59–60: We prepend self.cls_token in the patch embedding sequence. The input, x, is now in the shape B×(P+1)×dB \times (P+1) \times d.

  11. Lines 64–67: We add self.pos_embed and apply the self.pos_drop layer to the input sequence, x.

  12. Lines 69–70: We pass the input sequence to multi-head self-attention blocks self.blocks. The output for self.blocks is then passed through the LayerNorm self.norm layer.

  13. Line 74: We return the final representation corresponding to the [CLS] token.

The code outputs the original image, the patched image, and the shape of the network embeddings.

Note: The code aims to test the implementation of the vision transformer.