Introduction to Saliency Maps—The Vanilla Gradient

Get an overview of saliency maps and implement the vanilla gradient saliency algorithm.

Saliency maps

Given an image X[0,1]C×H×WX \in [0,1]^{C \times H \times W } (where HH is the height, WW is the width, and CC is the number of channels), a saliency map S[0,1]H×WS \in [0,1]^{H \times W} is an image where the brightness of the (i,j)th(i,j)^{th} pixel, i.e., S(i,j)S(i,j), represents how important or salient that pixel is in the network prediction. In other words, the greater the value of S(i,j)S(i,j), the more important that pixel. If a particular region in the saliency map is concentrated with bright pixels, it means that the region or feature is important for prediction. The figure below illustrates what a saliency map looks like. The red pixels in the saliency map denote the bright pixels important for prediction, while the black pixels are unnecessary.

Press + to interact
Saliency map
Saliency map

A saliency map can therefore be used as a visual explanation to verify the correctness of a model. For example, in the figure above, if the red pixels in the saliency map are concentrated around the “grass,” we can infer that the model is biased toward irrelevant artifacts, such as the background, to make the prediction. As a result, such a model shouldn’t be trusted and should only be deployed after careful examination.

Note: Saliency maps can also be referred to as pixel-attribution, attribution, or sensitivity maps.

Vanilla gradient saliency

We’ll now learn to implement our first saliency map algorithm: the vanilla gradient saliency.

Let’s assume that f(X)=[ f1(X),f2(X),...,fK(X) ]f(X) = [ \ f^1(X), f^2(X), ..., f^K(X) \ ] represents the neural network output where fi(X)f^i(X) denotes the pre-softmax score (logit) of theithi^{th} class (there are a total ofKKclasses). The prediction based on the output is given by

where kk^* is the label of the class with the highest score. The algorithm first computes the vanilla gradients GG as:

The vanilla gradients GG here is a matrix RC×H×W\in \R^{C \times H \times W} representing the direction in which the score of the predicted class, i.e., fk(X)f^{k^*}(X), increases. A positive gradient means that the pixel is important for increasing the predicted class score, while a negative or zero gradient means that the pixel is not important for the prediction. So it makes sense to only take positive gradients G+=max(0,G)G^+ = \max(0, G) (since they are important for prediction) for computing the saliency map.

The unnormalized saliency map SRH×WS \in \R^{H \times W} is now defined as the maximum value of the gradient G+G^+ along the channel dimension, i.e., S(i,j)=max{G+(1,i,j),G+(2,i,j),...,G+(C,i,j)}S(i, j) = \max\{G^+(1, i, j), G^+(2, i, j), ..., G^+(C, i, j)\}, where G+(c,i,j)G^+(c, i, j) denotes the gradients along the cthc^{th} channel and the(i,j)th(i,j)^{th} pixel (note that G+G^+ is C×H×WC \times H \times W matrix). Finally, the saliency map SS is normalized to [0,1]H×W[0,1]^{H \times W} for visualization.

Press + to interact
Vanilla gradient saliency
Vanilla gradient saliency

Implementing vanilla gradient saliency

The code below implements and visualizes the vanilla gradient saliency of an image with respect to the prediction made by the MobileNet-V2 network trained on the ImageNet-1K dataset. It outputs the original image, its vanilla gradient saliency, and the network prediction. The red pixels in the saliency map denote the bright pixels important for prediction, while the black pixels are unnecessary.

Press + to interact
import torch
import torchvision
import torchvision.transforms as T
from torchvision.models import mobilenet_v2
from PIL import Image
import json
import matplotlib.pyplot as plt
class_idx = json.load(open("imagenet_class_index.json", 'r')) # ImageNet Id to Label Mapping
idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
image = Image.open("dog.jpg").resize((224,224)) # original image
transform = T.Compose([
T.Resize((224,224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = transform(image) # transform and normalize image
image.requires_grad = True
torchvision.utils.save_image(image, "./output/image.png", normalize=True)
model = mobilenet_v2() # load MobileNetV2 model
ckpt = torch.load("./mobilenet_v2-b0353104.pth", map_location="cpu")
model.load_state_dict(ckpt)
model.eval()
logits = model(image.unsqueeze(0))[0] # output logits, forward pass
prediction = torch.argmax(logits) # network prediction
print("Model predicted : ", idx2label[prediction.item()])
grad = torch.autograd.grad(logits[prediction.item()], [image])[0] # backward pass
grad, _ = torch.max(grad.relu(), 0) # take positive gradients only
plt.imshow(grad, cmap=plt.cm.hot) # plot saliency
plt.axis("off")
plt.tight_layout()
plt.savefig("./output/saliency.png", bbox_inches="tight")
  • Lines 1–7: We import torch for automatic differentiation via torch.autograd.grad(), torchvision for loading MobileNetV2, PIL for image manipulations, json for JSON-related utilities, and matplotlib for plotting graphs/images.

  • Lines 9–10: We load the ImageNet-1K class indexes and create a map idx2label, which maps a numeric label to its corresponding class name in the ImageNet-1K dataset.

  • Lines 12–17: We load a PILPython Image Library image and define a transform function to resize the input image to 224 ✕ 224 resolution and then normalize it.

  • Lines 19–20: We transform the image and enable its requires_grad attribute for computing gradients.

  • Lines 23–26: We load the MobileNet-V2 model and set its parameters to eval() mode. model.eval() is a switch for specific layers like batch normalizationBatch normalization scales layers' outputs to have a mean of 0 and a variance of 1. The outputs are scaled this way to train the network faster. and dropoutThe dropout layer is a mask that nullifies the contribution of some neurons toward the next layer and leaves all others unmodified. in the model that behave differently during training and inference (evaluating) time.

  • Lines 28–29: We calculate the class logits and the network prediction.

  • Line 32: We calculate the vanilla gradient of the network prediction with respect to the input image.

  • Line 33: We take the maximum value of the positive gradient along the channel dimension.

  • Lines 35–38: We plot and visualize the normalized saliency map.

As we can see, the bright or red pixels are concentrated around the face of the dog, suggesting that the facial features are essential for classifying an object into a dog.

The vanilla gradient algorithm is fast because it involves only one forward-backward pass and generates decent explanations for most inputs. However, when the input becomes noisy, the vanilla gradients also become noisy, giving importance to some random/irrelevant pixels in the saliency map. This makes this algorithm vulnerable to adversarial attacksAn adversarial attack (also called an adversarial example) refers to the feeding of inputs (e.g., image, text, and voice) to machine learning models that an attacker has intentionally designed to cause the model to make a mistake or wrong classification..