Self-attention Mechanism

Sources:

  1. Transformer from scratch
  2. Attention Is All You Need 2017 paper

Notation

Symbol Type Explanation
xi RC Input vector at position i in the sequence
yi RC Output vector at position i in the sequence
T N Length of the input and output sequence
C N Dimension of the input and output vectors
wij R Weight indicating the influence of input vector xj on yi
wij R Unnormalized attention score between xi and xj
w RT×T Weight matrix for the entire sequence
x RT×C Matrix of all input vectors
y RT×C Matrix of all output vectors
Wq (single-head) RC×C Weight matrix for transforming input vectors into queries in single-head attention
Wk (single-head) RC×C Weight matrix for transforming input vectors into keys in single-head attention
Wv (single-head) RC×C Weight matrix for transforming input vectors into values in single-head attention
qi (single-head) RC Query vector for the input vector xi in single-head attention
ki (single-head) RC Key vector for the input vector xi in single-head attention
vi (single-head) RC Value vector for the input vector xi in single-head attention
Wq (multi-head) R(C/h)×C Weight matrix for transforming input vectors into queries in multi-head attention
Wk (multi-head) R(C/h)×C Weight matrix for transforming input vectors into keys in multi-head attention
Wv (multi-head) R(C/h)×C Weight matrix for transforming input vectors into values in multi-head attention
qi (multi-head) RC/h Query vector for the input vector xi in multi-head attention
ki (multi-head) RC/h Key vector for the input vector xi in multi-head attention
vi (multi-head) RC/h Value vector for the input vector xi in multi-head attention
h N Number of heads in multi-head attention
softmax Function The softmax function
exp Function Exponential function used in softmax calculation
Special value Used to mask future elements in the causality mask

Self-Attention

Self-attention is a sequence-to-sequence operation that transforms an input sequence of vectors {x1,x2,,xT} into an output sequence {y1,y2,,yT}, where each vector has dimension C. To compute the output vector yi, self-attention forms a weighted average of the input vectors:

yi=jwijxj,

where j indexes the entire sequence, and the weights wij sum to one over all j. These weights are derived from the input vectors xi and xj using a dot product:

wij=xixj.

The resulting scores are then normalized using the softmax function to form the final attention weights:

wij=exp(wij)jexp(wij).

self-attention

A visual illustration of basic self-attention. Note that the softmax operation over the weights is not illustrated.

In matrix form, the self-attention operation is expressed as: y=wx,

where

x=[x1xT]RT×C,y=[y1yT]RT×C,w=[w11w12w1Tw21w22w2TwT1wT2wTT]RT×T.

Additional tricks

The actual self-attention used in modern transformers relies on several additional tricks.

Queries, keys, and values

Each input vector xi in self-attention is used in three ways:

  1. to determine weights for its own output (query),
  2. to contribute to weights for others' outputs (key), and
  3. to participate in the weighted sum (value).

To facilitate these roles, each xi is linearly transformed into distinct vectors using learnable matrices Wq, Wk, and Wv: qi=Wqxi,ki=Wkxi,vi=Wvxi.

The attention weights are then computed as:

wij=qikj,

  1. apply a normalization then

wij=softmax(wij),

and the output vector yi is: yi=jwijvj.

This mechanism introduces learnable parameters into the self-attention layer, enabling it to adaptively modify the input vectors for their roles.

Another notation is X=[x1,x2,,xT]RC×TQ=WqXRCq×TK=WkXRCq×TV=WvXRCv×T Thus Y=Vcol-wise-softmax(KTQ)RCv×T where KTQRT×T.

t-th col in K^T is the attention to the T' searched elements.

Multi-head attention

Rather than applying a single attention function with C-dimensional keys, values, and queries, the model applies this function h times independently on different subspaces of dimension Ch. The outputs are then concatenated and linearly transformed to form the final output, allowing the model to capture diverse aspects of the input.

Causality mask for autoregressive models

In autoregressive models, where predictions depend solely on preceding elements, a causality mask is applied to prevent attention to future elements.

masked-attention

Masking the self attention, to ensure that elements can only attend to input elements that precede them in the sequence. Note that the multiplication symbol is slightly misleading: we actually set the masked out elements (the white squares) to .

The mask modifies the dot product scores:

wij={qikj,if ji,,if j>i.

This ensures that during the softmax calculation:

wij=exp(wij)k=1iexp(wik)

Future elements (j>i) are effectively ignored, as exp()=0, maintaining the autoregressive property and ensuring the model does not "look ahead" in the sequence.

Code

Single self-attention head

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class AttentionHead(nn.Module):
"""
A single attention head for self-attention in transformer models.

Attributes:
head_size (int): Dimension of the out put of the attention head.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length for attention.
to_query, to_key, to_value (nn.Linear): Layers to compute query, key, and value vectors.
tril (torch.Tensor): Lower triangular matrix for causal masking.
"""

def __init__(self, head_dim, embd_dim, context_len):
"""
Initializes the AttentionHead.

Args:
head_dim (int): Dimension of the out put of the attention head.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length.
"""
super().__init__()

self.head_size, self.embd_dim, self.context_len = head_dim, embd_dim, context_len
self.to_query = nn.Linear(embd_dim, head_dim)
self.to_key = nn.Linear(embd_dim, head_dim)
self.to_value = nn.Linear(embd_dim, head_dim)

# Causal mask for future tokens
self.register_buffer('tril', torch.tril(torch.ones(self.context_len, self.context_len)))

def forward(self, x, causality_mask=False):
"""
Computes attention scores.

Args:
x (torch.Tensor): Input of shape [B, T, C].
causality_mask (bool): If True, applies a mask to prevent attending to future tokens.

Returns:
torch.Tensor: Output of shape [B, T, head_dim].
"""
B, T, C = x.shape
k = self.to_key(x) # [B, T, head_dim]
q = self.to_query(x) # [B, T, head_dim]
v = self.to_value(x) # [B, T, head_dim]

weight = q @ k.transpose(-2,-1) # [B, T, T]

if causality_mask:
# Masking future tokens
weight = weight.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # [B, T, T]

weight = F.softmax(weight, dim=-1)
y = weight @ v # [B, T, head_dim]
return y

Muti-head attention block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class MultiHeadAttention(nn.Module):
"""
Multi-head self-attention mechanism.

This class implements multiple heads of self-attention in parallel,
allowing the model to attend to information from different representation
subspaces at different positions.

Attributes:
head_num (int): Number of attention heads.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length for attention.
head_size (int): Dimension of each attention head.
heads (nn.ModuleList): List of AttentionHead modules.
proj (nn.Linear): Linear layer to combine the outputs from all heads.
"""

def __init__(self, head_num, embd_dim, context_len):
"""
Initializes the MultiHeadAttention.

Args:
head_num (int): Number of attention heads.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length.
"""
super().__init__()
self.head_num, self.embd_dim, self.context_len = head_num, embd_dim, context_len
self.head_size = embd_dim // head_num
self.heads = nn.ModuleList([AttentionHead(self.head_size, self.embd_dim, self.context_len) for _ in range(head_num)])
self.proj = nn.Linear(self.embd_dim, self.embd_dim) # Mix the data output by the heads back together.


def forward(self, x):
"""
Applies multi-head attention to the input.

Args:
x (torch.Tensor): Input tensor of shape [B, T, C].

Returns:
torch.Tensor: Output tensor of shape [B, T, embd_dim].
"""
outs = torch.cat([head(x) for head in self.heads], dim=-1) # Concatenate outputs from all heads
x = self.proj(outs) # Project back to the original embedding dimension
return x