Attention Mechanism

Sources:

  1. Transformer from scratch
  2. Attention Is All You Need 2017 paper

Notation

Symbol Type Explanation
\(x_i \in \mathbb{R}^{C}\) Input vector Vector at position \(i\) in the input sequence
\(z_j \in \mathbb{R}^{C}\) Context vector Vector at position \(j\) in the context sequence
\(y_i \in \mathbb{R}^{C}\) Output vector Vector at position \(i\) in the output sequence
\(T \in \mathbb{N}\) Sequence length Length of the input and output sequences
\(S \in \mathbb{N}\) Sequence length Length of the context sequence
\(C \in \mathbb{N}\) Vector dimension Dimension of all vectors
\(q_i \in \mathbb{R}^{C}\) Query vector Query derived from input vector \(x_i\)
\(k_j \in \mathbb{R}^{C}\) Key vector Key derived from context vector \(z_j\)
\(v_j \in \mathbb{R}^{C}\) Value vector Value derived from context vector \(z_j\)
\(w_{ij} \in \mathbb{R}\) Attention weight Normalized weight indicating influence of \(v_j\) on \(y_i\)
\(w_{ij}' \in \mathbb{R}\) Attention score Unnormalized attention score between \(q_i\) and \(k_j\)
\(W \in \mathbb{R}^{T \times S}\) Attention matrix Matrix of all attention weights
\(Q \in \mathbb{R}^{T \times C}\) Query matrix Matrix of all query vectors (rows are \(q_i^{\top}\))
\(K \in \mathbb{R}^{S \times C}\) Key matrix Matrix of all key vectors (rows are \(k_j^{\top}\))
\(V \in \mathbb{R}^{S \times C}\) Value matrix Matrix of all value vectors (rows are \(v_j^{\top}\))
\(Y \in \mathbb{R}^{T \times C}\) Output matrix Matrix of all output vectors (rows are \(y_i^{\top}\))
\(W_q \in \mathbb{R}^{C \times C}\) Projection matrix Weight matrix for generating queries from input
\(W_k \in \mathbb{R}^{C \times C}\) Projection matrix Weight matrix for generating keys from context
\(W_v \in \mathbb{R}^{C \times C}\) Projection matrix Weight matrix for generating values from context
\(\text{softmax}\) Function Softmax normalization function
\(\exp\) Function Exponential function used in softmax calculation
\(-\infty\) Special value Used to mask future elements in causality mask

Attention

Attention is a sequence-to-sequence operation that allows a sequence \({x_1, x_2, \dots, x_T}\) to attend to another sequence \({z_1, z_2, \dots, z_S}\) to produce an output sequence \({y_1, y_2, \dots, y_T}\). The core idea is that each output \(y_i\) is computed by "looking at" all elements in the sequence \({z_1, z_2, \dots, z_S}\) and forming a weighted combination based on how relevant each \(z_j\) is to \(x_i\).

To implement this, we represent each element \(x_i\) as a query \(q_i\) and the so-called context sequence \(\{z_j\}_{j=1}^S\) as key-value pairs \(\{(k_j, v_j)\}_{j=1}^S\):

  • Query \(q_i\): derived from \(x_i\), represents "what we're looking for"
  • Key \(k_j\): derived from \(z_j\), represents "what this context element offers"
  • Value \(v_j\): derived from \(z_j\), represents "the actual content to retrieve"

The key-value pairs \({(k_j, v_j)}\) are simply a representation of the original context sequence \({z_j}\).

To compute the output vector \(y_{\textcolor{red}{i}}\), attention forms a linear combination of the value vectors: \[ \begin{equation} \label{eq1} y_{\textcolor{red}{i}} = \sum_{\textcolor{green}{j}}^S w_{\textcolor{red}{i}\textcolor{green}{j}} v_{\textcolor{green}{j}}, \end{equation} \] where \(\textcolor{green}{j}\) indexes the key-value sequence of length \(S\), and the weights \(w_{\textcolor{red}{i}\textcolor{green}{j}}\) sum to one over all \(\textcolor{green}{j}\).

These weights, also called attention weights, are derived from 2 steps:

  1. First, compute the dot product of the query vector \(q_{\textcolor{red}{i}}\) and key vectors \(k_{\textcolor{green}{j}}\): \[ \begin{equation} \label{eq2} w_{\textcolor{red}{i}\textcolor{green}{j}}' = k_{\textcolor{green}{j}}^{\top} q_{\textcolor{red}{i}}. \end{equation} \]

  2. The resulting scores are then normalized using the softmax function to form the final attention weights: \[ \begin{equation} \label{eq3} w_{\textcolor{red}{i}\textcolor{green}{j}} = \frac{\exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}{\sum_{\textcolor{green}{j}}^S \exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}. \end{equation} \]

Variants:

  • Cross-attention: When \({x_i}\) and \({z_j}\) are different sequences (the general case above)
  • Self-attention: When \({x_i} = {z_j}\) and therefore \(q_i = k_i = v_i\), \(T=S\) (a sequence attending to itself)

Below is a visual illustration of self-attention. Note that the softmax operation over the weights is not illustrated:

self-attention

The matrix form

The the attention equation \(\eqref{eq1}\) can be expressed compactly in matrix form. Let:

  • \(Q \in \mathbb{R}^{T \times C}\) be the matrix of query vectors (each row is \(q_i^{\top}\))
  • \(K \in \mathbb{R}^{S \times C}\) be the matrix of key vectors (each row is \(k_j^{\top}\))
  • \(V \in \mathbb{R}^{S \times C}\) be the matrix of value vectors (each row is \(v_j^{\top}\))

Then the attention weights matrix \(W \in \mathbb{R}^{T \times S}\) and output matrix \(Y \in \mathbb{R}^{T \times C}\) are: \[ \begin{equation} \label{eq4} W = \text{softmax}(QK^{\top}), \end{equation} \] and \[ \begin{equation} \label{eq5} Y = WV \end{equation} \] where the softmax is applied row-wise to \(QK^{\top}\) and:

  • \(Y \in \mathbb{R}^{T \times C}\) (output matrix, \(T\) rows for \(T\) output vectors)
  • \(W \in \mathbb{R}^{T \times S}\) (attention weights matrix, \(T\) rows for \(T\) query vectors, and \(S\) rows for \(S\) key vectors)
  • \(V \in \mathbb{R}^{S \times C}\) (value matrix, \(S\) rows for \(S\) value vectors)

Note that all vectors are stored as row vectors in the matrices.

Why does the notation change?

The vector/matrix multiplication notation differs between the vector \(\eqref{eq2}\) formulation and matrix \(\eqref{eq4}\) formulation because of different vector conventions:

  • Vector case: in the first case, \(q_i\) and \(k_j\) are column vectors (common in Math)→ need \(k_{\textcolor{green}{j}}^{\top} q_{\textcolor{red}{i}}\)for dot product
  • Matrix case: \(Q\) and \(K\) store vectors as row vectors (common in Deep Learning community)→ need \(QK^{\top}\) for dot products.

We can simply transposing all matrices the matrix form formulations to get coherent representation.

How to generate queries, keys and values

How do we generate queries \(\{q_i\}_{i=1}^T\) from input sequence \(\{x_i\}_{i=1}^T\) and keys/values \(\{(k_j, v_j)\}_{j=1}^S\) from context sequence \(\{z_j\}_{j=1}^S\).? The answer is surprisingly simple—linear transformations using learnable matrices:

  • Cross-attention: $ q_{} = W_q x_{}$, \(k_{\textcolor{green}{j}} = W_k z_{\textcolor{green}{j}}\), $v_{} = W_v z_{} $
  • Self-attention (when \(x_i = z_i\)): \(q_{\textcolor{red}{i}} = W_q x_{\textcolor{red}{i}}\), \(k_{\textcolor{red}{i}} = W_k x_{\textcolor{red}{i}}\), \(v_{\textcolor{red}{i}} = W_v x_{\textcolor{red}{i}}\)

where all vectors are column vectors.

In matrix form: \[ \begin{aligned} & Q = X W_q^{\top} \\ & K = X W_k^{\top} \\ & V = X W_v^{\top} \end{aligned} \] where all vectors are stored as row vectors in the matrices.

Note:

  1. You may wonder that why we prefer linear transformations than non-linear projections (e.g., MLPs with ReLU). I personally believe that this is pure engineering choice. That being said, simple linear transformations are satisfying enough.
  2. The attention computation itself \(\eqref{eq5}\) contains no trainable parameters—it's just matrix multiplications and softmax. All learnable parameters are in the \(W_q\), \(W_k\), and \(W_v\) matrices that generate the queries, keys, and values.

Multi-head attention

Rather than applying a single attention function with \(C\)-dimensional keys, values, and queries, the model applies this function \(h\) times independently on different subspaces of dimension \(\frac{C}{h}\). The outputs are then concatenated and linearly transformed to form the final output, allowing the model to capture diverse aspects of the input.

Causality mask for autoregressive models

Autoregressive models, e.g., GPT-based models, utilize self-attention mechanism and typically require a causality mask to ensure predictions depend only on preceding elements, i.e., \(x_{\textcolor{red}{i}}\) can NOT attend to any \(z_{\textcolor{green}{j}}\) where \(\textcolor{green}{j} > \textcolor{red}{i}\).

To enforce this constraint, a causality mask prevents attention to future positions:

masked-self-attention

, making the attention weights matrix \(W\) in \(\eqref{eq4}\) lower-triangular: \[ W = \begin{pmatrix} w_{11} & 0 & 0 & \cdots & 0 \\ w_{21} & w_{22} & 0 & \cdots & 0 \\ w_{31} & w_{32} & w_{33} & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ w_{T1} & w_{T2} & w_{T3} & \cdots & w_{TS} \end{pmatrix} \] where \(S=T\) as we are talking about self-attention.

The causality mask is implemneted by modifying the dot product scores: \[ w_{\textcolor{red}{i}\textcolor{green}{j}}' = \begin{cases} q_{\textcolor{red}{i}}^{\top} k_{\textcolor{green}{j}}, & \text{if } \textcolor{green}{j} \leq \textcolor{red}{i}, \\ -\infty, & \text{if } \textcolor{green}{j} > \textcolor{red}{i}. \end{cases} \]

After applying softmax normalization, future positions (\(\textcolor{green}{j} > \textcolor{red}{i}\)) are effectively zeroed out since \(\exp(-\infty) = 0\): \[ w_{\textcolor{red}{i}\textcolor{green}{j}} = \frac{\exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}{\sum_{\textcolor{green}{k}=1}^{\textcolor{red}{i}} \exp(w_{\textcolor{red}{i}\textcolor{green}{k}}')}, \]

Note: Cross-attention typically does not require causal masking since the decoder attends to a fully available context (e.g., encoder outputs in translation, image features in captioning).

Code

Single self-attention head

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class AttentionHead(nn.Module):
"""
A single attention head for self-attention in transformer models.

Attributes:
head_size (int): Dimension of the out put of the attention head.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length for attention.
to_query, to_key, to_value (nn.Linear): Layers to compute query, key, and value vectors.
tril (torch.Tensor): Lower triangular matrix for causal masking.
"""

def __init__(self, head_dim, embd_dim, context_len):
"""
Initializes the AttentionHead.

Args:
head_dim (int): Dimension of the out put of the attention head.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length.
"""
super().__init__()

self.head_size, self.embd_dim, self.context_len = head_dim, embd_dim, context_len
self.to_query = nn.Linear(embd_dim, head_dim)
self.to_key = nn.Linear(embd_dim, head_dim)
self.to_value = nn.Linear(embd_dim, head_dim)

# Causal mask for future tokens
self.register_buffer('tril', torch.tril(torch.ones(self.context_len, self.context_len)))

def forward(self, x, causality_mask=False):
"""
Computes attention scores.

Args:
x (torch.Tensor): Input of shape [B, T, C].
causality_mask (bool): If True, applies a mask to prevent attending to future tokens.

Returns:
torch.Tensor: Output of shape [B, T, head_dim].
"""
B, T, C = x.shape
k = self.to_key(x) # [B, T, head_dim]
q = self.to_query(x) # [B, T, head_dim]
v = self.to_value(x) # [B, T, head_dim]

weight = q @ k.transpose(-2,-1) # [B, T, T]

if causality_mask:
# Masking future tokens
weight = weight.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # [B, T, T]

weight = F.softmax(weight, dim=-1)
y = weight @ v # [B, T, head_dim]
return y

Muti-head attention block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class MultiHeadAttention(nn.Module):
"""
Multi-head self-attention mechanism.

This class implements multiple heads of self-attention in parallel,
allowing the model to attend to information from different representation
subspaces at different positions.

Attributes:
head_num (int): Number of attention heads.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length for attention.
head_size (int): Dimension of each attention head.
heads (nn.ModuleList): List of AttentionHead modules.
proj (nn.Linear): Linear layer to combine the outputs from all heads.
"""

def __init__(self, head_num, embd_dim, context_len):
"""
Initializes the MultiHeadAttention.

Args:
head_num (int): Number of attention heads.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length.
"""
super().__init__()
self.head_num, self.embd_dim, self.context_len = head_num, embd_dim, context_len
self.head_size = embd_dim // head_num
self.heads = nn.ModuleList([AttentionHead(self.head_size, self.embd_dim, self.context_len) for _ in range(head_num)])
self.proj = nn.Linear(self.embd_dim, self.embd_dim) # Mix the data output by the heads back together.


def forward(self, x):
"""
Applies multi-head attention to the input.

Args:
x (torch.Tensor): Input tensor of shape [B, T, C].

Returns:
torch.Tensor: Output tensor of shape [B, T, embd_dim].
"""
outs = torch.cat([head(x) for head in self.heads], dim=-1) # Concatenate outputs from all heads
x = self.proj(outs) # Project back to the original embedding dimension
return x