Multi-Head Self-Attention
Explore multi-head self-attention to understand how transformer models capture complex token relationships. Learn the roles of queries, keys, and values, why multiple heads improve model nuance, and gain practical insights from implementing both single- and multi-head attention in PyTorch. This lesson prepares you to explain and code this essential mechanism confidently for AI interviews.
A classic question you might encounter when discussing modern architectures is to explain multi-head self-attention in the context of transformers and why we want multiple heads. It’s essential not only to recall the definition but also to convey a clear understanding of why multi-head attention is so powerful.
In the conversation that follows, your ability to articulate both the mechanics—queries, keys, values—and the rationale behind using multiple attention heads demonstrates theoretical depth and practical insight. Let’s walk through everything you need to know about multi-head self-attention, from a concise conceptual overview to implementation details, testing approaches, and common follow-up questions.
Why do transformers use self-attention instead of recurrence?
Imagine every word in a sentence having the power to converse directly with every other word. This is the essence of self-attention: a breakthrough that transforms how models understand context.
In traditional attention models, particularly those used in encoder-decoder architectures for machine translation, there are two distinct sequences: one produces queries (from the decoder), and the other provides keys and values (from the encoder). Essentially, one sequence asks questions while the other offers the answers.
Self-attention revolutionizes this approach by deriving queries, keys, and values all from the same sequence. Every token in the input becomes both the inquirer and the provider. This means that each token can “look at” every other token, enabling the model to directly capture relationships and long-range dependencies that were previously difficult for older models, such as RNNs, to handle.
Imagine the sentence: “The cat sat on the mat.” Each word (or token) needs to understand which other words in the sentence contribute to its meaning. Self-attention enables every word to “look at” every other word, gathering context. Here’s how it works:
Each word generates three vectors:
Query (Q): Think of this as what the word is looking for in the sentence.
Key (K): This indicates what the word can offer or its characteristics.
Value (V): This carries the actual content or information of the word.
Each word computes a dot product between its own Query vector and the Key vectors of all words in the sentence, including itself, to measure their contextual relevance. This operation produces scores that measure how much each word should focus on the others. For example, if “cat” and “mat” share related features, the score between “cat’s Query” and “mat’s Key” might be high, suggesting a strong contextual relationship.
Interview trap: An interviewer might ask, “Why do we need separate Q, K, V projections? Why not just use the embeddings directly?” and candidates sometimes say, “It’s just how transformers are designed.”
However, there’s a fundamental reason! Using separate learned projections allows the model to learn different representations for “what I’m looking for” (Q), “what I offer” (K), and “what information I carry” (V). If we used raw embeddings, a token couldn’t distinguish between its role as a query versus its role as a key—the attention pattern would be symmetric and less expressive. The projections let the model learn asymmetric relationships (e.g., "pronoun looks for noun" is different from “noun looks for pronoun”).
These raw scores can be uneven or not directly comparable. We apply a function called softmax to convert them into probabilities that add up to 1. Softmax essentially scales the numbers so that higher scores become closer to 1 (high importance), and lower scores approach 0. This normalization helps determine the relative weight each token should receive.
Each word then uses these normalized scores to weigh the Value vectors of all other words. By summing these weighted Values, the word gets a new representation that captures all the relevant context from the entire sentence.
As a result, single-head self-attention scales as