Most explanations of attention are written from the training side: here's the intuition, here's the softmax, here's why it beats RNNs. That framing never helped me much when I was staring at a serving engine trying to figure out where the GPU memory went. This post is attention as I actually use it — as the thing that decides how big the KV cache is, how many requests fit on a GPU, and why decode is slow.
The QKV math in one screen#
For each token, the model computes a query, key, and value vector per attention head. Attention for one head is:
Attention(Q, K, V) = softmax(Q Kᵀ / √d_head) VFor a sequence of n tokens, Q Kᵀ is an n × n matrix per head. That's the famous O(n²). During training (and prefill), you really do pay O(n²) compute, because every token attends to every earlier token in one big batched matmul.
At decode time, the shape changes. You're generating one token, so Q has a single row: the per-step cost is O(n) compute against the n cached keys and values. One row of dot products is a tiny amount of math, but those cached K and V tensors have to come out of GPU HBM on every step for every layer. The matmul degenerates into a matrix–vector product, and matrix–vector products don't have enough arithmetic per byte to hide memory latency.
So the O(n²) headline hides the operational truth: at inference time, attention cost is mostly a memory-bandwidth statement, not a FLOPs statement. The FLOPs are concentrated in prefill; decode is dominated by reading state.
The KV cache is attention's shadow#
That cached state has a name. If you recomputed K and V for the whole prefix on every decode step, each token would cost as much as a full prefill — quadratic work per token. Nobody does that. Instead you compute each token's K and V once and keep them around. That's the KV cache, and its size is pure arithmetic:
kv_bytes = 2 × n_layers × n_kv_heads × d_head × seq_len × bytes_per_paramThe leading 2 is for K and V. Run it for a classic 7B configuration — 32 layers, 32 KV heads, head dim 128, FP16:
2 × 32 × 32 × 128 × 2 bytes = 512 KiB per tokenHalf a megabyte per token. A single 4,096-token conversation holds 2 GiB of cache. On a 24 GB card that still has to hold ~14 GB of weights, you can serve roughly three such conversations before the math stops working. Every "why is my batch size 4?" question I've debugged eventually reduces to this formula.
MHA → MQA → GQA#
The formula also explains why modern architectures changed. The n_kv_heads factor is the only knob that doesn't touch sequence length or model depth, so that's where everyone went to cut costs.
- Multi-Head Attention (MHA): every query head has its own K and V.
n_kv_heads = n_heads. The 512 KiB/token above. - Multi-Query Attention (MQA): all query heads share one K/V head.
n_kv_heads = 1gives2 × 32 × 1 × 128 × 2 = 16 KiBper token — a 32× cut, with measurable quality loss on some tasks. - Grouped-Query Attention (GQA): the compromise. Llama-3-8B uses 32 query heads but 8 KV heads:
2 × 32 × 8 × 128 × 2 bytes = 128 KiB per token4× less cache than MHA at near-MHA quality. The same 4,096-token conversation drops from 2 GiB to 512 MiB, which is the difference between serving 3 concurrent users and serving 12. From the inference side, GQA is one of the highest-leverage architecture decisions of the last few years.
There's a second-order win, too: decode reads the whole cache every step, so a 4× smaller cache is also 4× less bandwidth per token. GQA makes decode faster, not just smaller.
FlashAttention, in two paragraphs#
The naive implementation of softmax(Q Kᵀ / √d) V materializes the n × n score matrix in HBM, reads it back for the softmax, and reads it again for the weighted sum. For an 8K prompt, that intermediate is 64M entries per head — written and read at HBM speeds, which is exactly the resource we just established is precious.
FlashAttention's observation is that the score matrix never needs to exist. It tiles Q, K, and V into blocks that fit in on-chip SRAM, computes attention block by block with a running (online) softmax, and only writes the final output to HBM. Less memory traffic, exact same math. The catch people miss: this matters mostly for prefill, where n × n is real. At decode, Q is one row, the "matrix" is 1 × n, and there's nothing meaningful to tile — decode attention stays bandwidth-bound on the KV cache no matter how clever the kernel is.
What this looks like inside a serving engine#
Once you've internalized "the KV cache is the scarce resource," most serving-engine design choices read as cache management strategies:
- Paged attention stops allocating the cache as one contiguous tensor per request. Like virtual memory, the cache is split into fixed-size blocks (say, 16 tokens each) with an indirection table, so a request's cache doesn't need to be contiguous and short requests don't strand memory. Fragmentation drops from "kills your batch size" to a rounding error.
- Prefix caching / radix attention — the part I got to see up close working with SGLang — exploits the fact that K and V for a token depend only on the tokens before it. Two requests sharing a system prompt can share that prefix's cache pages outright. SGLang organizes cached prefixes in a radix tree, so a fleet of requests with the same 2,000-token system prompt pays for that prefill once and stores the cache once.
None of this changes the attention math. It changes who pays for it, and how often. That's the inference side of attention in one sentence: the formula is fixed; the engineering is deciding what you cache, where it lives, and what you can afford never to recompute.