Self-attention Mechanism

Sources:

  1. Transformer from scratch
  2. Attention Is All You Need 2017 paper

Notation

Symbol Type Explanation
\(x_i\) \(\in \mathbb{R}^{C}\) Input vector at position \(i\) in the sequence
\(y_i\) \(\in \mathbb{R}^{C}\) Output vector at position \(i\) in the sequence
\(T\) \(\in \mathbb{N}\) Length of the input and output sequence
\(C\) \(\in \mathbb{N}\) Dimension of the input and output vectors
\(w_{ij}\) \(\in \mathbb{R}\) Weight indicating the influence of input vector \(x_j\) on \(y_i\)
\(w_{ij}'\) \(\in \mathbb{R}\) Unnormalized attention score between \(x_i\) and \(x_j\)
\(w\) \(\in \mathbb{R}^{T \times T}\) Weight matrix for the entire sequence
\(x\) \(\in \mathbb{R}^{T \times C}\) Matrix of all input vectors
\(y\) \(\in \mathbb{R}^{T \times C}\) Matrix of all output vectors
\(W_q\) (single-head) \(\in \mathbb{R}^{C \times C}\) Weight matrix for transforming input vectors into queries in single-head attention
\(W_k\) (single-head) \(\in \mathbb{R}^{C \times C}\) Weight matrix for transforming input vectors into keys in single-head attention
\(W_v\) (single-head) \(\in \mathbb{R}^{C \times C}\) Weight matrix for transforming input vectors into values in single-head attention
\(q_i\) (single-head) \(\in \mathbb{R}^{C}\) Query vector for the input vector \(x_i\) in single-head attention
\(k_i\) (single-head) \(\in \mathbb{R}^{C}\) Key vector for the input vector \(x_i\) in single-head attention
\(v_i\) (single-head) \(\in \mathbb{R}^{C}\) Value vector for the input vector \(x_i\) in single-head attention
\(W_q\) (multi-head) \(\in \mathbb{R}^{(C/h) \times C}\) Weight matrix for transforming input vectors into queries in multi-head attention
\(W_k\) (multi-head) \(\in \mathbb{R}^{(C/h) \times C}\) Weight matrix for transforming input vectors into keys in multi-head attention
\(W_v\) (multi-head) \(\in \mathbb{R}^{(C/h) \times C}\) Weight matrix for transforming input vectors into values in multi-head attention
\(q_i\) (multi-head) \(\in \mathbb{R}^{C/h}\) Query vector for the input vector \(x_i\) in multi-head attention
\(k_i\) (multi-head) \(\in \mathbb{R}^{C/h}\) Key vector for the input vector \(x_i\) in multi-head attention
\(v_i\) (multi-head) \(\in \mathbb{R}^{C/h}\) Value vector for the input vector \(x_i\) in multi-head attention
\(h\) \(\in \mathbb{N}\) Number of heads in multi-head attention
\(\operatorname{softmax}\) Function The softmax function
\(\exp\) Function Exponential function used in softmax calculation
\(-\infty\) Special value Used to mask future elements in the causality mask

Self-Attention

Self-attention is a sequence-to-sequence operation that transforms an input sequence of vectors \(\{x_{1}, x_{2}, \dots, x_{T}\}\) into an output sequence \(\{y_1, y_2, \cdots, y_T\}\), where each vector has dimension \(C\). To compute the output vector \(y_{\textcolor{red}{i}}\), self-attention forms a weighted average of the input vectors:

\[ y_{\textcolor{red}{i}} = \sum_{\textcolor{green}{j}} w_{\textcolor{red}{i}\textcolor{green}{j}} x_{\textcolor{green}{j}}, \]

where \(\textcolor{green}{j}\) indexes the entire sequence, and the weights \(w_{\textcolor{red}{i}\textcolor{green}{j}}\) sum to one over all \(\textcolor{green}{j}\). These weights are derived from the input vectors \(x_{\textcolor{red}{i}}\) and \(x_{\textcolor{green}{j}}\) using a dot product:

\[ w_{\textcolor{red}{i}\textcolor{green}{j}}' = x_{\textcolor{red}{i}}^{\top} x_{\textcolor{green}{j}}. \]

The resulting scores are then normalized using the softmax function to form the final attention weights:

\[ w_{\textcolor{red}{i}\textcolor{green}{j}} = \frac{\exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}{\sum_{\textcolor{green}{j}} \exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}. \]

self-attention

A visual illustration of basic self-attention. Note that the softmax operation over the weights is not illustrated.

In matrix form, the self-attention operation is expressed as: \[ y = w x, \]

where

\[ x = \begin{bmatrix} x_1^{\top} \\ \vdots \\ x_T^{\top} \end{bmatrix} \in \mathbb{R}^{T \times C}, \quad y = \begin{bmatrix} y_1^{\top} \\ \vdots \\ y_T^{\top} \end{bmatrix} \in \mathbb{R}^{T \times C}, \quad w = \begin{bmatrix} w_{11} & w_{12} & \cdots & w_{1T} \\ w_{21} & w_{22} & \cdots & w_{2T} \\ \vdots & \vdots & \ddots & \vdots \\ w_{T1} & w_{T2} & \cdots & w_{TT} \end{bmatrix} \in \mathbb{R}^{T \times T}. \]

Additional tricks

The actual self-attention used in modern transformers relies on several additional tricks.

Queries, keys, and values

Each input vector \(x_{\textcolor{red}{i}}\) in self-attention is used in three ways:

  1. to determine weights for its own output (query),
  2. to contribute to weights for others' outputs (key), and
  3. to participate in the weighted sum (value).

To facilitate these roles, each \(x_{\textcolor{red}{i}}\) is linearly transformed into distinct vectors using learnable matrices \(W_q\), \(W_k\), and \(W_v\): \[ q_{\textcolor{red}{i}} = W_q x_{\textcolor{red}{i}}, \quad k_{\textcolor{red}{i}} = W_k x_{\textcolor{red}{i}}, \quad v_{\textcolor{red}{i}} = W_v x_{\textcolor{red}{i}}. \]

The attention weights are then computed as:

\[ w_{\textcolor{red}{i}\textcolor{green}{j}}' = q_{\textcolor{red}{i}}^{\top} k_{\textcolor{green}{j}}, \quad w_{\textcolor{red}{i}\textcolor{green}{j}} = \operatorname{softmax}(w_{\textcolor{red}{i}\textcolor{green}{j}}'), \]

and the output vector \(y_{\textcolor{red}{i}}\) is:

\[ y_{\textcolor{red}{i}} = \sum_{\textcolor{green}{j}} w_{\textcolor{red}{i}\textcolor{green}{j}} v_{\textcolor{green}{j}}. \]

This mechanism introduces learnable parameters into the self-attention layer, enabling it to adaptively modify the input vectors for their roles.

Multi-head attention

Rather than applying a single attention function with \(C\)-dimensional keys, values, and queries, the model applies this function \(h\) times independently on different subspaces of dimension \(\frac{C}{h}\). The outputs are then concatenated and linearly transformed to form the final output, allowing the model to capture diverse aspects of the input.

Causality mask for autoregressive models

In autoregressive models, where predictions depend solely on preceding elements, a causality mask is applied to prevent attention to future elements.

masked-attention

Masking the self attention, to ensure that elements can only attend to input elements that precede them in the sequence. Note that the multiplication symbol is slightly misleading: we actually set the masked out elements (the white squares) to \(-\infty\).

The mask modifies the dot product scores:

\[ w_{\textcolor{red}{i}\textcolor{green}{j}}' = \begin{cases} q_{\textcolor{red}{i}}^{\top} k_{\textcolor{green}{j}}, & \text{if } \textcolor{green}{j} \leq \textcolor{red}{i}, \\ -\infty, & \text{if } \textcolor{green}{j} > \textcolor{red}{i}. \end{cases} \]

This ensures that during the softmax calculation:

\[ w_{\textcolor{red}{i}\textcolor{green}{j}} = \frac{\exp(w_{\textcolor{red}{i}\textcolor{green}{j}}')}{\sum_{\textcolor{green}{k}=1}^{\textcolor{red}{i}} \exp(w_{\textcolor{red}{i}\textcolor{green}{k}}')} \]

Future elements (\(\textcolor{green}{j} > \textcolor{red}{i}\)) are effectively ignored, as \(\exp(-\infty) = 0\), maintaining the autoregressive property and ensuring the model does not "look ahead" in the sequence.

Code

Single self-attention head

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class AttentionHead(nn.Module):
"""
A single attention head for self-attention in transformer models.

Attributes:
head_size (int): Dimension of the out put of the attention head.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length for attention.
to_query, to_key, to_value (nn.Linear): Layers to compute query, key, and value vectors.
tril (torch.Tensor): Lower triangular matrix for causal masking.
"""

def __init__(self, head_dim, embd_dim, context_len):
"""
Initializes the AttentionHead.

Args:
head_dim (int): Dimension of the out put of the attention head.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length.
"""
super().__init__()

self.head_size, self.embd_dim, self.context_len = head_dim, embd_dim, context_len
self.to_query = nn.Linear(embd_dim, head_dim)
self.to_key = nn.Linear(embd_dim, head_dim)
self.to_value = nn.Linear(embd_dim, head_dim)

# Causal mask for future tokens
self.register_buffer('tril', torch.tril(torch.ones(self.context_len, self.context_len)))

def forward(self, x, causality_mask=False):
"""
Computes attention scores.

Args:
x (torch.Tensor): Input of shape [B, T, C].
causality_mask (bool): If True, applies a mask to prevent attending to future tokens.

Returns:
torch.Tensor: Output of shape [B, T, head_dim].
"""
B, T, C = x.shape
k = self.to_key(x) # [B, T, head_dim]
q = self.to_query(x) # [B, T, head_dim]
v = self.to_value(x) # [B, T, head_dim]

weight = q @ k.transpose(-2,-1) # [B, T, T]

if causality_mask:
# Masking future tokens
weight = weight.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # [B, T, T]

weight = F.softmax(weight, dim=-1)
y = weight @ v # [B, T, head_dim]
return y

Muti-head attention block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class MultiHeadAttention(nn.Module):
"""
Multi-head self-attention mechanism.

This class implements multiple heads of self-attention in parallel,
allowing the model to attend to information from different representation
subspaces at different positions.

Attributes:
head_num (int): Number of attention heads.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length for attention.
head_size (int): Dimension of each attention head.
heads (nn.ModuleList): List of AttentionHead modules.
proj (nn.Linear): Linear layer to combine the outputs from all heads.
"""

def __init__(self, head_num, embd_dim, context_len):
"""
Initializes the MultiHeadAttention.

Args:
head_num (int): Number of attention heads.
embd_dim (int): Dimension of the input embeddings.
context_len (int): Maximum sequence length.
"""
super().__init__()
self.head_num, self.embd_dim, self.context_len = head_num, embd_dim, context_len
self.head_size = embd_dim // head_num
self.heads = nn.ModuleList([AttentionHead(self.head_size, self.embd_dim, self.context_len) for _ in range(head_num)])
self.proj = nn.Linear(self.embd_dim, self.embd_dim) # Mix the data output by the heads back together.


def forward(self, x):
"""
Applies multi-head attention to the input.

Args:
x (torch.Tensor): Input tensor of shape [B, T, C].

Returns:
torch.Tensor: Output tensor of shape [B, T, embd_dim].
"""
outs = torch.cat([head(x) for head in self.heads], dim=-1) # Concatenate outputs from all heads
x = self.proj(outs) # Project back to the original embedding dimension
return x