Attention Mechanisms Explained
Attention mechanisms have revolutionized deep learning, enabling models to focus on relevant parts of the input when producing outputs. From machine translation to image recognition, attention is everywhere. This post breaks down attention from first principles.
The Intuition Behind Attention
Consider translating "The cat sat on the mat" to French. When generating the French word for "cat," the model should focus primarily on "cat" in the input, not "mat" or "the." Attention provides this selective focus.
Without attention, models compress the entire input into a fixed-size vector, losing information. Attention allows the model to "look back" at all input positions when generating each output.
Self-Attention: The Core Mechanism
Self-attention allows each position in a sequence to attend to all positions, computing weighted combinations.
The Query-Key-Value Framework
For each token, we compute three vectors:
- Query (Q): "What am I looking for?"
- Key (K): "What do I contain?"
- Value (V): "What information do I provide?"
import torch
import torch.nn as nn
import math
class SelfAttention(nn.Module):
def __init__(self, d_model: int, d_k: int = None):
super().__init__()
self.d_k = d_k or d_model
self.W_q = nn.Linear(d_model, self.d_k)
self.W_k = nn.Linear(d_model, self.d_k)
self.W_v = nn.Linear(d_model, self.d_k)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, seq_len, d_model)
Q = self.W_q(x) # (batch, seq_len, d_k)
K = self.W_k(x) # (batch, seq_len, d_k)
V = self.W_v(x) # (batch, seq_len, d_k)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, seq_len, seq_len)
scores = scores / math.sqrt(self.d_k) # Scale
# Softmax to get attention weights
attention_weights = torch.softmax(scores, dim=-1)
# Weighted sum of values
output = torch.matmul(attention_weights, V)
return output
Why Scale by √d_k?
Without scaling, as d_k increases, the dot products grow large, pushing softmax into regions with tiny gradients:
# Demonstration of scaling importance
d_k = 512
q = torch.randn(1, d_k)
k = torch.randn(1000, d_k)
# Unscaled: dot products have variance ≈ d_k
unscaled = q @ k.T
print(f"Unscaled std: {unscaled.std():.2f}") # ~22
# Scaled: variance ≈ 1
scaled = (q @ k.T) / math.sqrt(d_k)
print(f"Scaled std: {scaled.std():.2f}") # ~1
# Softmax behavior
print(f"Unscaled softmax max: {torch.softmax(unscaled, dim=-1).max():.4f}") # ~0.99+
print(f"Scaled softmax max: {torch.softmax(scaled, dim=-1).max():.4f}") # More uniform
Multi-Head Attention
One attention head can only focus on one pattern. Multi-head attention runs multiple attention operations in parallel, each learning different relationships.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
batch_size, seq_len, d_model = x.size()
# Project and reshape for multi-head
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Shape: (batch, num_heads, seq_len, d_k)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = torch.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
context = torch.matmul(attention_weights, V)
# Concatenate heads and project
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
output = self.W_o(context)
return output, attention_weights
What Do Different Heads Learn?
Research shows different heads specialize:
- Some track syntactic relationships
- Some capture positional patterns
- Some focus on semantic similarity
- Some learn rare but important patterns
Positional Encoding
Attention is permutation-equivariant—it doesn't inherently know token positions. We must inject positional information.
Sinusoidal Encoding
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_seq_length: int = 5000):
super().__init__()
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.pe[:, :x.size(1), :]
Properties:
- Each position has a unique encoding
- Relative positions can be learned (PE[pos+k] can be represented as linear function of PE[pos])
- Extrapolates to longer sequences
Learned Positional Embeddings
class LearnedPositionalEmbedding(nn.Module):
def __init__(self, d_model: int, max_seq_length: int = 5000):
super().__init__()
self.position_embeddings = nn.Embedding(max_seq_length, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
positions = torch.arange(x.size(1), device=x.device)
return x + self.position_embeddings(positions)
FlashAttention: Memory-Efficient Attention
Standard attention requires O(N²) memory for the attention matrix. FlashAttention reduces this to O(N).
The Memory Problem
# Standard attention memory usage
def standard_attention_memory(seq_len, batch, heads, head_dim):
q_k_v = 3 * batch * heads * seq_len * head_dim # Q, K, V
attention_matrix = batch * heads * seq_len * seq_len # The bottleneck!
return q_k_v + attention_matrix
# For seq_len=4096, batch=8, heads=32, head_dim=128:
# Attention matrix alone: 8 * 32 * 4096 * 4096 * 4 bytes = 17 GB!
FlashAttention Solution
def flash_attention_concept(Q, K, V, block_size=64):
"""
Process attention in blocks without materializing full attention matrix.
Uses online softmax for numerical stability.
"""
N, d = Q.shape
output = torch.zeros_like(Q)
row_max = torch.full((N,), float('-inf'))
row_sum = torch.zeros(N)
for k_start in range(0, N, block_size):
k_end = min(k_start + block_size, N)
K_block = K[k_start:k_end]
V_block = V[k_start:k_end]
# Compute block scores
scores = Q @ K_block.T / math.sqrt(d)
# Online softmax update
block_max = scores.max(dim=-1).values
new_max = torch.maximum(row_max, block_max)
# Rescale and accumulate
scale_old = torch.exp(row_max - new_max)
scale_new = torch.exp(scores - new_max.unsqueeze(-1))
row_sum = row_sum * scale_old + scale_new.sum(dim=-1)
output = output * scale_old.unsqueeze(-1) + scale_new @ V_block
row_max = new_max
return output / row_sum.unsqueeze(-1)
Sparse Attention: Breaking Quadratic Complexity
For very long sequences, even O(N) memory is expensive. Sparse attention patterns reduce complexity further.
Local Attention (Sliding Window)
def local_attention_mask(seq_len: int, window_size: int):
"""Each token attends only to nearby tokens."""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
mask[i, start:end] = True
return mask
Block-Local + Global Attention (Longformer)
def longformer_attention_mask(seq_len: int, window_size: int, global_tokens: int):
"""Combine local attention with global tokens."""
mask = local_attention_mask(seq_len, window_size)
# Global tokens can attend to everything
mask[:global_tokens, :] = True
mask[:, :global_tokens] = True
return mask
Causal Attention for Autoregressive Models
In language modeling, tokens should only attend to previous tokens:
def causal_mask(seq_len: int):
"""Lower triangular mask for autoregressive attention."""
return torch.tril(torch.ones(seq_len, seq_len))
class CausalSelfAttention(nn.Module):
def forward(self, x):
B, T, C = x.size()
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
# Apply causal mask
mask = torch.tril(torch.ones(T, T, device=x.device))
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = torch.softmax(scores, dim=-1)
output = weights @ V
return output
Putting It Together: Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
# Self-attention with residual
attn_output, _ = self.self_attention(x, mask)
x = self.norm1(x + self.dropout1(attn_output))
# Feed-forward with residual
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout2(ff_output))
return x
Conclusion
Attention mechanisms have transformed deep learning:
- Self-attention enables modeling relationships between all positions
- Scaling by √d_k ensures stable gradients
- Multi-head attention captures diverse patterns
- Positional encoding injects sequence order information
- FlashAttention makes long sequences practical
- Sparse attention enables even longer contexts
Understanding these fundamentals is essential for working with modern NLP models and developing new architectures. The attention mechanism's flexibility and power explain why it has become the dominant approach in deep learning.