Self-Attention

Understand in depth how self-attention works.

What is self-attention?

“Self-attention" is an attention mechanism relating different positions of a single sequence in order to compute a representation of the sequence.” ~ Ashish Vaswani et al. from Google Brain

Self-attention enables us to find correlations between different words (tokens) of the input indicating the syntactic and contextual structure of the sentence.

Let’s take the input sequence “Hello I love you” as an example.

A trained self-attention layer will associate the word “love” with the words “I” and “you” with a higher weight than the word “Hello”. From linguistics, we know that these words share a subject-verb-object relationship and that is an intuitive way to understand what self-attention will capture.

Self attention probability score matrix
Self attention probability score matrix

In practice, the original transformer model uses three different representations of the embedding matrix: the Queries, Keys, and Values.

This can easily be implemented by multiplying our input XRN×dk{X} \in R^{N \times d_{k} } with three different weight matrices WQ{W}_Q, WK{W}_K and WVRdk×dmodel{W}_V \in R^{ d_{k} \times d_{model}}. In essence, it is just a matrix multiplication in the original word embedding.

You can think of it as a linear projection. Here is a visual to help you out:

Input projection in Keys, Queries and Values
Input projection in Keys, Queries and Values

Having the Query, Value, and Key matrices, we can now apply the self-attention layer as:

Attention(Q,K,V)=softmax(QKTdk)V{Attention}({Q}, {K}, {V})={softmax}\left(\frac{{Q} {K}^{T}}{\sqrt{d_{k}}}\right) {V}

In the original paper, the scaled dot-product attention was chosen as a scoring function to represent the correlation between two words (the attention weight).

Note that we can also utilize another similarity function. The dk\sqrt{d_{k}} is here simply as a scaling factor to make sure that the vectors won’t explode.

Following the video database-query paradigm that we introduced before, this term simply finds the similarity of the searching query with an entry in a database.

Finally, we apply a softmax function to get the final attention weights as a probability distribution.

Remember that we have distinguished the Keys (KK) from the Values (VV) as distinct representations. Thus, the final representation is the self-attention matrix softmax(QKTdk){softmax}\left(\frac{{Q} {K}^{T}}{\sqrt{d_{k}}}\right) multiplied with the Value (VV) matrix.

Personally, the attention matrix softmax(QKTdk){softmax}\left(\frac{{Q} {K}^{T}}{\sqrt{d_{k}}}\right) can be thought of as where to look and the Value matrix as what is actually wanted. .

Notice any differences between vector similarities?

First, we have matrices instead of vectors, and, as a result, matrix multiplications. Second, we don’t scale down by the vector magnitude but by the matrix size (dk), which is the number of words in a sentence! The sentence size varies.

Matrix multiplication can be thought of as a parallel vector-matrix multiplication of multiple vectors. The vectors are simply the queries.

We “query” all the projected words together by stacking embedding vectors in a matrix and projecting them linearly to QQ.

Isn’t that awesome?

To make things clearer, you can find a Pytorch implementation below:

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, 
                query: torch.FloatTensor, 
                key: torch.FloatTensor,
                value: torch.FloatTensor, 
                mask: Optional[torch.ByteTensor] = None, 
                dropout: Optional[nn.Dropout] = None
                ) -> Tuple[torch.Tensor, Any]:
        """
        Args:
            `query`: shape (batch_size, n_heads, max_len, d_q)
            `key`: shape (batch_size, n_heads, max_len, d_k)
            `value`: shape (batch_size, n_heads, max_len, d_v)
            `mask`: shape (batch_size, 1, 1, max_len)
            `dropout`: nn.Dropout
        Returns:
            `weighted value`: shape (batch_size, n_heads, max_len, d_v)
            `weight matrix`: shape (batch_size, n_heads, max_len, max_len)
        """
        
        d_k = query.size(-1)  # d_k = d_model / n_heads
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)  
        if mask is not None:
            scores = scores.masked_fill(mask.eq(0), -1e9)
        p_attn = F.softmax(scores, dim=-1) 
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.matmul(p_attn, value), p_attn

Before we build fancy transformer blocks, we need to delve into one more critical concept: multi-head self-attention.

Get hands-on with 1200+ tech skills courses.