Attention Mechanism
Sources:
Notation
Symbol | Type | Explanation |
---|---|---|
Input vector | Vector at position |
|
Context vector | Vector at position |
|
Output vector | Vector at position |
|
Sequence length | Length of the input and output sequences | |
Sequence length | Length of the context sequence | |
Vector dimension | Dimension of all vectors | |
Query vector | Query derived from input vector |
|
Key vector | Key derived from context vector |
|
Value vector | Value derived from context vector |
|
Attention weight | Normalized weight indicating influence of |
|
Attention score | Unnormalized attention score between |
|
Attention matrix | Matrix of all attention weights | |
Query matrix | Matrix of all query vectors (rows are |
|
Key matrix | Matrix of all key vectors (rows are |
|
Value matrix | Matrix of all value vectors (rows are |
|
Output matrix | Matrix of all output vectors (rows are |
|
Projection matrix | Weight matrix for generating queries from input | |
Projection matrix | Weight matrix for generating keys from context | |
Projection matrix | Weight matrix for generating values from context | |
Function | Softmax normalization function | |
Function | Exponential function used in softmax calculation | |
Special value | Used to mask future elements in causality mask |
Attention
Attention is a sequence-to-sequence operation that allows a sequence
To implement this, we represent the so-called context sequence
- Query
: derived from , represents "what we're looking for" - Key
: derived from , represents "what this context element offers" - Value
: derived from , represents "the actual content to retrieve"
The key-value pairs
To compute the output vector
These weights, also called attention weights, are derived from 2 steps:
First, compute the dot product of the query vector
and key vectors :The resulting scores are then normalized using the softmax function to form the final attention weights:
Variants:
- Cross-attention: When
and are different sequences (the general case above) - Self-attention: When
and therefore , (a sequence attending to itself)
Below is a visual illustration of self-attention. Note that the softmax operation over the weights is not illustrated:
The matrix form
The the attention equation
be the matrix of query vectors (each row is ) be the matrix of key vectors (each row is ) be the matrix of value vectors (each row is )
Then the attention weights matrix
(output matrix, rows for output vectors) (attention weights matrix, rows for query vectors, and rows for key vectors) (value matrix, rows for value vectors)
Note that all vectors are stored as row vectors in the matrices.
Why does the notation change?
The vector/matrix multiplication notation differs between the vector
- Vector case: in the first case,
and are column vectors (common in Math)→ need for dot product - Matrix case:
and store vectors as row vectors (common in Deep Learning community)→ need for dot products.
We can simply transposing all matrices the matrix form formulations to get coherent representation.
How to generate queries, keys and values
How do we generate queries
- Cross-attention:
, , - Self-attention (when
): , ,
where all vectors are column vectors.
In matrix form:
Note:
- You may wonder that why we prefer linear transformations than non-linear projections (e.g., MLPs with ReLU). I personally believe that this is pure engineering choice. That being said, simple linear transformations are satisfying enough.
- The attention computation itself
contains no trainable parameters—it's just matrix multiplications and softmax. All learnable parameters are in the , , and matrices that generate the queries, keys, and values.
Multi-head attention
Rather than applying a single attention function with
Causality mask for autoregressive models
Autoregressive models, e.g., GPT-based models, utilize self-attention mechanism and typically require a causality mask to ensure predictions depend only on preceding elements, i.e.,
To enforce this constraint, a causality mask prevents attention to future positions:
, making the attention weights matrix
The causality mask is implemneted by modifying the dot product scores:
After applying softmax normalization, future positions (
Note: Cross-attention typically does not require causal masking since the decoder attends to a fully available context (e.g., encoder outputs in translation, image features in captioning).
Code
Single self-attention head
1 | class AttentionHead(nn.Module): |
Muti-head attention block
1 | class MultiHeadAttention(nn.Module): |