Back to posts

Understanding Transformer Inference Optimization

·6 min read·transformers·optimization·ml

Understanding Transformer Inference Optimization

Transformer models have become the backbone of modern AI, from GPT to BERT to Vision Transformers. While training these models gets most of the attention, inference optimization is what determines whether your model can be deployed in production. This guide covers the key techniques for making transformer inference fast and memory-efficient.

The KV Cache: Foundation of Efficient Autoregressive Inference

Understanding the KV cache is fundamental to transformer inference optimization.

Why KV Cache Exists

During autoregressive generation, each new token needs to attend to all previous tokens. Without caching, generating N tokens requires:

  • Token 1: 1 attention computation
  • Token 2: 2 attention computations
  • Token N: N attention computations
  • Total: N(N+1)/2 = O(N²) computations

With KV caching, we store the key and value projections for each token and reuse them:

  • Token 1: Compute and cache K1, V1
  • Token 2: Compute K2, V2; attend to cached K1,V1 plus K2,V2
  • Token N: Compute KN, VN; attend to all cached K,V pairs
  • Total: O(N) new computations per token
class KVCache:
    def __init__(self, max_seq_len: int, num_layers: int,
                 num_heads: int, head_dim: int):
        self.cache_k = torch.zeros(num_layers, max_seq_len, num_heads, head_dim)
        self.cache_v = torch.zeros(num_layers, max_seq_len, num_heads, head_dim)
        self.seq_len = 0

    def update(self, layer_idx: int, new_k: Tensor, new_v: Tensor):
        self.cache_k[layer_idx, self.seq_len] = new_k
        self.cache_v[layer_idx, self.seq_len] = new_v

    def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:
        return (self.cache_k[layer_idx, :self.seq_len + 1],
                self.cache_v[layer_idx, :self.seq_len + 1])

KV Cache Memory Requirements

For a typical LLM:

KV cache size = 2 × num_layers × seq_len × num_kv_heads × head_dim × bytes_per_param

For LLaMA-2 70B with 2048 sequence length:

= 2 × 80 × 2048 × 8 × 128 × 2 bytes = 6.7 GB per request

This is why KV cache management is critical for serving efficiency.

FlashAttention: Memory-Efficient Attention

Standard attention materializes a N×N attention matrix, which is prohibitive for long sequences. FlashAttention restructures the computation to avoid this.

The Problem

# Standard attention: O(N²) memory
def standard_attention(Q, K, V):
    scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
    # scores is N×N - this is the memory bottleneck
    attention = softmax(scores, dim=-1)
    return attention @ V

FlashAttention Solution

FlashAttention uses tiling and online softmax to compute attention in O(N) memory:

  1. Tiling: Process Q, K, V in blocks that fit in SRAM
  2. Online softmax: Compute softmax incrementally without materializing full matrix
  3. Kernel fusion: Perform all operations in a single GPU kernel
def flash_attention_conceptual(Q, K, V, block_size=64):
    N, d = Q.shape
    output = torch.zeros_like(Q)

    # Running statistics for online softmax
    row_max = torch.full((N,), float('-inf'))
    row_sum = torch.zeros(N)

    for k_start in range(0, N, block_size):
        K_block = K[k_start:k_start + block_size]
        V_block = V[k_start:k_start + block_size]

        scores = Q @ K_block.T / math.sqrt(d)

        # Online softmax update
        block_max = scores.max(dim=-1).values
        new_max = torch.maximum(row_max, block_max)

        # Rescale previous accumulator
        scale_prev = torch.exp(row_max - new_max)
        scale_curr = torch.exp(block_max - new_max)

        row_sum = row_sum * scale_prev + (scale_curr * torch.exp(scores - block_max.unsqueeze(-1))).sum(dim=-1)
        row_max = new_max

        # Accumulate output
        output = output * scale_prev.unsqueeze(-1) + torch.exp(scores - new_max.unsqueeze(-1)) @ V_block

    return output / row_sum.unsqueeze(-1)

Results: 2-4x speedup, 10-20x memory reduction, exact attention (no approximation).

Quantization: Trading Precision for Speed

Quantization reduces model size and increases inference speed by using lower-precision arithmetic.

INT8 Quantization

def quantize_int8(tensor: Tensor) -> Tuple[Tensor, Tensor]:
    scale = tensor.abs().max() / 127.0
    quantized = torch.round(tensor / scale).to(torch.int8)
    return quantized, scale

def dequantize(quantized: Tensor, scale: Tensor) -> Tensor:
    return quantized.float() * scale

INT4/GPTQ Quantization

For even more aggressive compression:

# GPTQ: Accurate Post-Training Quantization
# Quantizes to 4 bits with Hessian-guided rounding
# Achieves <1% accuracy loss on most models

Memory savings:

  • FP16 → INT8: 2x reduction
  • FP16 → INT4: 4x reduction

Performance: On modern GPUs with Tensor Cores:

  • INT8: 2x faster than FP16
  • INT4: Up to 4x faster with specialized kernels

Speculative Decoding: Parallelizing Sequential Generation

Autoregressive generation is inherently sequential. Speculative decoding circumvents this using a draft model.

def speculative_decode(target_model, draft_model, prompt, K=5):
    """
    Use draft model to propose K tokens, verify with target model.
    Average speedup: 2-4x with 80%+ acceptance rate.
    """
    tokens = prompt

    while not done:
        # Step 1: Draft model generates K tokens
        draft_tokens = []
        for _ in range(K):
            draft_token = draft_model.generate_one(tokens + draft_tokens)
            draft_tokens.append(draft_token)

        # Step 2: Target model verifies ALL K tokens in ONE forward pass
        target_logits = target_model.forward(tokens + draft_tokens)

        # Step 3: Accept matching tokens
        for i, draft_token in enumerate(draft_tokens):
            if matches(target_logits[i], draft_token):
                tokens.append(draft_token)
            else:
                tokens.append(sample(target_logits[i]))
                break

    return tokens

Batching Strategies

Static Batching

Wait for batch to fill, then process all together. Problem: short sequences wait for long ones.

Continuous Batching

Process one iteration at a time, dynamically adding/removing sequences:

class ContinuousBatcher:
    def __init__(self, model, max_batch_size):
        self.model = model
        self.active_sequences = []
        self.waiting_queue = []

    async def process_step(self):
        # Add new sequences if capacity available
        while len(self.active_sequences) < max_batch_size and self.waiting_queue:
            self.active_sequences.append(self.waiting_queue.pop(0))

        # Process one token for all active sequences
        if self.active_sequences:
            outputs = self.model.forward_batch(self.active_sequences)

            # Remove completed sequences
            self.active_sequences = [s for s in self.active_sequences if not s.done]

Result: 2-3x throughput improvement over static batching.

Hardware Considerations

Memory Hierarchy

Registers    ~1 cycle     256 KB per SM
L1/Shared    ~30 cycles   100 KB per SM
L2 Cache     ~200 cycles  40-60 MB total
HBM          ~400 cycles  40-80 GB total

Compute vs Memory Bound

Compute-bound (large batches): Optimize for FLOPS Memory-bound (small batches): Optimize for bandwidth

def is_memory_bound(batch_size, model_size, memory_bandwidth, compute_tflops):
    # Memory time = model_size / bandwidth
    # Compute time = flops / tflops
    memory_time = model_size / memory_bandwidth
    compute_time = (batch_size * model_flops) / compute_tflops
    return memory_time > compute_time

Practical Optimization Checklist

Phase 1: Foundation

  1. Enable KV caching (10-100x speedup)
  2. Use FP16/BF16 (2x speedup vs FP32)
  3. Set optimal batch size

Phase 2: Memory Optimization

  1. Enable FlashAttention (2-3x for long sequences)
  2. Quantize to INT8 (2x memory reduction)
  3. Implement continuous batching

Phase 3: Advanced

  1. INT4 quantization if memory-constrained
  2. Speculative decoding for latency-critical workloads
  3. Profile and optimize specific bottlenecks

Conclusion

Transformer inference optimization is a multi-faceted challenge:

  • KV caching eliminates redundant computation
  • FlashAttention enables long sequences
  • Quantization reduces memory and increases throughput
  • Speculative decoding parallelizes sequential generation
  • Continuous batching maximizes GPU utilization

The right combination depends on your constraints: latency vs throughput, memory vs compute, accuracy vs speed. Understanding these trade-offs is key to deploying transformers effectively.