Weight indicating the influence of input vector on
Unnormalized attention score between and
Weight matrix for the entire sequence
Matrix of all input vectors
Matrix of all output vectors
(single-head)
Weight matrix for transforming input vectors into queries in single-head attention
(single-head)
Weight matrix for transforming input vectors into keys in single-head attention
(single-head)
Weight matrix for transforming input vectors into values in single-head attention
(single-head)
Query vector for the input vector in single-head attention
(single-head)
Key vector for the input vector in single-head attention
(single-head)
Value vector for the input vector in single-head attention
(multi-head)
Weight matrix for transforming input vectors into queries in multi-head attention
(multi-head)
Weight matrix for transforming input vectors into keys in multi-head attention
(multi-head)
Weight matrix for transforming input vectors into values in multi-head attention
(multi-head)
Query vector for the input vector in multi-head attention
(multi-head)
Key vector for the input vector in multi-head attention
(multi-head)
Value vector for the input vector in multi-head attention
Number of heads in multi-head attention
Function
The softmax function
Function
Exponential function used in softmax calculation
Special value
Used to mask future elements in the causality mask
Self-Attention
Self-attention is a sequence-to-sequence operation that transforms an input sequence of vectors into an output sequence , where each vector has dimension . To compute the output vector , self-attention forms a weighted average of the input vectors:
where indexes the entire sequence, and the weights sum to one over all . These weights are derived from the input vectors and using a dot product:
The resulting scores are then normalized using the softmax function to form the final attention weights:
self-attention
A visual illustration of basic self-attention. Note that the softmax operation over the weights is not illustrated.
In matrix form, the self-attention operation is expressed as:
where
Additional tricks
The actual self-attention used in modern transformers relies on several additional tricks.
Queries, keys, and values
Each input vector in self-attention is used in three ways:
to determine weights for its own output (query),
to contribute to weights for others' outputs (key), and
to participate in the weighted sum (value).
To facilitate these roles, each is linearly transformed into distinct vectors using learnable matrices , , and :
The attention weights are then computed as:
apply a normalization then
and the output vector is:
This mechanism introduces learnable parameters into the self-attention layer, enabling it to adaptively modify the input vectors for their roles.
Another notation is Thus where .
-th col in K^T is the attention to the T' searched elements.
Multi-head attention
Rather than applying a single attention function with -dimensional keys, values, and queries, the model applies this function times independently on different subspaces of dimension . The outputs are then concatenated and linearly transformed to form the final output, allowing the model to capture diverse aspects of the input.
Causality mask for autoregressive models
In autoregressive models, where predictions depend solely on preceding elements, a causality mask is applied to prevent attention to future elements.
masked-attention
Masking the self attention, to ensure that elements can only attend to input elements that precede them in the sequence. Note that the multiplication symbol is slightly misleading: we actually set the masked out elements (the white squares) to .
The mask modifies the dot product scores:
This ensures that during the softmax calculation:
Future elements () are effectively ignored, as , maintaining the autoregressive property and ensuring the model does not "look ahead" in the sequence.
classAttentionHead(nn.Module): """ A single attention head for self-attention in transformer models. Attributes: head_size (int): Dimension of the out put of the attention head. embd_dim (int): Dimension of the input embeddings. context_len (int): Maximum sequence length for attention. to_query, to_key, to_value (nn.Linear): Layers to compute query, key, and value vectors. tril (torch.Tensor): Lower triangular matrix for causal masking. """
def__init__(self, head_dim, embd_dim, context_len): """ Initializes the AttentionHead. Args: head_dim (int): Dimension of the out put of the attention head. embd_dim (int): Dimension of the input embeddings. context_len (int): Maximum sequence length. """ super().__init__() self.head_size, self.embd_dim, self.context_len = head_dim, embd_dim, context_len self.to_query = nn.Linear(embd_dim, head_dim) self.to_key = nn.Linear(embd_dim, head_dim) self.to_value = nn.Linear(embd_dim, head_dim) # Causal mask for future tokens self.register_buffer('tril', torch.tril(torch.ones(self.context_len, self.context_len))) defforward(self, x, causality_mask=False): """ Computes attention scores. Args: x (torch.Tensor): Input of shape [B, T, C]. causality_mask (bool): If True, applies a mask to prevent attending to future tokens. Returns: torch.Tensor: Output of shape [B, T, head_dim]. """ B, T, C = x.shape k = self.to_key(x) # [B, T, head_dim] q = self.to_query(x) # [B, T, head_dim] v = self.to_value(x) # [B, T, head_dim] weight = q @ k.transpose(-2,-1) # [B, T, T] if causality_mask: # Masking future tokens weight = weight.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # [B, T, T] weight = F.softmax(weight, dim=-1) y = weight @ v # [B, T, head_dim] return y
classMultiHeadAttention(nn.Module): """ Multi-head self-attention mechanism. This class implements multiple heads of self-attention in parallel, allowing the model to attend to information from different representation subspaces at different positions. Attributes: head_num (int): Number of attention heads. embd_dim (int): Dimension of the input embeddings. context_len (int): Maximum sequence length for attention. head_size (int): Dimension of each attention head. heads (nn.ModuleList): List of AttentionHead modules. proj (nn.Linear): Linear layer to combine the outputs from all heads. """
def__init__(self, head_num, embd_dim, context_len): """ Initializes the MultiHeadAttention. Args: head_num (int): Number of attention heads. embd_dim (int): Dimension of the input embeddings. context_len (int): Maximum sequence length. """ super().__init__() self.head_num, self.embd_dim, self.context_len = head_num, embd_dim, context_len self.head_size = embd_dim // head_num self.heads = nn.ModuleList([AttentionHead(self.head_size, self.embd_dim, self.context_len) for _ inrange(head_num)]) self.proj = nn.Linear(self.embd_dim, self.embd_dim) # Mix the data output by the heads back together.
defforward(self, x): """ Applies multi-head attention to the input. Args: x (torch.Tensor): Input tensor of shape [B, T, C]. Returns: torch.Tensor: Output tensor of shape [B, T, embd_dim]. """ outs = torch.cat([head(x) for head in self.heads], dim=-1) # Concatenate outputs from all heads x = self.proj(outs) # Project back to the original embedding dimension return x