Attention Mechanism
Sources:
- Transformer from scratch
- Attention Is All You Need 2017 paper
Notation
| Symbol | Type | Explanation |
|---|---|---|
| \(x_i \in \mathbb{R}^{C}\) | Input vector | Vector at position \(i\) in the input sequence |
| \(z_j \in \mathbb{R}^{C}\) | Context vector | Vector at position \(j\) in the context sequence |
| \(y_i \in \mathbb{R}^{C}\) | Output vector | Vector at position \(i\) in the output sequence |
| \(T \in \mathbb{N}\) | Sequence length | Length of the input and output sequences |
| \(S \in \mathbb{N}\) | Sequence length | Length of the context sequence |
| \(C \in \mathbb{N}\) | Vector dimension | Dimension of all vectors |
| \(q_i \in \mathbb{R}^{C}\) | Query vector | Query derived from input vector \(x_i\) |
| \(k_j \in \mathbb{R}^{C}\) | Key vector | Key derived from context vector \(z_j\) |
| \(v_j \in \mathbb{R}^{C}\) | Value vector | Value derived from context vector \(z_j\) |
| \(w_{ij} \in \mathbb{R}\) | Attention weight | Normalized weight indicating influence of \(v_j\) on \(y_i\) |
| \(w_{ij}' \in \mathbb{R}\) | Attention score | Unnormalized attention score between \(q_i\) and \(k_j\) |
| \(W \in \mathbb{R}^{T \times S}\) | Attention matrix | Matrix of all attention weights |
| \(Q \in \mathbb{R}^{T \times C}\) | Query matrix | Matrix of all query vectors (rows are \(q_i^{\top}\)) |
| \(K \in \mathbb{R}^{S \times C}\) | Key matrix | Matrix of all key vectors (rows are \(k_j^{\top}\)) |
| \(V \in \mathbb{R}^{S \times C}\) | Value matrix | Matrix of all value vectors (rows are \(v_j^{\top}\)) |
| \(Y \in \mathbb{R}^{T \times C}\) | Output matrix | Matrix of all output vectors (rows are \(y_i^{\top}\)) |
| \(W_q \in \mathbb{R}^{C \times C}\) | Projection matrix | Weight matrix for generating queries from input |
| \(W_k \in \mathbb{R}^{C \times C}\) | Projection matrix | Weight matrix for generating keys from context |
| \(W_v \in \mathbb{R}^{C \times C}\) | Projection matrix | Weight matrix for generating values from context |
| \(\text{softmax}\) | Function | Softmax normalization function |
| \(\exp\) | Function | Exponential function used in softmax calculation |
| \(-\infty\) | Special value | Used to mask future elements in causality mask |
Attention
Attention is a sequence-to-sequence operation that allows a sequence \({x_1, x_2, \dots, x_T}\) to attend to another sequence \({z_1, z_2, \dots, z_S}\) to produce an output sequence \({y_1, y_2, \dots, y_T}\). The core idea is that each output \(y_i\) is computed by "looking at" all elements in the sequence \({z_1, z_2, \dots, z_S}\) and forming a weighted combination based on how relevant each \(z_j\) is to \(x_i\).
To implement this, we represent each element \(x_i\) as a query \(q_i\) and the so-called context sequence \(\{z_j\}_{j=1}^S\) as key-value pairs \(\{(k_j, v_j)\}_{j=1}^S\):
- Query \(q_i\): derived from \(x_i\), represents "what we're looking for"
- Key \(k_j\): derived from \(z_j\), represents "what this context element offers"
- Value \(v_j\): derived from \(z_j\), represents "the actual content to retrieve"
The key-value pairs \({(k_j, v_j)}\) are simply a representation of the original context sequence \({z_j}\).
To compute the output vector \(y_{\textcolor{red}{i}}\), attention forms a linear combination of the value vectors: \[ \begin{equation} \label{eq1} y_{\textcolor{red}{i}} = \sum_{\textcolor{green}{j}}^S w_{\textcolor{red}{i}\textcolor{green}{j}} v_{\textcolor{green}{j}}, \end{equation} \] where \(\textcolor{green}{j}\) indexes the key-value sequence of length \(S\), and the weights \(w_{\textcolor{red}{i}\textcolor{green}{j}}\) sum to one over all \(\textcolor{green}{j}\).
These weights, also called attention weights, are derived from 2 steps:
First, compute the dot product of the query vector \(q_{\textcolor{red}{i}}\) and key vectors \(k_{\textcolor{green}{j}}\): \[ \begin{equation} \label{eq2} w_{\textcolor{red}{i}\textcolor{green}{j}}' = k_{\textcolor{green}{j}}^{\top} q_{\textcolor{red}{i}}. \end{equation} \]
The resulting scores are then normalized using the softmax function to form the final attention weights: \[ \begin{equation} \label{eq3} w_{\textcolor{red}{i}\textcolor{green}{j}} = \frac{\exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}{\sum_{\textcolor{green}{j}}^S \exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}. \end{equation} \]
Variants:
- Cross-attention: When \({x_i}\) and \({z_j}\) are different sequences (the general case above)
- Self-attention: When \({x_i} = {z_j}\) and therefore \(q_i = k_i = v_i\), \(T=S\) (a sequence attending to itself)
Below is a visual illustration of self-attention. Note that the softmax operation over the weights is not illustrated:
The matrix form
The the attention equation \(\eqref{eq1}\) can be expressed compactly in matrix form. Let:
- \(Q \in \mathbb{R}^{T \times C}\) be the matrix of query vectors (each row is \(q_i^{\top}\))
- \(K \in \mathbb{R}^{S \times C}\) be the matrix of key vectors (each row is \(k_j^{\top}\))
- \(V \in \mathbb{R}^{S \times C}\) be the matrix of value vectors (each row is \(v_j^{\top}\))
Then the attention weights matrix \(W \in \mathbb{R}^{T \times S}\) and output matrix \(Y \in \mathbb{R}^{T \times C}\) are: \[ \begin{equation} \label{eq4} W = \text{softmax}(QK^{\top}), \end{equation} \] and \[ \begin{equation} \label{eq5} Y = WV \end{equation} \] where the softmax is applied row-wise to \(QK^{\top}\) and:
- \(Y \in \mathbb{R}^{T \times C}\) (output matrix, \(T\) rows for \(T\) output vectors)
- \(W \in \mathbb{R}^{T \times S}\) (attention weights matrix, \(T\) rows for \(T\) query vectors, and \(S\) rows for \(S\) key vectors)
- \(V \in \mathbb{R}^{S \times C}\) (value matrix, \(S\) rows for \(S\) value vectors)
Note that all vectors are stored as row vectors in the matrices.
Why does the notation change?
The vector/matrix multiplication notation differs between the vector \(\eqref{eq2}\) formulation and matrix \(\eqref{eq4}\) formulation because of different vector conventions:
- Vector case: in the first case, \(q_i\) and \(k_j\) are column vectors (common in Math)→ need \(k_{\textcolor{green}{j}}^{\top} q_{\textcolor{red}{i}}\)for dot product
- Matrix case: \(Q\) and \(K\) store vectors as row vectors (common in Deep Learning community)→ need \(QK^{\top}\) for dot products.
We can simply transposing all matrices the matrix form formulations to get coherent representation.
How to generate queries, keys and values
How do we generate queries \(\{q_i\}_{i=1}^T\) from input sequence \(\{x_i\}_{i=1}^T\) and keys/values \(\{(k_j, v_j)\}_{j=1}^S\) from context sequence \(\{z_j\}_{j=1}^S\).? The answer is surprisingly simple—linear transformations using learnable matrices:
- Cross-attention: $ q_{} = W_q x_{}$, \(k_{\textcolor{green}{j}} = W_k z_{\textcolor{green}{j}}\), $v_{} = W_v z_{} $
- Self-attention (when \(x_i = z_i\)): \(q_{\textcolor{red}{i}} = W_q x_{\textcolor{red}{i}}\), \(k_{\textcolor{red}{i}} = W_k x_{\textcolor{red}{i}}\), \(v_{\textcolor{red}{i}} = W_v x_{\textcolor{red}{i}}\)
where all vectors are column vectors.
In matrix form: \[ \begin{aligned} & Q = X W_q^{\top} \\ & K = X W_k^{\top} \\ & V = X W_v^{\top} \end{aligned} \] where all vectors are stored as row vectors in the matrices.
Note:
- You may wonder that why we prefer linear transformations than non-linear projections (e.g., MLPs with ReLU). I personally believe that this is pure engineering choice. That being said, simple linear transformations are satisfying enough.
- The attention computation itself \(\eqref{eq5}\) contains no trainable parameters—it's just matrix multiplications and softmax. All learnable parameters are in the \(W_q\), \(W_k\), and \(W_v\) matrices that generate the queries, keys, and values.
Multi-head attention
Rather than applying a single attention function with \(C\)-dimensional keys, values, and queries, the model applies this function \(h\) times independently on different subspaces of dimension \(\frac{C}{h}\). The outputs are then concatenated and linearly transformed to form the final output, allowing the model to capture diverse aspects of the input.
Causality mask for autoregressive models
Autoregressive models, e.g., GPT-based models, utilize self-attention mechanism and typically require a causality mask to ensure predictions depend only on preceding elements, i.e., \(x_{\textcolor{red}{i}}\) can NOT attend to any \(z_{\textcolor{green}{j}}\) where \(\textcolor{green}{j} > \textcolor{red}{i}\).
To enforce this constraint, a causality mask prevents attention to future positions:
, making the attention weights matrix \(W\) in \(\eqref{eq4}\) lower-triangular: \[ W = \begin{pmatrix} w_{11} & 0 & 0 & \cdots & 0 \\ w_{21} & w_{22} & 0 & \cdots & 0 \\ w_{31} & w_{32} & w_{33} & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ w_{T1} & w_{T2} & w_{T3} & \cdots & w_{TS} \end{pmatrix} \] where \(S=T\) as we are talking about self-attention.
The causality mask is implemneted by modifying the dot product scores: \[ w_{\textcolor{red}{i}\textcolor{green}{j}}' = \begin{cases} q_{\textcolor{red}{i}}^{\top} k_{\textcolor{green}{j}}, & \text{if } \textcolor{green}{j} \leq \textcolor{red}{i}, \\ -\infty, & \text{if } \textcolor{green}{j} > \textcolor{red}{i}. \end{cases} \]
After applying softmax normalization, future positions (\(\textcolor{green}{j} > \textcolor{red}{i}\)) are effectively zeroed out since \(\exp(-\infty) = 0\): \[ w_{\textcolor{red}{i}\textcolor{green}{j}} = \frac{\exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}{\sum_{\textcolor{green}{k}=1}^{\textcolor{red}{i}} \exp(w_{\textcolor{red}{i}\textcolor{green}{k}}')}, \]
Note: Cross-attention typically does not require causal masking since the decoder attends to a fully available context (e.g., encoder outputs in translation, image features in captioning).
Code
Single self-attention head
1 | class AttentionHead(nn.Module): |
Muti-head attention block
1 | class MultiHeadAttention(nn.Module): |