Introduction to Masked Image Modeling
Explore masked image modeling to understand how self-supervised learning predicts missing parts of images. Learn about masking strategies, encoder design using vision transformers, multi-head self-attention, and how these components work together for image reconstruction and representation learning.
What is masked image modeling?
Inspired by NLP, masked image modeling (MIM) is a new self-supervised learning paradigm where the masked portion of the input is used to learn and predict the masked signal. This approach provides competitive results to approaches like contrastive learning.
Applying masked image modeling can be challenging because of several reasons:
Pixels close to each other are highly correlated. As a result, sometimes the image can be reconstructed well enough, even by duplicating nearby pixels. This leads to trivial solutions and inefficient learning.
Signals at pixel levels are very raw and contain low-level information.
Signals in image data are also continuous, unlike text data, where they are discrete.
Thus, masked image modeling must be accomplished properly to avoid correlation/trivial solutions.
The framework of masked image modeling
Masked image modeling aims to predict the original signals from a masked input. As illustrated below, the framework involves the following components:
Masking strategy: A masking strategy is based on selecting the area to mask and performing masking on that area. Usually, masking is done at the image patch level rather than the pixel level (i.e., masking is applied to patches rather than to pixels). We can use various strategies for image masking, like square shape masking, random masking, etc. This masked image is used as an input to the neural network.
Encoder: This component is a neural network that should be able to take a masked image as input and extract useful latent representations to predict the original signals at the masked areas. Generally, transformer models like Vision Transformer and Swin Transformers (discussed subsequently) are used as encoder architectures.
Prediction head: This component should reconstruct the original signals at the masked region of the input when the encoder features are given as input.
Prediction target: This component calculates the loss function on prediction head output. The loss type can be a cross-entropy classification
or pixel regression loss. Pixel regression means we predict the values of masked regions of the input image.
Here,
Overview of vision transformers
Most approaches in mask image modeling use masking strategies that operate at the image patch level. Instead of masking an image at each pixel, they mask
Patch embeddings
The first step is to represent the input image
The next step is to project each of these patches (
Here,
[CLS] token
After obtaining patch embeddings
Note: The size of input
increases by one (i.e., ) after adding [CLS] token .
Positional embeddings
Next, a positional encoding
Here ,
Self-attention
Self-attention heads are the building blocks of vision transformers. Self-attention allows a global lookup as the model can look at the information present in the embeddings present in other positions of the input sequence. Through global lookup, the self-attention-based model can simultaneously focus on the whole image and summarize it better. This is impossible in convolutional models, as they use convolutional operations that only operate on the part of the input image.
A self-attention head first embeds the input sequence,
Here, the matrixes
Note: Think of these query, key, and value vectors as intermediate outputs used in the computation of final output.
Next, we compute self-attention scores
Note that since both
Multi-head self-attention (MSA)
A multi-head self-attention (MSA) layer comprises more than one self-attention head. The output of all these self-attention heads is concatenated together as one single output of the MSA layer. Generally, we have
A vision transformer consists of several such MSA layers after one and a multi-layer perceptron. We take the final embedding corresponding to the [CLS] token in the output sequence for the classification.
The code snippet below implements the vision transformer. We import the implementation of MSA layers from PyTorch’s timm library.
Lines 10–24: We define the
PatchEmbedclass that takesimg_size(), patch_size(), and the network embedding size embed_dim() as input in its __init__call and calculates the total number of patchesself.num_patches(). Line 20: We define the
self.projlayer as a convolutional layer withkernel_sizeandstrideaspatch_size. With this kernel size and stride, the convolutional kernel will operate on non-overlapping image patches of sizepatch_sizeand project the image batch of shapeinto a feature volume of shape ( is the batch size). This volume is further flattened and reshaped in the forward()function to give patch embeddings of size. Line 26: We define the
VisionTransformerclass that takesimg_size,patch_size,embed_dim,depth() and num_heads() as inputs in its __init__call.Lines 34–35: We create an instance,
self.patch_embed, of the classPatchEmbed.Lines 40–41: We define the [CLS] token,
self.cls_token(), and positional encoding self.pos_embed() as learnable parameters. Line 42: We define a dropout layer,
self.pos_drop.Lines 45–47: We define the
depth() number of MSA blocks in self.blocks. Each MSA block is composed ofnum_headsself-attention heads.Line 49: We define the LayerNorm layer,
self.norm.Line 51: We implement the
forwardcall that takes an input image,x, as input and converts it to patch embeddings (of size) using the self.patch_embedlayer in line 54.Lines 59–60: We prepend
self.cls_tokenin the patch embedding sequence. The input,x, is now in the shape. Lines 64–67: We add
self.pos_embedand apply theself.pos_droplayer to the input sequence,x.Lines 69–70: We pass the input sequence to multi-head self-attention blocks
self.blocks. The output forself.blocksis then passed through the LayerNormself.normlayer.Line 74: We return the final representation corresponding to the [CLS] token.
The code outputs the original image, the patched image, and the shape of the network embeddings.
Note: The code aims to test the implementation of the vision transformer.