Back to posts

Attention Mechanisms Explained

·7 min read·attention·transformers·ml

Attention Mechanisms Explained

Attention mechanisms have revolutionized deep learning, enabling models to focus on relevant parts of the input when producing outputs. From machine translation to image recognition, attention is everywhere. This post breaks down attention from first principles.

The Intuition Behind Attention

Consider translating "The cat sat on the mat" to French. When generating the French word for "cat," the model should focus primarily on "cat" in the input, not "mat" or "the." Attention provides this selective focus.

Without attention, models compress the entire input into a fixed-size vector, losing information. Attention allows the model to "look back" at all input positions when generating each output.

Self-Attention: The Core Mechanism

Self-attention allows each position in a sequence to attend to all positions, computing weighted combinations.

The Query-Key-Value Framework

For each token, we compute three vectors:

  • Query (Q): "What am I looking for?"
  • Key (K): "What do I contain?"
  • Value (V): "What information do I provide?"
import torch
import torch.nn as nn
import math

class SelfAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int = None):
        super().__init__()
        self.d_k = d_k or d_model

        self.W_q = nn.Linear(d_model, self.d_k)
        self.W_k = nn.Linear(d_model, self.d_k)
        self.W_v = nn.Linear(d_model, self.d_k)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        Q = self.W_q(x)  # (batch, seq_len, d_k)
        K = self.W_k(x)  # (batch, seq_len, d_k)
        V = self.W_v(x)  # (batch, seq_len, d_k)

        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, seq_len, seq_len)
        scores = scores / math.sqrt(self.d_k)  # Scale

        # Softmax to get attention weights
        attention_weights = torch.softmax(scores, dim=-1)

        # Weighted sum of values
        output = torch.matmul(attention_weights, V)

        return output

Why Scale by √d_k?

Without scaling, as d_k increases, the dot products grow large, pushing softmax into regions with tiny gradients:

# Demonstration of scaling importance
d_k = 512
q = torch.randn(1, d_k)
k = torch.randn(1000, d_k)

# Unscaled: dot products have variance ≈ d_k
unscaled = q @ k.T
print(f"Unscaled std: {unscaled.std():.2f}")  # ~22

# Scaled: variance ≈ 1
scaled = (q @ k.T) / math.sqrt(d_k)
print(f"Scaled std: {scaled.std():.2f}")  # ~1

# Softmax behavior
print(f"Unscaled softmax max: {torch.softmax(unscaled, dim=-1).max():.4f}")  # ~0.99+
print(f"Scaled softmax max: {torch.softmax(scaled, dim=-1).max():.4f}")  # More uniform

Multi-Head Attention

One attention head can only focus on one pattern. Multi-head attention runs multiple attention operations in parallel, each learning different relationships.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        batch_size, seq_len, d_model = x.size()

        # Project and reshape for multi-head
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        # Shape: (batch, num_heads, seq_len, d_k)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        context = torch.matmul(attention_weights, V)

        # Concatenate heads and project
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        output = self.W_o(context)

        return output, attention_weights

What Do Different Heads Learn?

Research shows different heads specialize:

  • Some track syntactic relationships
  • Some capture positional patterns
  • Some focus on semantic similarity
  • Some learn rare but important patterns

Positional Encoding

Attention is permutation-equivariant—it doesn't inherently know token positions. We must inject positional information.

Sinusoidal Encoding

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_seq_length: int = 5000):
        super().__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length).unsqueeze(1)

        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, :x.size(1), :]

Properties:

  • Each position has a unique encoding
  • Relative positions can be learned (PE[pos+k] can be represented as linear function of PE[pos])
  • Extrapolates to longer sequences

Learned Positional Embeddings

class LearnedPositionalEmbedding(nn.Module):
    def __init__(self, d_model: int, max_seq_length: int = 5000):
        super().__init__()
        self.position_embeddings = nn.Embedding(max_seq_length, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        positions = torch.arange(x.size(1), device=x.device)
        return x + self.position_embeddings(positions)

FlashAttention: Memory-Efficient Attention

Standard attention requires O(N²) memory for the attention matrix. FlashAttention reduces this to O(N).

The Memory Problem

# Standard attention memory usage
def standard_attention_memory(seq_len, batch, heads, head_dim):
    q_k_v = 3 * batch * heads * seq_len * head_dim  # Q, K, V
    attention_matrix = batch * heads * seq_len * seq_len  # The bottleneck!
    return q_k_v + attention_matrix

# For seq_len=4096, batch=8, heads=32, head_dim=128:
# Attention matrix alone: 8 * 32 * 4096 * 4096 * 4 bytes = 17 GB!

FlashAttention Solution

def flash_attention_concept(Q, K, V, block_size=64):
    """
    Process attention in blocks without materializing full attention matrix.
    Uses online softmax for numerical stability.
    """
    N, d = Q.shape
    output = torch.zeros_like(Q)
    row_max = torch.full((N,), float('-inf'))
    row_sum = torch.zeros(N)

    for k_start in range(0, N, block_size):
        k_end = min(k_start + block_size, N)
        K_block = K[k_start:k_end]
        V_block = V[k_start:k_end]

        # Compute block scores
        scores = Q @ K_block.T / math.sqrt(d)

        # Online softmax update
        block_max = scores.max(dim=-1).values
        new_max = torch.maximum(row_max, block_max)

        # Rescale and accumulate
        scale_old = torch.exp(row_max - new_max)
        scale_new = torch.exp(scores - new_max.unsqueeze(-1))

        row_sum = row_sum * scale_old + scale_new.sum(dim=-1)
        output = output * scale_old.unsqueeze(-1) + scale_new @ V_block
        row_max = new_max

    return output / row_sum.unsqueeze(-1)

Sparse Attention: Breaking Quadratic Complexity

For very long sequences, even O(N) memory is expensive. Sparse attention patterns reduce complexity further.

Local Attention (Sliding Window)

def local_attention_mask(seq_len: int, window_size: int):
    """Each token attends only to nearby tokens."""
    mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        mask[i, start:end] = True
    return mask

Block-Local + Global Attention (Longformer)

def longformer_attention_mask(seq_len: int, window_size: int, global_tokens: int):
    """Combine local attention with global tokens."""
    mask = local_attention_mask(seq_len, window_size)

    # Global tokens can attend to everything
    mask[:global_tokens, :] = True
    mask[:, :global_tokens] = True

    return mask

Causal Attention for Autoregressive Models

In language modeling, tokens should only attend to previous tokens:

def causal_mask(seq_len: int):
    """Lower triangular mask for autoregressive attention."""
    return torch.tril(torch.ones(seq_len, seq_len))

class CausalSelfAttention(nn.Module):
    def forward(self, x):
        B, T, C = x.size()

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)

        # Apply causal mask
        mask = torch.tril(torch.ones(T, T, device=x.device))
        scores = scores.masked_fill(mask == 0, float('-inf'))

        weights = torch.softmax(scores, dim=-1)
        output = weights @ V

        return output

Putting It Together: Transformer Encoder Layer

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()

        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)

        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        # Self-attention with residual
        attn_output, _ = self.self_attention(x, mask)
        x = self.norm1(x + self.dropout1(attn_output))

        # Feed-forward with residual
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))

        return x

Conclusion

Attention mechanisms have transformed deep learning:

  1. Self-attention enables modeling relationships between all positions
  2. Scaling by √d_k ensures stable gradients
  3. Multi-head attention captures diverse patterns
  4. Positional encoding injects sequence order information
  5. FlashAttention makes long sequences practical
  6. Sparse attention enables even longer contexts

Understanding these fundamentals is essential for working with modern NLP models and developing new architectures. The attention mechanism's flexibility and power explain why it has become the dominant approach in deep learning.