ML coding — from scratch

At Anthropic / OpenAI / Thinking Machines / Cohere / Mistral, ML coding rounds replace half the LeetCode rounds. Implement attention, sampling, BPE, k-means, etc. from memory, in 30-45 min, in NumPy or PyTorch.

The shortlist (drill all of these)

FrequencyProblemCompanies that ask it
★★★★★Single-head self-attention from scratch (NumPy)Anthropic, OpenAI, Cohere, Mistral, DeepMind, xAI
★★★★★Multi-head attentionsame as above + Apple, Nvidia
★★★★Causal mask + KV cacheAnthropic, OpenAI, Together, vLLM-adjacent companies
★★★★Sampling: top-k, top-p, temperature, beamOpenAI, Cohere, Adept, Cursor
★★★★Softmax (numerically stable + cross-entropy)everywhere
★★★★Layer norm / RMSNorm + gradAnthropic, OpenAI
★★★★Tokenizer: BPE training + encodeAnthropic, OpenAI, Cohere
★★★RoPE rotary embeddingsMistral, Cohere
★★★K-meansPinterest, Snap, recsys-heavy
★★★Logistic regression with SGDrecsys, ranking, ad teams
★★★Gradient descent + manual backprop on a 2-layer netgeneralist labs
★★★Mini DataLoader / sharded batchingtraining infra teams
★★Speculative decoding (toy)vLLM, Together, Anthropic infra
★★Convolution (forward + backward)Tesla, Waymo, Aurora, Apple
★★Train a tiny transformer on toy dataany pretraining-team loop

1. Single-head self-attention (NumPy)

Drill until you can do this in 6 minutes, no reference. Be ready to extend to multi-head and causal mask on the spot.

import numpy as np

def softmax(x, axis=-1):
    x = x - x.max(axis=axis, keepdims=True)  # numerical stability
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

def self_attention(x, w_q, w_k, w_v, mask=None):
    """
    x:    (B, T, d_model)
    w_*:  (d_model, d_k)
    mask: (T, T), True = mask out (set to -inf before softmax)
    returns: (B, T, d_v)
    """
    Q = x @ w_q                       # (B, T, d_k)
    K = x @ w_k                       # (B, T, d_k)
    V = x @ w_v                       # (B, T, d_v)
    d_k = Q.shape[-1]
    scores = Q @ K.transpose(0, 2, 1) / np.sqrt(d_k)   # (B, T, T)
    if mask is not None:
        scores = np.where(mask, -1e9, scores)
    attn = softmax(scores, axis=-1)   # (B, T, T)
    return attn @ V                   # (B, T, d_v)

# Causal mask: True above diagonal
T = 8
causal = np.triu(np.ones((T, T), dtype=bool), k=1)

Common follow-ups

2. Multi-head attention

def multi_head_attention(x, W_q, W_k, W_v, W_o, num_heads, mask=None):
    """
    x:   (B, T, d_model)
    W_*: (d_model, d_model)
    """
    B, T, d_model = x.shape
    d_head = d_model // num_heads

    Q = (x @ W_q).reshape(B, T, num_heads, d_head).transpose(0, 2, 1, 3)
    K = (x @ W_k).reshape(B, T, num_heads, d_head).transpose(0, 2, 1, 3)
    V = (x @ W_v).reshape(B, T, num_heads, d_head).transpose(0, 2, 1, 3)
    # shapes now: (B, H, T, d_head)

    scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_head)   # (B, H, T, T)
    if mask is not None:
        scores = np.where(mask, -1e9, scores)
    attn = softmax(scores, axis=-1)
    out = attn @ V                                            # (B, H, T, d_head)

    out = out.transpose(0, 2, 1, 3).reshape(B, T, d_model)
    return out @ W_o

3. KV cache (autoregressive decode)

class KVCache:
    """One layer's KV cache for autoregressive generation."""
    def __init__(self, max_len, num_heads, d_head, dtype=np.float32):
        self.K = np.zeros((max_len, num_heads, d_head), dtype=dtype)
        self.V = np.zeros((max_len, num_heads, d_head), dtype=dtype)
        self.pos = 0
    def append(self, k_new, v_new):
        # k_new: (num_heads, d_head) for one new token
        self.K[self.pos] = k_new
        self.V[self.pos] = v_new
        self.pos += 1
    def get(self):
        return self.K[:self.pos], self.V[:self.pos]

def decode_step(x_new, w_q, w_k, w_v, w_o, cache, num_heads):
    """One token forward, using and updating KV cache. x_new: (d_model,)"""
    d_model = x_new.shape[0]; d_head = d_model // num_heads
    q = (x_new @ w_q).reshape(num_heads, d_head)   # (H, d_head)
    k = (x_new @ w_k).reshape(num_heads, d_head)
    v = (x_new @ w_v).reshape(num_heads, d_head)
    cache.append(k, v)
    K, V = cache.get()                               # (T, H, d_head)
    scores = np.einsum('hd,thd->ht', q, K) / np.sqrt(d_head)   # (H, T)
    attn = softmax(scores, axis=-1)
    out = np.einsum('ht,thd->hd', attn, V).reshape(d_model)
    return out @ w_o

4. Sampling (greedy / top-k / top-p / temperature)

def sample(logits, temperature=1.0, top_k=None, top_p=None):
    """logits: (V,)  -> sampled token id"""
    logits = logits / max(temperature, 1e-6)
    if top_k is not None:
        idx = np.argpartition(logits, -top_k)[-top_k:]
        mask = np.full_like(logits, -1e9); mask[idx] = logits[idx]
        logits = mask
    probs = softmax(logits)
    if top_p is not None:
        order = np.argsort(-probs)         # stable descending sort
        cum = np.cumsum(probs[order])
        cutoff = np.searchsorted(cum, top_p) + 1   # smallest set with cum >= p
        keep = order[:cutoff]
        new = np.zeros_like(probs); new[keep] = probs[keep]
        probs = new / new.sum()
    return np.random.choice(len(probs), p=probs)

5. Numerically-stable softmax + cross-entropy

def log_softmax(x, axis=-1):
    m = x.max(axis=axis, keepdims=True)
    return x - m - np.log(np.exp(x - m).sum(axis=axis, keepdims=True))

def cross_entropy(logits, labels):
    """logits (B, V), labels (B,) -> mean NLL"""
    lsm = log_softmax(logits, axis=-1)
    nll = -lsm[np.arange(len(labels)), labels]
    return nll.mean()

6. LayerNorm and RMSNorm

def layer_norm(x, gamma, beta, eps=1e-5):
    """x: (..., d). gamma, beta: (d,)"""
    mu = x.mean(-1, keepdims=True)
    var = x.var(-1, keepdims=True)
    return gamma * (x - mu) / np.sqrt(var + eps) + beta

def rms_norm(x, gamma, eps=1e-6):
    """No mean, no bias — Llama-style."""
    rms = np.sqrt((x**2).mean(-1, keepdims=True) + eps)
    return gamma * x / rms

7. RoPE (rotary positional embedding)

Convention gotcha — interleave vs half-rotate
The original RoPE paper rotates pairs (x[..., 0::2], x[..., 1::2]) — what's below. Llama / HuggingFace use (x[..., :d//2], x[..., d//2:]). Equivalent up to a permutation of dims, but you must match the trained convention or you get garbage output. Frontier-lab interview probe: "which convention does Llama use?" Answer: half-rotate.
def rope(x, base=10000):
    """
    x: (..., T, d) — apply 2D rotation per pair of dims, frequency by index.
    Returns rotated x of same shape.
    """
    *batch, T, d = x.shape
    assert d % 2 == 0
    # frequency for each pair of dims
    freqs = 1.0 / (base ** (np.arange(0, d, 2) / d))   # (d/2,)
    pos = np.arange(T)                                  # (T,)
    angles = np.outer(pos, freqs)                       # (T, d/2)
    cos = np.cos(angles); sin = np.sin(angles)          # (T, d/2)

    x1 = x[..., 0::2]; x2 = x[..., 1::2]                # (..., T, d/2)
    rot1 = x1 * cos - x2 * sin
    rot2 = x1 * sin + x2 * cos
    out = np.empty_like(x)
    out[..., 0::2] = rot1
    out[..., 1::2] = rot2
    return out

8. BPE tokenizer (training + encode)

from collections import Counter

def bpe_train(corpus, num_merges):
    """corpus: list of words (strings). Returns merge list."""
    # 1) initial: each word as tuple of chars + end-of-word
    word_freq = Counter(corpus)
    splits = {w: tuple(w) + ('</w>',) for w in word_freq}
    merges = []

    for _ in range(num_merges):
        # 2) count pair frequencies
        pair_freq = Counter()
        for w, freq in word_freq.items():
            sym = splits[w]
            for i in range(len(sym) - 1):
                pair_freq[(sym[i], sym[i+1])] += freq
        if not pair_freq: break

        # 3) pick most frequent pair
        best = max(pair_freq, key=pair_freq.get)
        merges.append(best)

        # 4) merge that pair in all words
        for w in splits:
            sym = splits[w]; new_sym = []
            i = 0
            while i < len(sym):
                if i < len(sym) - 1 and (sym[i], sym[i+1]) == best:
                    new_sym.append(sym[i] + sym[i+1]); i += 2
                else:
                    new_sym.append(sym[i]); i += 1
            splits[w] = tuple(new_sym)
    return merges

def bpe_encode(word, merges):
    sym = list(word) + ['</w>']
    for a, b in merges:
        i = 0; new_sym = []
        while i < len(sym):
            if i < len(sym) - 1 and sym[i] == a and sym[i+1] == b:
                new_sym.append(a + b); i += 2
            else:
                new_sym.append(sym[i]); i += 1
        sym = new_sym
    return sym

9. K-means

def kmeans(X, k, num_iters=20, seed=0):
    """X: (N, d). returns (centroids (k, d), labels (N,))."""
    rng = np.random.default_rng(seed)
    centroids = X[rng.choice(len(X), k, replace=False)]
    for _ in range(num_iters):
        # assign
        d = ((X[:, None, :] - centroids[None, :, :]) ** 2).sum(-1)   # (N, k)
        labels = d.argmin(1)
        # update
        new_c = np.stack([X[labels == j].mean(0) if (labels == j).any() else centroids[j]
                          for j in range(k)])
        if np.allclose(new_c, centroids): break
        centroids = new_c
    return centroids, labels

Follow-ups they actually ask

10. Logistic regression with SGD

def logreg_sgd(X, y, lr=0.01, epochs=10, l2=1e-4):
    """X: (N, d), y: (N,) in {0,1}. returns w, b."""
    N, d = X.shape
    w = np.zeros(d); b = 0.0
    for _ in range(epochs):
        idx = np.random.permutation(N)
        for i in idx:
            z = X[i] @ w + b
            p = 1.0 / (1.0 + np.exp(-z))
            err = p - y[i]
            w -= lr * (err * X[i] + l2 * w)
            b -= lr * err
    return w, b

11. Manual backprop on a 2-layer MLP

def mlp_train_step(x, y, W1, b1, W2, b2, lr=1e-2):
    """x (B, d_in), y (B,) integer labels."""
    # Forward
    z1 = x @ W1 + b1                    # (B, h)
    a1 = np.maximum(z1, 0)              # ReLU
    z2 = a1 @ W2 + b2                   # (B, V)
    p  = softmax(z2, axis=-1)

    # Loss (CE)
    B = len(y)
    logp = np.log(p[np.arange(B), y] + 1e-12)
    loss = -logp.mean()

    # Backward
    dz2 = p.copy(); dz2[np.arange(B), y] -= 1; dz2 /= B   # (B, V)
    dW2 = a1.T @ dz2                                       # (h, V)
    db2 = dz2.sum(0)
    da1 = dz2 @ W2.T                                       # (B, h)
    dz1 = da1 * (z1 > 0)                                   # ReLU' = (z > 0)
    dW1 = x.T @ dz1
    db1 = dz1.sum(0)

    # SGD
    W1 -= lr * dW1; b1 -= lr * db1
    W2 -= lr * dW2; b2 -= lr * db2
    return loss

12. Mini DataLoader / sharded batching

class ShardedDataLoader:
    """Distributes batches across N workers; each worker sees disjoint subset."""
    def __init__(self, dataset, batch_size, world_size, rank, shuffle=True, seed=0):
        self.ds, self.B, self.W, self.r = dataset, batch_size, world_size, rank
        self.shuffle, self.seed = shuffle, seed
    def __iter__(self):
        n = len(self.ds)
        idx = np.arange(n)
        if self.shuffle:
            np.random.default_rng(self.seed).shuffle(idx)
        # take this rank's slice
        idx = idx[self.r::self.W]
        for i in range(0, len(idx), self.B):
            batch = [self.ds[j] for j in idx[i:i+self.B]]
            yield batch

13. Speculative decoding (toy)

def sample_from_residual(p_target, p_draft):
    """Sample from (p_target - p_draft)_+ / Z. Both are (V,) distributions."""
    residual = np.maximum(0.0, p_target - p_draft)
    Z = residual.sum()
    if Z < 1e-12:
        # degenerate: fall back to target
        return np.random.choice(len(p_target), p=p_target)
    return np.random.choice(len(p_target), p=residual / Z)

def speculative_step(prefix, draft_model, target_model, k=4):
    """
    Draft proposes k tokens autoregressively (one decode step each).
    Target then runs ONE forward pass over (prefix + drafts) to get
    next-token distributions at every drafted position in parallel.
    Returns the new sequence after one verification pass.
    """
    # 1. Draft k tokens, recording the draft distributions
    drafts = []          # list of (token, p_draft over V)
    state = list(prefix)
    for _ in range(k):
        p_d = draft_model(state)
        t = int(np.argmax(p_d))   # or sample; both work for the math
        drafts.append((t, p_d))
        state.append(t)

    # 2. ONE big target forward pass over prefix + drafted tokens.
    # Returns target next-token distribution at every position
    # i in 0..k (i.e. positions right after prefix[-1], drafts[0], ..., drafts[k-1])
    target_probs = target_model.batch_next_token_dist(prefix, [d[0] for d in drafts])
    # target_probs: shape (k+1, V)

    # 3. Modified rejection sampling, position by position.
    accepted = []
    for i, (t, p_d) in enumerate(drafts):
        p_t = target_probs[i]
        accept_prob = min(1.0, p_t[t] / max(p_d[t], 1e-12))
        if np.random.rand() < accept_prob:
            accepted.append(t)
        else:
            # Reject: sample one corrected token from (p_t - p_d)_+
            corrected = sample_from_residual(p_t, p_d)
            return prefix + accepted + [corrected]
    # All k accepted: sample one bonus token from target's last distribution
    fresh = np.random.choice(len(target_probs[-1]), p=target_probs[-1])
    return prefix + accepted + [fresh]

14. 2D convolution (forward + backward)

def conv2d_forward(x, w, stride=1, pad=0):
    """x: (B, C_in, H, W), w: (C_out, C_in, kH, kW)"""
    if pad: x = np.pad(x, ((0,0),(0,0),(pad,pad),(pad,pad)))
    B, C_in, H, W = x.shape
    C_out, _, kH, kW = w.shape
    H_out = (H - kH) // stride + 1
    W_out = (W - kW) // stride + 1
    out = np.zeros((B, C_out, H_out, W_out))
    for i in range(H_out):
        for j in range(W_out):
            patch = x[:, :, i*stride:i*stride+kH, j*stride:j*stride+kW]   # (B, C_in, kH, kW)
            out[:, :, i, j] = np.einsum('bchw,ochw->bo', patch, w)
    return out

For backward, dW = correlation between input patches and dY; dX = full convolution between dY and flipped W. Practice writing both. Real implementations use im2col + GEMM.

15. Train a tiny transformer end-to-end (PyTorch)

import torch, torch.nn as nn, torch.nn.functional as F

class TinyTransformer(nn.Module):
    def __init__(self, V, d=128, h=4, L=4, T=64):
        super().__init__()
        self.tok = nn.Embedding(V, d)
        self.pos = nn.Embedding(T, d)
        self.blocks = nn.ModuleList([Block(d, h) for _ in range(L)])
        self.norm = nn.LayerNorm(d)
        self.head = nn.Linear(d, V, bias=False)
        self.head.weight = self.tok.weight       # tied
        self.T = T
    def forward(self, idx):
        B, T = idx.shape
        x = self.tok(idx) + self.pos(torch.arange(T, device=idx.device))
        mask = torch.triu(torch.ones(T, T, device=idx.device), diagonal=1).bool()
        for blk in self.blocks: x = blk(x, mask)
        return self.head(self.norm(x))

class Block(nn.Module):
    def __init__(self, d, h):
        super().__init__()
        self.ln1 = nn.LayerNorm(d); self.attn = nn.MultiheadAttention(d, h, batch_first=True)
        self.ln2 = nn.LayerNorm(d); self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
    def forward(self, x, mask):
        h, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=mask, need_weights=False)
        x = x + h
        return x + self.ff(self.ln2(x))

# Train loop (skeleton)
def train_step(model, batch, opt):
    x, y = batch  # (B, T), (B, T)
    logits = model(x)
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
    opt.zero_grad(); loss.backward(); opt.step()
    return loss.item()

Common interview gotchas

Numerical stability
Always do x - x.max() before exp in softmax. Always add + eps in division (especially in normalization layers and log). Interviewers will throw you a sequence of all-zeros to test this.
Shape bugs
State the shape of every tensor in a comment. Especially when you reshape for multi-head: (B, T, d) → (B, T, H, d_head) → (B, H, T, d_head). The transpose order matters.
Mask broadcasting
Causal mask is (T, T), attention scores are (B, H, T, T) — they broadcast. But add a leading dim for batched masks (padding masks are per-example).
Train/eval mode
Dropout, BN, LN behave differently. model.eval() matters. Even more critical at inference for KV cache.