Self-attention Mechanism
Sources:
- Transformer from scratch
- Attention Is All You Need 2017 paper
Notation
Symbol | Type | Explanation |
---|---|---|
\(x_i\) | \(\in \mathbb{R}^{C}\) | Input vector at position \(i\) in the sequence |
\(y_i\) | \(\in \mathbb{R}^{C}\) | Output vector at position \(i\) in the sequence |
\(T\) | \(\in \mathbb{N}\) | Length of the input and output sequence |
\(C\) | \(\in \mathbb{N}\) | Dimension of the input and output vectors |
\(w_{ij}\) | \(\in \mathbb{R}\) | Weight indicating the influence of input vector \(x_j\) on \(y_i\) |
\(w_{ij}'\) | \(\in \mathbb{R}\) | Unnormalized attention score between \(x_i\) and \(x_j\) |
\(w\) | \(\in \mathbb{R}^{T \times T}\) | Weight matrix for the entire sequence |
\(x\) | \(\in \mathbb{R}^{T \times C}\) | Matrix of all input vectors |
\(y\) | \(\in \mathbb{R}^{T \times C}\) | Matrix of all output vectors |
\(W_q\) (single-head) | \(\in \mathbb{R}^{C \times C}\) | Weight matrix for transforming input vectors into queries in single-head attention |
\(W_k\) (single-head) | \(\in \mathbb{R}^{C \times C}\) | Weight matrix for transforming input vectors into keys in single-head attention |
\(W_v\) (single-head) | \(\in \mathbb{R}^{C \times C}\) | Weight matrix for transforming input vectors into values in single-head attention |
\(q_i\) (single-head) | \(\in \mathbb{R}^{C}\) | Query vector for the input vector \(x_i\) in single-head attention |
\(k_i\) (single-head) | \(\in \mathbb{R}^{C}\) | Key vector for the input vector \(x_i\) in single-head attention |
\(v_i\) (single-head) | \(\in \mathbb{R}^{C}\) | Value vector for the input vector \(x_i\) in single-head attention |
\(W_q\) (multi-head) | \(\in \mathbb{R}^{(C/h) \times C}\) | Weight matrix for transforming input vectors into queries in multi-head attention |
\(W_k\) (multi-head) | \(\in \mathbb{R}^{(C/h) \times C}\) | Weight matrix for transforming input vectors into keys in multi-head attention |
\(W_v\) (multi-head) | \(\in \mathbb{R}^{(C/h) \times C}\) | Weight matrix for transforming input vectors into values in multi-head attention |
\(q_i\) (multi-head) | \(\in \mathbb{R}^{C/h}\) | Query vector for the input vector \(x_i\) in multi-head attention |
\(k_i\) (multi-head) | \(\in \mathbb{R}^{C/h}\) | Key vector for the input vector \(x_i\) in multi-head attention |
\(v_i\) (multi-head) | \(\in \mathbb{R}^{C/h}\) | Value vector for the input vector \(x_i\) in multi-head attention |
\(h\) | \(\in \mathbb{N}\) | Number of heads in multi-head attention |
\(\operatorname{softmax}\) | Function | The softmax function |
\(\exp\) | Function | Exponential function used in softmax calculation |
\(-\infty\) | Special value | Used to mask future elements in the causality mask |
Self-Attention
Self-attention is a sequence-to-sequence operation that transforms an input sequence of vectors \(\{x_{1}, x_{2}, \dots, x_{T}\}\) into an output sequence \(\{y_1, y_2, \cdots, y_T\}\), where each vector has dimension \(C\). To compute the output vector \(y_{\textcolor{red}{i}}\), self-attention forms a weighted average of the input vectors:
\[ y_{\textcolor{red}{i}} = \sum_{\textcolor{green}{j}} w_{\textcolor{red}{i}\textcolor{green}{j}} x_{\textcolor{green}{j}}, \]
where \(\textcolor{green}{j}\) indexes the entire sequence, and the weights \(w_{\textcolor{red}{i}\textcolor{green}{j}}\) sum to one over all \(\textcolor{green}{j}\). These weights are derived from the input vectors \(x_{\textcolor{red}{i}}\) and \(x_{\textcolor{green}{j}}\) using a dot product:
\[ w_{\textcolor{red}{i}\textcolor{green}{j}}' = x_{\textcolor{red}{i}}^{\top} x_{\textcolor{green}{j}}. \]
The resulting scores are then normalized using the softmax function to form the final attention weights:
\[ w_{\textcolor{red}{i}\textcolor{green}{j}} = \frac{\exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}{\sum_{\textcolor{green}{j}} \exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}. \]
A visual illustration of basic self-attention. Note that the softmax operation over the weights is not illustrated.
In matrix form, the self-attention operation is expressed as: \[ y = w x, \]
where
\[ x = \begin{bmatrix} x_1^{\top} \\ \vdots \\ x_T^{\top} \end{bmatrix} \in \mathbb{R}^{T \times C}, \quad y = \begin{bmatrix} y_1^{\top} \\ \vdots \\ y_T^{\top} \end{bmatrix} \in \mathbb{R}^{T \times C}, \quad w = \begin{bmatrix} w_{11} & w_{12} & \cdots & w_{1T} \\ w_{21} & w_{22} & \cdots & w_{2T} \\ \vdots & \vdots & \ddots & \vdots \\ w_{T1} & w_{T2} & \cdots & w_{TT} \end{bmatrix} \in \mathbb{R}^{T \times T}. \]
Advanced tricks
Queries, keys, and values
Each input vector \(x_{\textcolor{red}{i}}\) in self-attention is used in three ways:
- to determine weights for its own output (query),
- to contribute to weights for others' outputs (key), and
- to participate in the weighted sum (value).
To facilitate these roles, each \(x_{\textcolor{red}{i}}\) is linearly transformed into distinct vectors using learnable matrices \(W_q\), \(W_k\), and \(W_v\): \[ q_{\textcolor{red}{i}} = W_q x_{\textcolor{red}{i}}, \quad k_{\textcolor{red}{i}} = W_k x_{\textcolor{red}{i}}, \quad v_{\textcolor{red}{i}} = W_v x_{\textcolor{red}{i}}. \]
The attention weights are then computed as:
\[ w_{\textcolor{red}{i}\textcolor{green}{j}}' = q_{\textcolor{red}{i}}^{\top} k_{\textcolor{green}{j}}, \quad w_{\textcolor{red}{i}\textcolor{green}{j}} = \operatorname{softmax}(w_{\textcolor{red}{i}\textcolor{green}{j}}'), \]
and the output vector \(y_{\textcolor{red}{i}}\) is:
\[ y_{\textcolor{red}{i}} = \sum_{\textcolor{green}{j}} w_{\textcolor{red}{i}\textcolor{green}{j}} v_{\textcolor{green}{j}}. \]
This mechanism introduces learnable parameters into the self-attention layer, enabling it to adaptively modify the input vectors for their roles.
Multi-head attention
Rather than applying a single attention function with \(C\)-dimensional keys, values, and queries, the model applies this function \(h\) times independently on different subspaces of dimension \(\frac{C}{h}\). The outputs are then concatenated and linearly transformed to form the final output, allowing the model to capture diverse aspects of the input.
Causality mask for autoregressive models
In autoregressive models, where predictions depend solely on preceding elements, a causality mask is applied to prevent attention to future elements.
Masking the self attention, to ensure that elements can only attend to input elements that precede them in the sequence. Note that the multiplication symbol is slightly misleading: we actually set the masked out elements (the white squares) to \(-\infty\).
The mask modifies the dot product scores:
\[ w_{\textcolor{red}{i}\textcolor{green}{j}}' = \begin{cases} q_{\textcolor{red}{i}}^{\top} k_{\textcolor{green}{j}}, & \text{if } \textcolor{green}{j} \leq \textcolor{red}{i}, \\ -\infty, & \text{if } \textcolor{green}{j} > \textcolor{red}{i}. \end{cases} \]
This ensures that during the softmax calculation:
\[ w_{\textcolor{red}{i}\textcolor{green}{j}} = \frac{\exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}{\sum_{\textcolor{green}{k}=1}^{\textcolor{red}{i}} \exp(w_{\textcolor{red}{i}\textcolor{green}{k}}')} \]
Future elements (\(\textcolor{green}{j} > \textcolor{red}{i}\)) are effectively ignored, as \(\exp(-\infty) = 0\), maintaining the autoregressive property and ensuring the model does not "look ahead" in the sequence.
Code
Single self-attention head
1 | class AttentionHead(nn.Module): |
Muti-head attention block
1 | class MultiHeadAttention(nn.Module): |