Understanding Transformer Inference Optimization
Transformer models have become the backbone of modern AI, from GPT to BERT to Vision Transformers. While training these models gets most of the attention, inference optimization is what determines whether your model can be deployed in production. This guide covers the key techniques for making transformer inference fast and memory-efficient.
The KV Cache: Foundation of Efficient Autoregressive Inference
Understanding the KV cache is fundamental to transformer inference optimization.
Why KV Cache Exists
During autoregressive generation, each new token needs to attend to all previous tokens. Without caching, generating N tokens requires:
- Token 1: 1 attention computation
- Token 2: 2 attention computations
- Token N: N attention computations
- Total: N(N+1)/2 = O(N²) computations
With KV caching, we store the key and value projections for each token and reuse them:
- Token 1: Compute and cache K1, V1
- Token 2: Compute K2, V2; attend to cached K1,V1 plus K2,V2
- Token N: Compute KN, VN; attend to all cached K,V pairs
- Total: O(N) new computations per token
class KVCache:
def __init__(self, max_seq_len: int, num_layers: int,
num_heads: int, head_dim: int):
self.cache_k = torch.zeros(num_layers, max_seq_len, num_heads, head_dim)
self.cache_v = torch.zeros(num_layers, max_seq_len, num_heads, head_dim)
self.seq_len = 0
def update(self, layer_idx: int, new_k: Tensor, new_v: Tensor):
self.cache_k[layer_idx, self.seq_len] = new_k
self.cache_v[layer_idx, self.seq_len] = new_v
def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:
return (self.cache_k[layer_idx, :self.seq_len + 1],
self.cache_v[layer_idx, :self.seq_len + 1])
KV Cache Memory Requirements
For a typical LLM:
KV cache size = 2 × num_layers × seq_len × num_kv_heads × head_dim × bytes_per_param
For LLaMA-2 70B with 2048 sequence length:
= 2 × 80 × 2048 × 8 × 128 × 2 bytes = 6.7 GB per request
This is why KV cache management is critical for serving efficiency.
FlashAttention: Memory-Efficient Attention
Standard attention materializes a N×N attention matrix, which is prohibitive for long sequences. FlashAttention restructures the computation to avoid this.
The Problem
# Standard attention: O(N²) memory
def standard_attention(Q, K, V):
scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
# scores is N×N - this is the memory bottleneck
attention = softmax(scores, dim=-1)
return attention @ V
FlashAttention Solution
FlashAttention uses tiling and online softmax to compute attention in O(N) memory:
- Tiling: Process Q, K, V in blocks that fit in SRAM
- Online softmax: Compute softmax incrementally without materializing full matrix
- Kernel fusion: Perform all operations in a single GPU kernel
def flash_attention_conceptual(Q, K, V, block_size=64):
N, d = Q.shape
output = torch.zeros_like(Q)
# Running statistics for online softmax
row_max = torch.full((N,), float('-inf'))
row_sum = torch.zeros(N)
for k_start in range(0, N, block_size):
K_block = K[k_start:k_start + block_size]
V_block = V[k_start:k_start + block_size]
scores = Q @ K_block.T / math.sqrt(d)
# Online softmax update
block_max = scores.max(dim=-1).values
new_max = torch.maximum(row_max, block_max)
# Rescale previous accumulator
scale_prev = torch.exp(row_max - new_max)
scale_curr = torch.exp(block_max - new_max)
row_sum = row_sum * scale_prev + (scale_curr * torch.exp(scores - block_max.unsqueeze(-1))).sum(dim=-1)
row_max = new_max
# Accumulate output
output = output * scale_prev.unsqueeze(-1) + torch.exp(scores - new_max.unsqueeze(-1)) @ V_block
return output / row_sum.unsqueeze(-1)
Results: 2-4x speedup, 10-20x memory reduction, exact attention (no approximation).
Quantization: Trading Precision for Speed
Quantization reduces model size and increases inference speed by using lower-precision arithmetic.
INT8 Quantization
def quantize_int8(tensor: Tensor) -> Tuple[Tensor, Tensor]:
scale = tensor.abs().max() / 127.0
quantized = torch.round(tensor / scale).to(torch.int8)
return quantized, scale
def dequantize(quantized: Tensor, scale: Tensor) -> Tensor:
return quantized.float() * scale
INT4/GPTQ Quantization
For even more aggressive compression:
# GPTQ: Accurate Post-Training Quantization
# Quantizes to 4 bits with Hessian-guided rounding
# Achieves <1% accuracy loss on most models
Memory savings:
- FP16 → INT8: 2x reduction
- FP16 → INT4: 4x reduction
Performance: On modern GPUs with Tensor Cores:
- INT8: 2x faster than FP16
- INT4: Up to 4x faster with specialized kernels
Speculative Decoding: Parallelizing Sequential Generation
Autoregressive generation is inherently sequential. Speculative decoding circumvents this using a draft model.
def speculative_decode(target_model, draft_model, prompt, K=5):
"""
Use draft model to propose K tokens, verify with target model.
Average speedup: 2-4x with 80%+ acceptance rate.
"""
tokens = prompt
while not done:
# Step 1: Draft model generates K tokens
draft_tokens = []
for _ in range(K):
draft_token = draft_model.generate_one(tokens + draft_tokens)
draft_tokens.append(draft_token)
# Step 2: Target model verifies ALL K tokens in ONE forward pass
target_logits = target_model.forward(tokens + draft_tokens)
# Step 3: Accept matching tokens
for i, draft_token in enumerate(draft_tokens):
if matches(target_logits[i], draft_token):
tokens.append(draft_token)
else:
tokens.append(sample(target_logits[i]))
break
return tokens
Batching Strategies
Static Batching
Wait for batch to fill, then process all together. Problem: short sequences wait for long ones.
Continuous Batching
Process one iteration at a time, dynamically adding/removing sequences:
class ContinuousBatcher:
def __init__(self, model, max_batch_size):
self.model = model
self.active_sequences = []
self.waiting_queue = []
async def process_step(self):
# Add new sequences if capacity available
while len(self.active_sequences) < max_batch_size and self.waiting_queue:
self.active_sequences.append(self.waiting_queue.pop(0))
# Process one token for all active sequences
if self.active_sequences:
outputs = self.model.forward_batch(self.active_sequences)
# Remove completed sequences
self.active_sequences = [s for s in self.active_sequences if not s.done]
Result: 2-3x throughput improvement over static batching.
Hardware Considerations
Memory Hierarchy
Registers ~1 cycle 256 KB per SM
L1/Shared ~30 cycles 100 KB per SM
L2 Cache ~200 cycles 40-60 MB total
HBM ~400 cycles 40-80 GB total
Compute vs Memory Bound
Compute-bound (large batches): Optimize for FLOPS Memory-bound (small batches): Optimize for bandwidth
def is_memory_bound(batch_size, model_size, memory_bandwidth, compute_tflops):
# Memory time = model_size / bandwidth
# Compute time = flops / tflops
memory_time = model_size / memory_bandwidth
compute_time = (batch_size * model_flops) / compute_tflops
return memory_time > compute_time
Practical Optimization Checklist
Phase 1: Foundation
- Enable KV caching (10-100x speedup)
- Use FP16/BF16 (2x speedup vs FP32)
- Set optimal batch size
Phase 2: Memory Optimization
- Enable FlashAttention (2-3x for long sequences)
- Quantize to INT8 (2x memory reduction)
- Implement continuous batching
Phase 3: Advanced
- INT4 quantization if memory-constrained
- Speculative decoding for latency-critical workloads
- Profile and optimize specific bottlenecks
Conclusion
Transformer inference optimization is a multi-faceted challenge:
- KV caching eliminates redundant computation
- FlashAttention enables long sequences
- Quantization reduces memory and increases throughput
- Speculative decoding parallelizes sequential generation
- Continuous batching maximizes GPU utilization
The right combination depends on your constraints: latency vs throughput, memory vs compute, accuracy vs speed. Understanding these trade-offs is key to deploying transformers effectively.