Transformer internals
Every modern LLM — Llama 4, DeepSeek V3, Claude, GPT — is a stack of one block repeated. Master the block and you master the model. This chapter walks the block end-to-end at Sr Staff depth: attention math, the MHA→MLA evolution, RoPE/YaRN, KV cache, MoE routing, normalization, tokenization, parallelism.
What you'll learn
- The transformer block — anatomy of one decoder layer
- Self-attention — math, complexity, masking
- The KV evolution — MHA → MQA → GQA → MLA
- Positional encodings — sinusoidal, learned, ALiBi, RoPE, YaRN, NoPE
- KV cache mechanics — what to cache, how big it gets
- Feed-forward block — MLPs, gated activations, SwiGLU
- Mixture of Experts — routing, balancing, DeepSeek's tricks
- Normalization & residuals — pre- vs post-norm, the residual stream
- Tokenization — BPE, SentencePiece, byte fallback
- Sequence & context parallelism — when training breaks the GPU
This chapter dissects a single pre-norm decoder block: the residual stream, the two sub-layers, and the parameter accounting that lets you estimate any model size on the back of a napkin. It is the architectural substrate that every subsequent chapter builds on.
Strip away everything and one transformer block is:
x' = x + Attention(LayerNorm(x)) # sub-layer 1
y = x' + FFN(LayerNorm(x')) # sub-layer 2
Repeat these two lines L times (Llama 3 70B: L=80; DeepSeek V3: L=61), then project to vocabulary logits. That is the complete model. Everything else — rotary embeddings, GQA, SwiGLU, MoE — is a component that slots into one of the two sub-layers or the projection step.
Plain words: every token starts as a vector x ∈ ℝd. Each sub-layer reads from it, computes a delta, and adds the delta back. The vector is never replaced — only updated by accumulation.
Concrete 4-token example. Token "cat", d=4, starts as x=[0.1, 0.5, −0.2, 0.8]. After attention: delta_attn=[0.0, 0.1, 0.3, −0.1], so x'=[0.1, 0.6, 0.1, 0.7]. After FFN: delta_ffn=[0.2, −0.1, 0.0, 0.2], so y=[0.3, 0.5, 0.1, 0.9]. Neither sub-layer ever zeros the vector; they only nudge it.
Formal: The mech-interp framing (Elhage et al., Anthropic 2021) treats the residual stream as a high-dimensional linear communication bus. Sub-layers read via LayerNorm(x), write via x ← x + f(LN(x)). Linearity of the bus means later layers can read information placed there by any earlier layer.
Why it matters: Without residuals, the gradient through L sub-layers is a product of L Jacobians and vanishes exponentially. With residuals, ∂(x + f(x))/∂x = I + ∂f/∂x — an identity component always passes through, so you can train 80-layer networks without gradient explosions.
Per block at hidden dim d with 4d FFN expansion:
4 · d² params8 · d² params4d params — negligibleWorked example — Llama 3 70B. d=8192, L=80. One block ≈ 12 · 8192² ≈ 805M params. Stack 80 blocks → ≈ 64B from blocks alone. Add embeddings (vocab 128k × 8192 ≈ 1B) and GQA/SwiGLU adjustments → land at 70B. The formula gives you the right ballpark in under 10 seconds.
Original paper (2017) used post-norm: LN(x + Sublayer(x)). The LayerNorm is applied after the residual merge — gradient passes through the norm's reciprocal-std, which can shrink or spike, making deep networks fragile.
Pre-norm (2019–present): x + Sublayer(LN(x)). LayerNorm is only in the sub-layer branch; the residual path is pure identity. The gradient through the residual is exactly I + ∂f/∂x — no norm in the way. GPT-2 moved here and trained 48 layers stably without warmup tricks. Every major LLM since follows suit.
The cost: pre-norm's final layer norm is applied at the top; the final layer's residual output is unnormalized before projection. A small quality hit compared to a perfectly-tuned post-norm model — worth it for trainability at depth.
"The residual connection is just a shortcut to avoid vanishing gradients." It is that — but it is also the core communication mechanism. In the mech-interp lens, sub-layers are attention heads and FFN neurons writing into a shared linear bus. Each head's output is added, not concatenated or overwritten. The linearity is the feature: it means different sub-circuits compose additively and any layer can read what any earlier layer wrote. Understanding this changes how you read ablation studies.
Trigger: "Walk me through a single transformer block" or "What is the residual stream?"
- State the two equations:
x' = x + Attn(LN(x))andy = x' + FFN(LN(x')). - Name pre-norm by name and say why (identity gradient path for deep training).
- Name the residual stream as a linear bus: sub-layers add deltas, never overwrite.
- Give the parameter count: ≈ 12d² per block, FFN is 2× attention.
Never: describe the block without mentioning why pre-norm exists, or give params without explaining why FFN dominates.
- One block =
x + Attn(LN(x))thenx' + FFN(LN(x')). Repeat L times. - Residual stream = linear communication bus; sub-layers add deltas, never replace.
- Pre-norm: identity gradient path through residual → stable training at L=80+.
- Per-block params ≈ 12d²; FFN costs 2× attention; total model ≈ 12·d²·L + embeddings.
A modern LLM block is two equations repeated L times. The residual stream is a linear bus that sub-layers add to — never overwrite. Pre-norm places LayerNorm inside each branch, giving a clean identity gradient path that makes deep training stable. Each block costs ≈ 12d² parameters, with FFN holding 2× the weight of attention.
Q1. Write the two equations for a pre-norm transformer block and explain each term.
x' = x + Attention(LayerNorm(x)): the residual stream x passes through LN to stabilize scale, then attention computes a delta that is added back. y = x' + FFN(LayerNorm(x')): same pattern for the feed-forward sub-layer. The outer addition is the residual; LN is inside the branch. Critical insight: the residual is never normalized, so gradients flow through it without distortion.Q2. Why does pre-norm outperform post-norm for deep networks?
LN(x + f(x)). Gradient must pass through LN's reciprocal-std term, which can shrink or amplify at each layer — cumulative instability for L>24. Pre-norm: x + f(LN(x)). The residual gradient path is identity + Jacobian of f. The identity component always survives regardless of depth, enabling L=80+ without exotic warmup.Q3. Estimate the parameter count of a 7B model with d=4096, L=32.
Q4. What does the "residual stream as a communication bus" framing mean?
Q5. Why does FFN dominate parameters more than attention?
Q6. What information does LN's γ (scale) and β (bias) carry after pre-norm?
Q7. How many blocks deep are the biggest models, and why doesn't the residual break?
Q8. If you double d from 4096 to 8192 but keep L constant, how does parameter count change?
Q9. What breaks without the residual connection? (Be specific about the gradient.)
Q10. The original Transformer used both an encoder and a decoder. Why did GPT-style decoder-only dominate?
This chapter builds self-attention from scratch: why the formula is what it is, what each division and masking step does, and where the O(L²) cost comes from. Every attention variant in the rest of the transformer section — MQA, GQA, MLA, FlashAttention — is a direct response to the costs analyzed here.
Plain words. Imagine a dictionary where each entry has a key and a value. To look something up, you don't need an exact key match — you score your query against every key, turn the scores into probabilities, and return the weighted average of all values. That is exactly what attention computes. The "dictionary" is the set of past tokens; the "query" is the current token's need.
4-token worked example. Tokens: ["The", "cat", "sat", "down"]. Processing "sat" (query). Dot products with keys: The→0.2, cat→0.9, sat→0.3, down→0.0 (future, masked to −∞). After softmax: The→0.19, cat→0.68, sat→0.13, down→0.00. Output for "sat" = 0.19·v_The + 0.68·v_cat + 0.13·v_sat. The high weight on "cat" means "sat" picks up subject information.
Given input X ∈ ℝL×d (L tokens, hidden dim d), one attention head:
Q = X W_Q # (L, d_k) — what each token is looking for
K = X W_K # (L, d_k) — what each token offers as a key
V = X W_V # (L, d_v) — what each token offers as a value
scores = Q Kᵀ / √d_k # (L, L) — how relevant is each key?
scores = scores + causal_mask # add −∞ to future positions
attn = softmax(scores) # (L, L) — rows sum to 1
out = attn @ V # (L, d_v) — blend values by weight
The problem without scaling. Suppose Q and K entries are drawn from a standard normal: mean 0, variance 1. The dot product q · k = Σ_{i=1}^{d_k} q_i k_i. Each term q_i k_i has mean 0 and variance 1. The sum of d_k such terms has variance d_k. Standard deviation = √d_k.
Concrete numbers. d_k=128 → std≈11. So scores are commonly in the range [−20, +20]. Feed that to softmax: exp(20)/exp(−20) ≈ 1017 — the winner takes essentially 1.0 and all other weights go to 0. You get a one-hot attention pattern. Gradients through a saturated softmax are ≈0 everywhere. Training grinds to a halt.
Fix: divide by √d_k=√128≈11.3. Now scores have std≈1. Softmax is in its useful regime (a smooth distribution, not a one-hot). At d_k=64 (original paper) √64=8; modern heads typically d_k=128.
"Skipping √d_k looks fine on small models." True — at d_k=16, std=4, softmax is only mildly saturated. The bug is silent on toy experiments and catastrophic at scale. It is the most common mistake in from-scratch attention implementations. Always include the scaling.
Plain words. An autoregressive LM must predict token t using only tokens 1..t−1. But the raw attention score matrix is (L×L) — every token can see every other. The mask enforces the rule by making future attention scores −∞ before softmax, so they become 0 after.
Implementation. Before softmax, add a (L×L) upper-triangular matrix filled with −∞ (in practice −1e9 or the minimum representable float) to the score matrix. Position (i,j) with j>i (future token) gets −∞ → softmax output 0. Row i sums to 1 over the attended positions {0..i}.
Why add (not multiply)? You add −∞ to logits before softmax. Multiplying would require separate logic; adding is one fused kernel operation and works naturally with the softmax exponential.
| Quantity | Cost | Why it hurts |
|---|---|---|
| Compute (FLOPs) | O(L²·d) | Doubling context = 4× compute |
| Activation memory | O(L²) | Materialized score matrix blows VRAM |
| KV cache at decode | O(L·d) | Linear but large at long context |
Concrete 128k context calculation. L=128k, d=8k, one head. Score matrix: L² = 16B entries × 2 bytes (fp16) = 32 GB per head per layer. With H=64 heads × L=80 layers: 32 GB × 64 × 80 = 163 TB — obviously impossible to materialize. Even for a single head in a single layer, 32 GB exceeds many GPU VRAM budgets.
FlashAttention (Dao 2022) sidesteps the memory by tiling the (L×L) score matrix and computing outputs block by block with online softmax (Milakov & Gimelshein 2018). The full score matrix is never written to HBM — only the final output is. Memory drops to O(L). The FLOPs are the same (still O(L²·d)) — you still do all the multiplications; you just don't store the intermediate results. "FlashAttention reduces memory, not compute."
"FlashAttention is O(1) memory — true or false?" False. FlashAttention is O(L) memory (the output), down from O(L²) for the naive materialized score matrix. Compute is still O(L²·d). An interviewer checking whether you understand the tiling trick vs the FLOPs is a direct test of this distinction.
Trigger: "Why does long context require so much memory?" or "What does FlashAttention actually fix?"
- State the O(L²) memory origin: the (L×L) score matrix, one entry per (query, key) pair.
- Give a concrete number: at L=128k, one head's score matrix = 32 GB in fp16.
- Explain FlashAttention: tiles the matrix, never writes the full (L×L) to HBM, uses online softmax. Memory → O(L). FLOPs unchanged.
- Say what FlashAttention does NOT fix: the O(L²) compute cost — that's why ring attention and other approaches are needed for 1M+ context.
Never: say FlashAttention "makes attention O(L)" without qualifying that it is memory-only and FLOPs remain O(L²).
- Formula:
softmax(QKᵀ/√d_k)·V. The √d_k restores unit variance to the dot products so softmax doesn't saturate. - Causal mask = upper-triangular −∞ added to scores before softmax. After softmax those entries are 0, row still sums to 1.
- Complexity: O(L²·d) FLOPs, O(L²) memory (naive). FlashAttention fixes the memory to O(L); FLOPs still O(L²).
- Every attention variant (GQA, MLA, sparse, linear) is a response to the L² cost and the KV cache.
Attention is softmax(QKᵀ/√d_k)·V. The √d_k prevents softmax saturation by keeping dot-product variance ≈1. The causal mask adds −∞ to future positions before softmax. The O(L²) compute and O(L²) memory costs are why long context is hard and why FlashAttention (which fixes memory but not FLOPs) exists.
Q1. Write the full attention formula and explain every symbol.
Attention(Q,K,V) = softmax(QKᵀ / √d_k) · V. Q: query matrix (what we're looking for), K: key matrix (what each token offers as index), V: value matrix (content). QKᵀ: dot product between all query–key pairs → (L×L) score matrix. √d_k: scale factor; without it, high d_k → large dot products → softmax saturation → vanishing gradients. softmax: normalizes each row to a probability distribution (sums to 1). Final matrix multiplication with V: weighted average of value vectors.Q2. Why does variance of QK dot products grow with d_k?
Q3. What happens if you apply softmax to scores with std=11 (d_k=128, no scaling)?
Q4. Where exactly is the causal mask applied, and what is its value?
scores = Q@Kᵀ / √d_k (shape L×L), before softmax: add an upper-triangular matrix of −∞ (or −1e9 in practice, to avoid NaN in FP16). Position (i,j) with j>i gets −∞ → exp(−∞) = 0 after softmax. Each row i then sums to 1 over positions {0,1,...,i} — the causal prefix. The mask is pre-computed once and broadcast across all heads and batch elements.Q5. What is the attention complexity — both FLOPs and memory?
Q6. What exactly does FlashAttention do, and what doesn't it change?
Q7. Why does attention use softmax rather than sigmoid?
Q8. Multi-head attention: why H heads and how are they combined?
Q9. What is online softmax and why does FlashAttention need it?
Q10. What is the KV cache and when is it useful?
The KV cache is the dominant VRAM cost at long context. This chapter traces the four-step evolution that shrinks it from O(H·d) per token (MHA) down to ~7% of that (MLA), explaining the mechanism and quality trade-off at each step. Know all four cold — this is a top-5 LLM interview topic.
Plain words. Multi-Head Attention (Vaswani 2017) runs H independent attention operations. Each head h has its own weight matrices W_Q^h, W_K^h, W_V^h projecting the residual stream to dimension d_head = d/H. The H outputs are concatenated and projected by W_O.
Worked example (H=8, d=512, d_head=64). Each head projects to 64-dim K and V. At decode step t, cached K and V for each past token: 8 heads × 64 dims × 2 (K+V) = 1024 floats per token. In bf16 (2 bytes): 2048 bytes per token per layer. For L=80 layers and seq=128k: 2048 × 80 × 128k = 20 GB per request.
Why each head needs its own K,V. Different heads attend to different relationships — one might track syntax, another coreference. If they share K and V they can only extract one type of signal from the same linear subspace. Quality drops, especially on tasks requiring multiple simultaneous relationships.
Plain words. Multi-Query Attention (Shazeer 2019) keeps H query heads but collapses K and V to a single shared projection: one W_K and one W_V for the whole layer. Every query head reads from the same key and value vectors.
Cache reduction. KV cache drops from 2·H·d_head to 2·d_head per token per layer — an H× reduction. For H=8, that is 8× smaller: the 20 GB example above becomes 2.5 GB.
Quality cost. The single K,V must serve all H queries simultaneously. In practice, the quality drop is small on downstream tasks but measurable on harder reasoning benchmarks. PaLM (Google 2022) and Falcon used MQA and found it acceptable.
The intuition for the quality drop. All queries must "ask about" the same key subspace. If head 1 wants syntactic-role keys and head 3 wants semantic-similarity keys, they now get the same K — a compromise projection. The model learns to work around it, but there's a ceiling.
Plain words. Grouped-Query Attention (Ainslie 2023, arXiv 2305.13245) interpolates between MHA and MQA. Split the H query heads into G groups; each group shares one K and one V projection. So there are G KV heads (not 1, not H).
Worked example. H=64 query heads, G=8 KV heads. Each KV head serves 64/8=8 query heads. Cache = 2·8·d_head per token per layer — 8× smaller than MHA. Llama 2/3 70B, Mistral, Qwen all use this exact setting.
Quality. The gap between GQA and MHA is negligible at typical g=8 — the 8× cache win is essentially free. This is why GQA became the universal default.
Naming trap: "g" in GQA(g=8) is the number of KV heads, not the group size of query heads. g=8 means 8 KV heads. Group size = H/g. GQA(g=1) = MQA (extreme sharing). GQA(g=H) = MHA (no sharing). When an interviewer says "GQA with g=8 on a 64-head model," they mean 8 KV heads, 8 queries per KV head, 8× cache reduction.
The insight. Multi-head Latent Attention (DeepSeek V2/V3, arXiv 2405.04434) asks a deeper question: rather than grouping existing KV heads, can we compress the joint (K,V) space into a low-rank latent vector and reconstruct K,V on the fly? The answer is yes, and it is nearly lossless.
Plain words. Instead of caching H full K vectors and H full V vectors, cache a single compressed vector c ∈ ℝd_c (typically d_c≈512, much smaller than 2·H·d_head ≈ 8192). At attention time, reconstruct K and V from c via up-projection matrices.
The key trick — absorbing up-projections into Q. Let K = W_K_up · c and V = W_V_up · c. The attention score Q·Kᵀ becomes Q·(W_K_up · c)ᵀ = (Q·W_K_up)·cᵀ. You can absorb W_K_up into the Q matrix: compute Q' = Q·W_K_up once per step, then compute Q'·cᵀ (instead of reconstructing full K). This means at inference you cache only c, not the reconstructed K. V is reconstructed on the fly for the value aggregation step, but that is cheaper than storing it.
Cache size. MHA: 2·H·d_head ≈ 2·128·128 = 32768 per token (DeepSeek V2 dims). MLA: d_c ≈ 512. Ratio: 512/32768 ≈ 1.6%. In practice DeepSeek reports ~7% of MHA cache including the small rotary component.
The problem. RoPE rotates Q and K by position-dependent angles. If K is reconstructed from latent c via W_K_up, you'd need to rotate the reconstruction by RoPE. But RoPE is position-dependent — you can't pre-multiply it into W_K_up and absorb it into Q, because the rotation angle changes for every token position. If you try to absorb it, the matrix you'd need depends on position and can't be precomputed.
The solution: split Q and K into two paths.
Final score: score = QC·(KC)ᵀ + QR·(KR)ᵀ. The cache stores only c (d_c≈512) plus the small KR (≈64). This is the mechanism that lets MLA be simultaneously compact and MHA-quality.
| Variant | KV cache / token / layer | Quality | Used by |
|---|---|---|---|
| MHA | 2·H·d_head | baseline | GPT-2/3, original Llama |
| MQA | 2·d_head (H× smaller) | small drop | PaLM, Falcon |
| GQA(g=8) | 2·g·d_head (8× smaller at H=64) | ≈ MHA | Llama 2/3/4, Mistral, Qwen |
| MLA | d_latent + d_R (~7% of MHA) | ≈ MHA | DeepSeek V2/V3 |
Trigger: "Walk me through MHA to MLA" or "Why does MLA need decoupled RoPE?"
- Name the cache cost of MHA: 2·H·d_head per token per layer. Give a number (20–40 GB at 128k context for a 70B model).
- MQA: collapse to 1 KV head, H× reduction, small quality drop. GQA: G groups, sweet spot.
- MLA: compress (K,V) into latent c; absorb W_K_up into Q; cache only c. State the cache ratio (~7%).
- Decoupled RoPE: RoPE is position-dependent so it can't be pre-absorbed. Split into content path (no RoPE, absorbable) and rotary path (small dim, MQA-style). Score = content + rotary.
Never: confuse GQA(g=8) with "8 queries per KV group" — g=8 means 8 KV heads.
- MHA → MQA → GQA → MLA: each step removes a layer of KV redundancy.
- GQA(g=k) means k KV heads (not k queries per group). g=1 = MQA, g=H = MHA.
- MLA caches a low-rank latent c (~512 dims) instead of full K,V; absorbs K up-projection into Q.
- Decoupled RoPE is load-bearing: content path no RoPE (absorbable), rotary path has RoPE (small, MQA-style).
Q1. Why does MHA have such a large KV cache? Give a concrete number.
Q2. MQA reduces cache 8×. Why not use it everywhere?
Q3. Explain GQA in one sentence and give the formula for its cache size.
Q4. What does MLA cache, and how is it different from storing K and V?
Q5. Why can't MLA apply RoPE to its compressed K without the decoupled design?
Q6. What is the quality difference between GQA and MLA in practice?
Q7. If you could redesign MHA today with the KV cache problem in mind, what would you do?
Q8. How is MQA's quality drop different from GQA's?
Q9. Llama 3 70B uses GQA(g=8) with H=64 query heads. What is the exact KV cache per token at 128k context?
Q10. In MLA, what exactly is "absorbed" into Q and what remains in the cache?
Self-attention is permutation-invariant — without positional info, "the dog bit the man" equals "the man bit the dog." Six families have been tried; in 2026 the winning recipe is RoPE for pretraining and YaRN to extend to long context.
The six families
| Method | Mechanism | Extrapolation | Used by |
|---|---|---|---|
| Sinusoidal | Add sin(pos/10000^(2i/d)) to embeddings | weak in practice | Original Transformer |
| Learned absolute | Lookup table indexed by position | none beyond train length | GPT-2, GPT-3 |
| ALiBi | Add linear bias −m·|i−j| to scores; no embedding | strong | MPT, BLOOM |
| RoPE | Rotate Q,K in 2D subspaces by θ_i·pos | weak alone, strong with YaRN | Llama, Mistral, DeepSeek, Qwen |
| YaRN | Frequency-categorized RoPE rescaling + temperature | extends 4k → 128k | Llama 3, Qwen 2 |
| NoPE | No positional info; causal mask provides position implicitly | surprising on some tasks | research |
RoPE in one paragraph
RoPE (Su 2021, arxiv 2104.09864) groups Q and K dimensions into 2D pairs and rotates each pair by an angle θ_i · pos, with θ_i = base−2i/d (default base 10000). Because rotations preserve dot products, Q·K after rotation depends only on relative position (pos_q − pos_k). This gives you relative-position semantics without learning a relative bias matrix.
# Per 2D pair (x, y), at position p, with base frequency θ:
[x'] [cos(θp) -sin(θp)] [x]
[y'] = [sin(θp) cos(θp)] [y]
Llama 2 trained at 4k with base=10000. Llama 3 raised the base to 500000 to push the wavelength spectrum out to 8k. To go further (8k → 128k), Llama 3 used YaRN-style rescaling during a long-context fine-tune. The base controls how fast the slowest-frequency dim rotates; longer wavelengths support longer context.
YaRN — extending RoPE to long context
YaRN (Peng 2023) categorizes RoPE frequencies into three bands and rescales them differently:
- High-freq dims (fast oscillation, position-sensitive): minimal interpolation — they extrapolate.
- Low-freq dims (slow, semantic): full positional interpolation (compress positions into the trained range).
- Mid-freq dims: smoothly interpolated between the two.
Plus an attention-temperature factor that compensates for the lower entropy at long context. Result: extend a 4k pretrained model to 128k with a brief continued-training pass on long-context data.
ALiBi vs RoPE — pick one and defend
ALiBi
- Linear bias added to attention scores — no embeddings to learn.
- Per-head slope
m(geometric sequence). - Excellent length extrapolation out of the box.
- Slightly worse in-distribution quality than RoPE.
RoPE
- Rotates Q,K — encodes relative position via dot product.
- Better quality at training length; weak naive extrapolation.
- Pairs cleanly with YaRN/NTK rescaling for long context.
- The 2026 default.
(0,1),(2,3),...) and GPT-NeoX/Llama style (split halves (0, d/2), (1, d/2+1),...). Loading weights across implementations without remapping gives garbage outputs that look almost-right. Always check.
- RoPE rotates Q,K in 2D pairs by
θ_i·pos; dot product captures relative position. - Modern recipe: train at moderate length with RoPE, then YaRN-extend with continued long-context training.
- ALiBi is the simpler alternative with strong extrapolation but weaker peak quality.
During autoregressive decode, K and V for past tokens never change — cache them, recompute only Q for the new token. Cache size scales with 2 · L · n_kv_heads · d_head · n_bytes per token. At long context this dominates VRAM, which is why GQA and MLA exist.
Why a cache exists
For token t, attention reads K and V for all tokens ≤ t. None of those K, V vectors depend on t's query — they were computed when each past token was processed. Recomputing them per step would make decode O(L²) per token; caching makes it O(L) per token (reading the cache).
The size formula
cache_size = 2 · L_layers · n_kv_heads · d_head · seq_len · bytes_per_value
The 2 is for K + V. n_kv_heads is where MHA → MLA shrinks the bill. bytes_per_value is 2 for bf16/fp16, 1 for int8 quantized cache.
L=80, n_kv_heads=8 (GQA), d_head=128, seq=128k, bf16. Cache = 2 · 80 · 8 · 128 · 128k · 2 ≈ 40 GB per request. On an 80 GB H100, you can fit one such request plus the model. Single-prompt long-context inference is largely a KV-cache problem.
For a deeper dive on paged KV, prefix sharing and the prefill vs decode distinction, see LLM inference.
- Cache K,V because they never change after the token is produced.
- Cache size =
2 · L · n_kv_heads · d_head · seq · bytes. - At long context the cache, not the weights, dominates VRAM — this is why GQA/MLA matter.
The FFN is a per-token MLP — usually 4× wider than the residual stream — that holds most of the model's parameters and (per mech-interp) most of the model's "knowledge." SwiGLU replaced ReLU/GELU around 2022 and is now universal: it splits the input into a gate and a value, modulates them with Swish, and trades one matrix for measurable quality gains.
Standard FFN — the original
y = W_2 · σ(W_1 · x + b_1) + b_2 # σ ∈ {ReLU, GELU}
Hidden dim 4d. Two matrices, one nonlinearity. GPT-2/3, BERT.
SwiGLU — the 2024+ default
SwiGLU (Shazeer 2020) introduces a gate:
y = W_2 · ( Swish(W_1 · x) ⊙ (W_3 · x) )
where Swish(z) = z · sigmoid(z)
Three matrices. To keep parameter count constant relative to a 4d standard FFN, hidden dim is reduced to ~2.67d (= 8d/3). Used by Llama, Mistral, DeepSeek, PaLM, Qwen.
Standard FFN at d=8192: 2 · (4·8192·8192) = 537M params. SwiGLU at hidden = 8/3 · 8192 ≈ 21845: 3 · (21845·8192) ≈ 537M. Same parameter count, gated nonlinearity, ~1% perplexity win on equal-data training.
Why gating works (intuition)
The gate Swish(W_1·x) learns which dimensions of the value W_3·x to suppress per-token. It's a soft, data-dependent feature selector — close in spirit to LSTM gates and to multiplicative interactions. Empirically the gain over plain ReLU/GELU is small but consistent and free, so everyone adopted it.
- FFN is a per-token MLP at hidden dim ~4d (or 8d/3 for SwiGLU); 2/3 of block params live here.
- SwiGLU =
W_2(Swish(W_1·x) ⊙ W_3·x), three matrices, hidden dim 8d/3 to match params. - Gated activations beat pure activations by a small but free margin — universal in 2026.
Replace each FFN with N expert FFNs and a learned router that picks k experts per token. You get the parameter capacity of a giant model with the FLOPs of a small one. The whole design problem is keeping experts balanced — three balancing tricks (aux loss, router Z-loss, aux-loss-free bias) are all on the interview menu.
The forward pass
Per token x:
- Router:
logits = W_r · xover N experts. - Top-k: pick the k experts with highest logits; softmax their logits → gate weights
g_i. - Expert forward: each picked expert
icomputesFFN_i(x). - Combine:
y = Σ_i g_i · FFN_i(x).
Routing variants
- Top-1 (Switch Transformer, Fedus 2021): one expert per token. Cheapest; quality dip vs top-2.
- Top-2 (Mixtral): two experts. Mixtral-8x7B → 47B total params, ~13B active per token.
- Top-K with shared experts (DeepSeek-MoE, arxiv 2401.06066): some experts always activate, routed experts specialize. DeepSeek V3 has 1 shared + 256 routed, top-8 routed.
- Expert Choice (Zhou 2022): experts pick tokens, not vice versa — guarantees load balance by construction. Tradeoff: tokens not chosen are dropped or sent to a default.
Load balancing — get the formula right
Without intervention the router collapses onto a few favorite experts. Three remedies, in chronological order:
1. Switch Transformer auxiliary loss
L_aux = α · N · Σ_{i=1}^{N} f_i · P_i
N= number of expertsf_i= fraction of tokens routed to expert i (per batch)P_i= mean router probability for expert i (per batch)α= scalar coefficient (typically 0.01)- Why the N multiplier? Under perfectly uniform routing (
f_i = P_i = 1/N), the loss equals 1 regardless of N — cleanly normalized.
2. Router Z-loss (memorize alongside)
(ST-MoE / PaLM, also DeepSeek) Penalize the log-partition of router logits to keep them bounded:
L_z = (1/B) · Σ_i ( log Σ_j exp(x_{ij}) )²
Where x_ij is the router logit for expert j on token i. Without it, router logits drift in bf16/fp8 and softmax becomes numerically unstable.
3. Auxiliary-loss-free balancing (DeepSeek V3)
Drop the auxiliary loss entirely. Maintain a per-expert bias b_i that's added to the router logits at routing time only (not used for the gate weights). When an expert is over-subscribed, decrement its bias; under-subscribed, increment. The bias is a hyperparam-light PI controller on the routing distribution, with no gradient interference from an auxiliary loss term.
f_i uses argmax and isn't itself differentiable — that's why it multiplies by P_i.
When MoE pays off
MoE buys parameter capacity for free FLOPs at training time, but the full model must fit in memory at inference (all experts must be loaded; only k run per token). On a single GPU MoE is rarely worth it. Across many GPUs with expert parallelism (each GPU holds different experts), MoE wins because compute scales with active params not total params.
- Router → top-k → gate-weighted sum. Mixtral top-2; DeepSeek top-8 with 1 shared expert.
- Switch aux loss:
L_aux = α·N·Σ f_i·P_i; N normalizes uniform routing to 1. - Router Z-loss penalizes the log-partition for fp8/bf16 stability.
- DeepSeek V3 replaces aux loss with a per-expert bias updated as a PI controller — no gradient interference.
Pre-norm puts the LayerNorm inside the residual branch — gradients flow through the residual unchanged, which is why pre-norm trains deep models stably. Post-norm gives slightly better quality if you can train it. RMSNorm replaced LayerNorm in most modern stacks; QK-Norm is the new stability trick.
Pre-norm vs post-norm
Pre-norm (modern default)
x_{l+1} = x_l + Sublayer(LN(x_l))
- Identity gradient path through residual.
- Stable for L > 100; requires no warmup tricks.
- Used by GPT-2, Llama, Mistral, DeepSeek, PaLM.
Post-norm (original)
x_{l+1} = LN(x_l + Sublayer(x_l))
- Slightly better quality if it trains.
- Gradient passes through LN's reciprocal-std → harder to train deep.
- Used by original Transformer, BERT.
RMSNorm — LayerNorm minus the mean
RMSNorm (Zhang & Sennrich 2019) drops the mean-subtraction step of LayerNorm:
RMSNorm(x) = x / sqrt(mean(x²) + ε) · γ
Faster (no mean), no bias parameter, slightly better empirically. Used in Llama, Mistral, DeepSeek. LayerNorm is now legacy in LLMs.
Sandwich norm and QK-Norm
- Sandwich norm / double norm (Gemma 2, OLMo): LN both before and after a sub-layer. Adds stability at scale.
- QK-Norm (OLMo, some Gemma variants): apply LN to Q and K before the dot product. Bounds attention score magnitudes, prevents softmax saturation, key for training very large/wide models.
Why residuals are non-negotiable
Without residuals, the gradient through L sub-layers is a product of L Jacobians — vanishes or explodes. With residuals, ∂(x + f(x))/∂x = I + ∂f/∂x, preserving an identity component. This is the architectural enabler that lets transformers go to L=80, 100, or beyond.
The original 12-layer Transformer (post-norm) needed careful warmup or it diverged. GPT-2 swapped to pre-norm and trained 48 layers stably without exotic tricks. Every major LLM since has been pre-norm; the cost (small quality hit) is dominated by the gain (you can scale depth).
- Pre-norm = stable scaling; post-norm = slightly better quality if it trains. Use pre-norm.
- RMSNorm has replaced LayerNorm in modern stacks (no mean, no bias, faster).
- QK-Norm bounds attention scores — increasingly common at scale.
- Residuals give the identity gradient path that makes L=80 trainable.
Tokenization is the un-glamorous step that quietly governs cost, multilingual fairness, and arithmetic. The 2026 default is byte-level BPE (OpenAI tiktoken) or SentencePiece-BPE (Llama/Mistral) with byte fallback for robustness. Larger vocabs (200k+) reduce tokens-per-text but bloat the embedding table.
The four families
| Algorithm | Idea | Used by |
|---|---|---|
| BPE | Start from chars; iteratively merge most frequent pair. | GPT-2/3, Llama (via SentencePiece-BPE) |
| Byte-level BPE | BPE on UTF-8 bytes; vocab includes 256 byte tokens; handles any input. | GPT-2, tiktoken |
| SentencePiece (BPE or Unigram) | Operates on raw text — no whitespace pretokenization. | Llama, Mistral, T5 |
| Unigram LM | Start with large vocab; iteratively prune tokens that least decrease likelihood. Probabilistic segmentation. | SentencePiece-Unigram, mBART |
Byte fallback — the robustness trick
When the tokenizer encounters a character it can't segment (rare emoji, unusual script), byte fallback emits the raw UTF-8 bytes as tokens instead of an <unk>. Critical for production: the model can always generate something, never crashes on input.
Vocab size — the tradeoff
Larger vocab → fewer tokens per text → faster inference per character; but larger embedding table and more compute per softmax. tiktoken cl100k_base (GPT-3.5/4) uses ~100k vocab; o200k (GPT-4o) uses 200k. Llama 3 jumped to 128k from Llama 2's 32k for the same reason.
- Numbers: GPT-2 splits "1234" oddly. Llama 3 forces single-digit tokenization for arithmetic.
- Spaces:
" hello"≠"hello"— leading space is part of the token. Many prompt bugs trace to this. - Non-English asymmetry: a sentence in Thai may use 4× more tokens than the same in English → 4× the API cost. Larger multilingual vocabs and byte fallback partly fix this.
- Glitch tokens: rare-vocab tokens that the model never trained on can produce hallucinations or repetition (
SolidGoldMagikarp).
BPE vs Unigram — when each
BPE is deterministic — given a tokenizer, every text has one segmentation. Unigram supports subword regularization: at training time, sample different segmentations of the same text per epoch as data augmentation. Used in some multilingual setups; BPE is more common.
- Byte-level BPE (tiktoken) and SentencePiece-BPE (Llama) dominate; both with byte fallback.
- Vocab size is a cost / table-size tradeoff; 100k–200k is the modern range.
- Unigram LM enables probabilistic segmentation (subword regularization); BPE is deterministic.
- Spaces are part of tokens; non-English costs more tokens — both real-world traps.
Tensor parallelism shards weights but leaves the full sequence on every GPU — at long context the activations overflow VRAM. Sequence parallelism shards the sequence in LN/dropout regions; context parallelism shards it inside attention via ring exchange. Together they make >1M context training feasible.
The problem TP alone doesn't solve
With tensor parallelism (TP), Q/K/V projections are sharded across GPUs, but LayerNorm, residuals and dropout all see the full hidden dim and the full sequence on each GPU. At long context the activations from these regions blow VRAM. The shared region between sub-layers is what TP can't help.
Sequence parallelism
Sequence parallelism (Korthikanti 2022, arxiv 2205.05198) shards the sequence dimension during the LN/dropout regions where TP can't help:
- TP region (attention/FFN matmul): full sequence, sharded hidden.
- SP region (LN, dropout, residual): sharded sequence, full hidden.
- Boundaries: all-gather (going into TP region) and reduce-scatter (going out).
Net effect: dramatic activation memory reduction at the cost of two extra collective ops per layer. Combined with TP it gives much higher effective context per GPU.
Context parallelism — ring attention
Context parallelism shards the sequence across GPUs inside attention itself. Each GPU holds a chunk of the sequence; K and V chunks rotate around a ring (Ring Attention, Liu 2023, arxiv 2310.01889) so every Q chunk eventually sees every K,V chunk. With FlashAttention-style online softmax, you accumulate the partial outputs without materializing the full score matrix anywhere. This is what enables >1M context training and inference.
32k context: TP alone is fine. 128k context: TP + SP — SP cuts the activation memory of the LN regions enough to fit. 1M+ context: TP + SP + context parallelism (ring attention) — the only way the full attention matrix doesn't have to live on any single GPU.
- TP shards weights; SP shards activations in LN/dropout/residual regions.
- SP boundaries: all-gather in, reduce-scatter out.
- Context parallelism (ring attention) shards the sequence inside attention itself — enables >1M context.
Common interview questions
- "Why divide attention scores by √d_k?" → Without scaling, dot product variance grows with d_k → softmax saturates → vanishing gradients.
- "Walk through MHA → MQA → GQA → MLA." → Reduces KV cache size step-by-step. MLA goes furthest by storing only a low-rank latent and reconstructing K,V on the fly.
- "Why RoPE over learned positional?" → Encodes relative position naturally via rotation in 2D subspaces; extrapolates better with PI/YaRN tricks.
- "Pre-norm vs post-norm — pick one and defend." → Pre-norm: gradients flow through residual cleanly, stable for depth, easier to train at scale. Post-norm: slightly better quality if stable, used in original Transformer, harder to train deep.
- "Walk through one MoE forward pass." → Token x → router → softmax over E experts → top-k selection (and gating weights) → forward through k experts → weighted sum → output. Plus aux load-balance loss during training.
- "Why does DeepSeek V3 use shared + routed experts?" → Shared experts capture common knowledge across all tokens; routed experts specialize on niches. Reduces redundant routed-expert capacity.
- "How does QK-Norm help?" → LN on Q and K before dot product bounds attention score magnitudes; prevents softmax saturation at large scales; improves training stability for very deep / wide models.
- "What changes if you switch from BPE to Unigram LM?" → Unigram supports probabilistic segmentation (sample different segmentations per training pass) → subword regularization. BPE is deterministic.
0 → hero reading path for transformer internals
- foundation Karpathy — Let's build GPT from scratch (2 hours; the canonical lecture)
- foundation The Annotated Transformer (Harvard NLP; line-by-line walkthrough of the original paper)
- foundation The Illustrated Transformer (Jay Alammar)
- build nanoGPT — read it, train it, modify it
- build Lilian Weng — The Transformer Family v2
- depth Attention Is All You Need (Vaswani 2017)
- depth RoPE (Su 2021)
- depth GQA (Ainslie 2023)
- depth DeepSeek-V2 / MLA (Liu 2024)
- depth DeepSeek-MoE
- depth SwiGLU / GLU variants (Shazeer 2020)
- depth YaRN (Peng 2023) for long context extension
- depth Ring Attention (Liu 2023)
Transformer quiz — readiness check
- Walk through one self-attention forward pass with shapes.
Show answer
x: (B, T, d). Q = xW_Q, K = xW_K, V = xW_V → each (B, T, d_k). scores = Q K^T / √d_k → (B, T, T). Apply causal mask (set upper triangle to -∞). attn = softmax(scores, dim=-1). out = attn V → (B, T, d_v).
- Memory complexity of attention?
Show answer
O(T²) for the attention matrix. FLOPs O(T² d). FlashAttention reduces memory to O(T) via tiling + online softmax — never materializes the full attn matrix in HBM.
- Difference between MHA, MQA, GQA, MLA in KV cache size.
Show answer
Per token per layer: MHA: 2·n_heads·d_head. MQA: 2·d_head (8× smaller for 8-head MHA). GQA(g): 2·g·d_head (intermediate). MLA: 2·d_latent (much smaller — DeepSeek-V2 latent ~512 vs MHA's ~8192).
- Why does MLA need decoupled RoPE?
Show answer
RoPE is non-linear in position — you cannot absorb a position-rotated K reconstruction into Q's projection. So MLA splits Q,K into content (low-rank, no RoPE, absorbable) and rotary (small dim, RoPE-applied, MQA-style shared) parts. Final score = Q^C K^C + Q^R K^R.
- Sinusoidal vs RoPE vs ALiBi — when each?
Show answer
Sinusoidal: original; can extrapolate weakly. Learned absolute: GPT-2/3; bad extrapolation. ALiBi: linear attention bias; good extrapolation; no positional embedding. RoPE: rotation in 2D subspaces; relative position via dot product; current default. RoPE + YaRN: best for long-context extension.
- Why scale FFN hidden dim 4d?
Show answer
Empirically optimal width-to-depth ratio at fixed parameters. 4d gives the FFN ~2/3 of the block's parameters; the remaining 1/3 are attention. SwiGLU uses ~2.67d × 3 matrices to keep param count constant.
- Switch Transformer aux loss formula?
Show answer
L_aux = α · N · Σ f_i · P_i where N = num experts, f_i = fraction of tokens routed to expert i, P_i = mean router prob for expert i, α ≈ 0.01. The N multiplier normalizes so uniform routing → loss = 1.
- What is auxiliary-loss-free balancing (DeepSeek V3)?
Show answer
Drop the aux loss; instead maintain per-expert bias added to router logits at routing time only. Decrement bias for over-subscribed experts, increment for under-subscribed. Hyperparam-free PI controller; avoids the gradient interference aux losses introduce.
- What does router Z-loss do?
Show answer
L_z = (1/B) Σ (log Σ exp(x_ij))² where x_ij are router logits. Penalizes the partition function — keeps router logits bounded so softmax stays stable in bf16/fp8. Used by ST-MoE / PaLM / DeepSeek.
- Why softmax not sigmoid in attention?
Show answer
Each row should sum to 1 — attention is a distribution over keys. Sigmoid would let multiple keys be "fully" attended. Some recent work explores sigmoid attention for length generalization but softmax is the default.
- Why do tokenizers matter for non-English?
Show answer
BPE merges trained on English-heavy corpora produce many tokens per word for non-English languages → cost asymmetry (more tokens to encode same content) and worse inductive bias for those languages. Fix: multilingual training, byte fallback, larger vocab.
- What's the difference between byte-level BPE and SentencePiece?
Show answer
Byte-level BPE (GPT-2): operates on UTF-8 bytes; vocab includes all 256 bytes; handles any input. SentencePiece (Kudo): operates on raw text (no whitespace pretokenization); supports BPE or Unigram; standard for Llama / Mistral.
- How does Mixtral 8x7B work?
Show answer
Each FFN block has 8 expert FFNs; router selects top-2 per token; output is gate-weighted sum of 2 experts. Total params 47B; active per token ~13B (because only 2/8 experts run + the shared attention). Inference is faster than 47B dense; quality close to it.
- Why does the residual connection matter for deep transformers?
Show answer
Without residuals, gradient through L sublayers is a product of L Jacobians → vanishes or explodes. Residual ∂(x + f(x))/∂x = I + ∂f/∂x preserves a clean identity gradient path. Lets you train 100+ layer networks.
- What's QK-Norm and when is it used?
Show answer
LayerNorm applied to Q and K before the dot product. Bounds attention score magnitudes → prevents softmax saturation → improves training stability at very large scale. Used in OLMo and some recent Gemma variants.
- Why is GPT-style decoder-only the dominant architecture in 2026?
Show answer
(1) Generative tasks need autoregressive decoding. (2) Encoder-decoder requires task split (BERT-style filling vs GPT-style generation); decoder-only with causal attention does both. (3) Scaling has worked exceptionally well; switching costs aren't worth it. (4) Inference can use KV cache straightforwardly.
- What does sequence parallelism shard, and what's the comm pattern?
Show answer
Shards the sequence dim during LN/dropout/residual regions (where TP can't help). All-gather + reduce-scatter on TP/SP boundaries. Cuts activation memory significantly. Combined with TP, gives much higher effective context per GPU.
- What is "expert choice" routing vs token-choice?
Show answer
Token-choice (standard): each token picks top-k experts. Can produce imbalance. Expert-choice (Zhou 2022): each expert picks top-k tokens (capacity per expert is fixed). Guarantees load balance — no aux loss needed. Tradeoff: tokens not picked are dropped or routed to a default.
- Why pre-norm enables deeper networks?
Show answer
Pre-norm: x + Sublayer(LN(x)) — residual is unnormalized, so gradients flow through it cleanly via the identity path. Post-norm: LN(x + Sublayer(x)) — gradient passes through the LN's reciprocal-std term, which can shrink or amplify; harder to train deep without careful warmup.
- How does YaRN extend RoPE to long context?
Show answer
Frequency-categorized RoPE rescaling: high-freq RoPE dims (position-sensitive, fast oscillation) get less interpolation; low-freq dims (semantic, slow oscillation) get full interpolation. Plus an attention temperature factor. Allows extending pretrained 4k models to 128k with brief continued training.