Attention Mechanism

Sources:

  1. Transformer from scratch
  2. Attention Is All You Need 2017 paper

Notation

Symbol Type Explanation
xiRC Input vector Vector at position i in the input sequence
zjRC Context vector Vector at position j in the context sequence
yiRC Output vector Vector at position i in the output sequence
TN Sequence length Length of the input and output sequences
SN Sequence length Length of the context sequence
CN Vector dimension Dimension of all vectors
qiRC Query vector Query derived from input vector xi
kjRC Key vector Key derived from context vector zj
vjRC Value vector Value derived from context vector zj
wijR Attention weight Normalized weight indicating influence of vj on yi
wijR Attention score Unnormalized attention score between qi and kj
WRT×S Attention matrix Matrix of all attention weights
QRT×C Query matrix Matrix of all query vectors (rows are qi)
KRS×C Key matrix Matrix of all key vectors (rows are kj)
VRS×C Value matrix Matrix of all value vectors (rows are vj)
YRT×C Output matrix Matrix of all output vectors (rows are yi)
WqRC×C Projection matrix Weight matrix for generating queries from input
WkRC×C Projection matrix Weight matrix for generating keys from context
WvRC×C Projection matrix Weight matrix for generating values from context
softmax Function Softmax normalization function
exp Function Exponential function used in softmax calculation
Special value Used to mask future elements in causality mask

Attention

Attention is a sequence-to-sequence operation that allows a sequence x1,x2,,xT to attend to another sequence z1,z2,,zS to produce an output sequence y1,y2,,yT. The core idea is that each output yi is computed by "looking at" all elements in the sequence z1,z2,,zS and forming a weighted combination based on how relevant each zj is to xi.

To implement this, we represent the so-called context sequence {zj}jS as key-value pairs {(kj,vj)}jS and each element xi as a query qi:

  • Query qi: derived from xi, represents "what we're looking for"
  • Key kj: derived from zj, represents "what this context element offers"
  • Value vj: derived from zj, represents "the actual content to retrieve"

The key-value pairs (kj,vj) are simply a representation of the original context sequence zj.

To compute the output vector yi, attention forms a linear combination of the value vectors: (1)yi=jSwijvj, where j indexes the key-value sequence of length S, and the weights wij sum to one over all j.

These weights, also called attention weights, are derived from 2 steps:

  1. First, compute the dot product of the query vector qi and key vectors kj: (2)wij=kjqi.

  2. The resulting scores are then normalized using the softmax function to form the final attention weights: (3)wij=exp(wij)jSexp(wij).

Variants:

  • Cross-attention: When xi and zj are different sequences (the general case above)
  • Self-attention: When xi=zj and therefore qi=ki=vi, T=S (a sequence attending to itself)

Below is a visual illustration of self-attention. Note that the softmax operation over the weights is not illustrated:

self-attention

The matrix form

The the attention equation (1) can be expressed compactly in matrix form. Let:

  • QRT×C be the matrix of query vectors (each row is qi)
  • KRS×C be the matrix of key vectors (each row is kj)
  • VRS×C be the matrix of value vectors (each row is vj)

Then the attention weights matrix WRT×S and output matrix YRT×C are: (4)W=softmax(QK), and (5)Y=WV where the softmax is applied row-wise to QK and:

  • YRT×C (output matrix, T rows for T output vectors)
  • WRT×S (attention weights matrix, T rows for T query vectors, and S rows for S key vectors)
  • VRS×C (value matrix, S rows for S value vectors)

Note that all vectors are stored as row vectors in the matrices.

Why does the notation change?

The vector/matrix multiplication notation differs between the vector (2) formulation and matrix (4) formulation because of different vector conventions:

  • Vector case: in the first case, qi and kj are column vectors (common in Math)→ need kjqifor dot product
  • Matrix case: Q and K store vectors as row vectors (common in Deep Learning community)→ need QK for dot products.

We can simply transposing all matrices the matrix form formulations to get coherent representation.

How to generate queries, keys and values

How do we generate queries {qi}iT from input sequence {xi}iT and keys/values {(kj,vj)}jS from context sequence {zi}iT.? The answer is surprisingly simple—linear transformations using learnable matrices:

  • Cross-attention: q=Wqx, kj=Wkzj, v=Wvz
  • Self-attention (when xi=zi): qi=Wqxi, ki=Wkxi, vi=Wvxi

where all vectors are column vectors.

In matrix form: Q=XWqK=XWkV=XWv where all vectors are stored as row vectors in the matrices.

Note:

  1. You may wonder that why we prefer linear transformations than non-linear projections (e.g., MLPs with ReLU). I personally believe that this is pure engineering choice. That being said, simple linear transformations are satisfying enough.
  2. The attention computation itself (5) contains no trainable parameters—it's just matrix multiplications and softmax. All learnable parameters are in the Wq, Wk, and Wv matrices that generate the queries, keys, and values.

Multi-head attention

Rather than applying a single attention function with C-dimensional keys, values, and queries, the model applies this function h times independently on different subspaces of dimension Ch. The outputs are then concatenated and linearly transformed to form the final output, allowing the model to capture diverse aspects of the input.

Causality mask for autoregressive models

Autoregressive models, e.g., GPT-based models, utilize self-attention mechanism and typically require a causality mask to ensure predictions depend only on preceding elements, i.e., xi can NOT attend to any zj where j>i.

To enforce this constraint, a causality mask prevents attention to future positions:

masked-self-attention

, making the attention weights matrix W in (4) lower-triangular: W=(w11000w21w2200w31w32w330wT1wT2wT3wTS) where S=T as we are talking about self-attention.

The causality mask is implemneted by modifying the dot product scores: wij={qikj,if ji,,if j>i.

After applying softmax normalization, future positions (j>i) are effectively zeroed out since exp()=0: wij=exp(wij)k=1iexp(wik),

Note: Cross-attention typically does not require causal masking since the decoder attends to a fully available context (e.g., encoder outputs in translation, image features in captioning).

Code

Single self-attention head

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class AttentionHead(nn.Module):
"""
A single attention head for self-attention in transformer models.

Attributes:
head_size (int): Dimension of the out put of the attention head.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length for attention.
to_query, to_key, to_value (nn.Linear): Layers to compute query, key, and value vectors.
tril (torch.Tensor): Lower triangular matrix for causal masking.
"""

def __init__(self, head_dim, embd_dim, context_len):
"""
Initializes the AttentionHead.

Args:
head_dim (int): Dimension of the out put of the attention head.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length.
"""
super().__init__()

self.head_size, self.embd_dim, self.context_len = head_dim, embd_dim, context_len
self.to_query = nn.Linear(embd_dim, head_dim)
self.to_key = nn.Linear(embd_dim, head_dim)
self.to_value = nn.Linear(embd_dim, head_dim)

# Causal mask for future tokens
self.register_buffer('tril', torch.tril(torch.ones(self.context_len, self.context_len)))

def forward(self, x, causality_mask=False):
"""
Computes attention scores.

Args:
x (torch.Tensor): Input of shape [B, T, C].
causality_mask (bool): If True, applies a mask to prevent attending to future tokens.

Returns:
torch.Tensor: Output of shape [B, T, head_dim].
"""
B, T, C = x.shape
k = self.to_key(x) # [B, T, head_dim]
q = self.to_query(x) # [B, T, head_dim]
v = self.to_value(x) # [B, T, head_dim]

weight = q @ k.transpose(-2,-1) # [B, T, T]

if causality_mask:
# Masking future tokens
weight = weight.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # [B, T, T]

weight = F.softmax(weight, dim=-1)
y = weight @ v # [B, T, head_dim]
return y

Muti-head attention block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class MultiHeadAttention(nn.Module):
"""
Multi-head self-attention mechanism.

This class implements multiple heads of self-attention in parallel,
allowing the model to attend to information from different representation
subspaces at different positions.

Attributes:
head_num (int): Number of attention heads.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length for attention.
head_size (int): Dimension of each attention head.
heads (nn.ModuleList): List of AttentionHead modules.
proj (nn.Linear): Linear layer to combine the outputs from all heads.
"""

def __init__(self, head_num, embd_dim, context_len):
"""
Initializes the MultiHeadAttention.

Args:
head_num (int): Number of attention heads.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length.
"""
super().__init__()
self.head_num, self.embd_dim, self.context_len = head_num, embd_dim, context_len
self.head_size = embd_dim // head_num
self.heads = nn.ModuleList([AttentionHead(self.head_size, self.embd_dim, self.context_len) for _ in range(head_num)])
self.proj = nn.Linear(self.embd_dim, self.embd_dim) # Mix the data output by the heads back together.


def forward(self, x):
"""
Applies multi-head attention to the input.

Args:
x (torch.Tensor): Input tensor of shape [B, T, C].

Returns:
torch.Tensor: Output tensor of shape [B, T, embd_dim].
"""
outs = torch.cat([head(x) for head in self.heads], dim=-1) # Concatenate outputs from all heads
x = self.proj(outs) # Project back to the original embedding dimension
return x