3. Coding Attention Mechanisms
Overview
Chapter 3 delves into one of the core concepts underpinning modern large language models (LLMs): attention mechanisms. It explores the evolution from earlier architectures to attention-based models and systematically implements several variants of attention mechanisms.
This chapter covers:
- The reasons for using attention mechanisms in neural networks
- A basic self-attention framework, progressing to an enhanced self-attention mechanism
- A causal attention module that allows LLMs to generate one token at a time
- Masking randomly selected attention weights with dropout to reduce overfitting
- Stacking multiple causal attention modules into a multi- head attention module
3.1 The Problem with Modeling Long Sequences
- Traditional architectures like encoder-decoder RNNs face challenges:
- They compress the entire input sequence into a single hidden state.
- The decoder then takes in this hidden state to produce the output
- Long-range dependencies are lost, leading to poor performance on lengthy sequences.
- Attention mechanisms were developed to address this by allowing models to focus on specific parts of the input dynamically during processing.
3.2 Capturing Data Dependencies with Attention
- One major shortcoming in this approach is that the RNN must remember the entire encoded input in a single hidden state before passing it to the decoder.
- The first major improvement was the Bahdanau attention mechanism (2014), which enabled selective access to input tokens during decoding.
- This inspired the self-attention mechanism of the transformer architecture, enabling LLMs to weigh all input tokens when computing each token's representation.
3.3 Attending to Different Parts of the Input with Self-Attention
3.3.1 Simplified Self-Attention
- In its basic form, self-attention computes a context vector for each token in the input sequence, incorporating information from all tokens.
- The mechanism computes attention weights based on the relevance between tokens.
Example:
- Input:
x(1), x(2), ..., x(T)
- When computing the context vector
z(2)
, attention weights are calculated with respect to input elementx(2)
and all other inputs.
Implementation:
- Compute attention scores as dot products between token embeddings.
- Normalize scores using softmax to get attention weights.
- Compute context vectors as weighted sums of token embeddings.
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
for j, x_j in enumerate(inputs):
attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)
3.4 Implementing Self-Attention with Trainable Weights
- Adds trainable weight matrices:
- Queries (
Wq
), Keys (Wk
), and Values (Wv
).
- Queries (
- The steps involve:
- Projecting token embeddings into query, key, and value spaces.
- Computing attention scores as dot products of queries and keys.
- Normalizing scores with softmax and scaling.
- Computing weighted sums of value vectors to get context vectors.
Compact Implementation:
- nn.Linear has an optimized weight initialization scheme, contributing to more stable and effective model training.
import torch.nn as nn
class SelfAttention_v1(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
keys = x @ self.W_key
queries = x @ self.W_query
values = x @ self.W_value
attn_scores = queries @ keys.T
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
context_vec = attn_weights @ values
return context_vec
3.5 Hiding Future Words with Causal Attention
- For tasks like text generation, models must only attend to previous tokens to avoid peeking into the future.
- Causal masking ensures that attention weights for future tokens are set to zero.
- This is achieved by:
- Applying an upper triangular mask to the attention score matrix.
-
Normalizing non-masked scores.
-
Implementation
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
'mask',
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2)
attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
3.6 Extending Single-Head Attention to Multi-Head Attention
- Multi-head attention runs multiple attention mechanisms in parallel, allowing the model to capture diverse relationships.
- Each head processes a separate subspace of the input data.
- Outputs from all heads are concatenated and linearly transformed.
Implementation:
-
A wrapper class stacks multiple attention heads and combines their outputs.
-
Implementation
class MultiHeadAttentionWrapper(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): super().__init__() self.heads = nn.ModuleList( [CausalAttention( d_in, d_out, context_length, dropout, qkv_bias ) for _ in range(num_heads)] ) def forward(self, x): return torch.cat([head(x) for head in self.heads], dim=-1)
Summary
Chapter 3 provides a detailed exploration of attention mechanisms, starting with basic concepts and progressively adding complexity. By the end, readers will understand and implement:
- Self-attention with and without trainable weights.
- The rationale and coding of causal attention.
- Multi-head attention for parallel processing of input features.
These mechanisms form the backbone of transformer architectures like GPT, enabling efficient handling of long-range dependencies and parallelization for scalability.