{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mini-project: the KV-cache lab \u2014 measure why inference needs it\n",
    "**Goal:** build a tiny GPT-style decoder in NumPy, generate text the naive way (recompute\n",
    "everything per token), then with a KV cache \u2014 assert identical outputs, then MEASURE the\n",
    "O(T^2) vs O(T) difference and the cache's memory bill.\n",
    "**Concepts:** autoregressive decoding, attention complexity, why vLLM exists. **Time:** ~2-3h.\n",
    "\n",
    "**How to work:** TODOs in order; un-comment each `check_stageN()` after implementing.\n",
    "Solutions at the bottom. The point is the MEASUREMENT at the end \u2014 get there."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "\n",
    "rng = np.random.default_rng(0)\n",
    "np.seterr(all=\"ignore\")  # macOS Accelerate emits spurious matmul warnings; results are finite (asserted in checks)\n",
    "D_MODEL, N_LAYERS, VOCAB = 64, 2, 50\n",
    "\n",
    "class TinyDecoder:\n",
    "    \"\"\"2-layer decoder: per layer, single-head causal attention + a 2-layer MLP.\n",
    "    Weights are random \u2014 we generate gibberish; the SYSTEMS behavior is the lesson.\"\"\"\n",
    "    def __init__(self):\n",
    "        scale = 0.1\n",
    "        self.embed = rng.normal(0, scale, (VOCAB, D_MODEL))\n",
    "        self.layers = []\n",
    "        for _ in range(N_LAYERS):\n",
    "            self.layers.append({\n",
    "                \"wq\": rng.normal(0, scale, (D_MODEL, D_MODEL)),\n",
    "                \"wk\": rng.normal(0, scale, (D_MODEL, D_MODEL)),\n",
    "                \"wv\": rng.normal(0, scale, (D_MODEL, D_MODEL)),\n",
    "                \"w1\": rng.normal(0, scale, (D_MODEL, 2 * D_MODEL)),\n",
    "                \"w2\": rng.normal(0, scale, (2 * D_MODEL, D_MODEL)),\n",
    "            })\n",
    "        self.unembed = rng.normal(0, scale, (D_MODEL, VOCAB))\n",
    "\n",
    "MODEL = TinyDecoder()\n",
    "\n",
    "def softmax_rows(x):\n",
    "    x = x - x.max(axis=-1, keepdims=True)\n",
    "    e = np.exp(x)\n",
    "    return e / e.sum(axis=-1, keepdims=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 1 \u2014 the forward pass over a whole sequence\n",
    "Standard causal attention: for a sequence of n token ids, embed, then per layer compute\n",
    "Q, K, V (n x d each), scores QK^T/sqrt(d) with the causal mask (position i sees <= i),\n",
    "softmax, weighted sum of V, then the MLP with a tanh. Return logits for EVERY position.\n",
    "This is what training and \"naive generation\" run."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def forward_full(model, token_ids):\n",
    "    \"\"\"token_ids: list[int]. Return logits, shape (n, VOCAB).\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def check_stage1():\n",
    "    logits = forward_full(MODEL, [1, 2, 3])\n",
    "    assert logits.shape == (3, VOCAB)\n",
    "    logits2 = forward_full(MODEL, [1, 2, 3, 4])\n",
    "    assert np.allclose(logits[0], logits2[0]), \"causality: past must not depend on future\"\n",
    "    print(\"stage 1 ok\")\n",
    "# check_stage1()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 2 \u2014 naive generation\n",
    "Greedy decode T tokens the expensive way: each step, run forward_full over the ENTIRE\n",
    "sequence so far and take argmax of the LAST position's logits. Count, per step, how many\n",
    "token-positions you processed \u2014 that's the work meter we'll plot."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def generate_naive(model, prompt_ids, num_new):\n",
    "    \"\"\"Return (all_token_ids, work_per_step list of positions processed).\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def check_stage2():\n",
    "    ids, work = generate_naive(MODEL, [1, 2], 5)\n",
    "    assert len(ids) == 7 and len(work) == 5\n",
    "    assert work[-1] > work[0], \"later steps reprocess longer sequences\"\n",
    "    print(\"stage 2 ok, work per step:\", work)\n",
    "# check_stage2()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 3 \u2014 the KV cache\n",
    "The fix: K and V for past tokens never change (causality!), so store them per layer.\n",
    "Each new token: embed ONE token, compute its q,k,v per layer, APPEND k,v to the cache,\n",
    "attend q against all cached K,V. Work per step = 1 position. Outputs must be IDENTICAL."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def generate_cached(model, prompt_ids, num_new):\n",
    "    \"\"\"Same interface as generate_naive, but O(1) positions per step after prefill.\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def check_stage3():\n",
    "    ids_naive, _ = generate_naive(MODEL, [1, 2, 3], 8)\n",
    "    ids_cached, work = generate_cached(MODEL, [1, 2, 3], 8)\n",
    "    assert ids_naive == ids_cached, \"cache must not change outputs!\"\n",
    "    assert all(w == 1 for w in work[1:]), \"each decode step should process one position\"\n",
    "    print(\"stage 3 ok \u2014 identical outputs\")\n",
    "# check_stage3()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 4 \u2014 measure it\n",
    "Time both at T=200 new tokens, plot (or print) cumulative positions processed \u2014\n",
    "the naive curve is quadratic, the cached one linear. Then compute the cache's memory:\n",
    "2 (K,V) x layers x seq_len x d_model x 8 bytes (float64 here). Scale the FORMULA to a\n",
    "7B model (32 layers, 32 heads x 128 dim, fp16) at 4k context and feel the gigabytes.\n",
    "\n",
    "**Stretch:** sliding-window cache (keep last W tokens) \u2014 measure the quality/memory trade;\n",
    "batch dimension; paged allocation (fixed-size blocks + a block table)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# SOLUTIONS \u2014 no peeking until your attempt"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def _attn_block(layer, x_seq):\n",
    "    q = x_seq @ layer[\"wq\"]; k = x_seq @ layer[\"wk\"]; v = x_seq @ layer[\"wv\"]\n",
    "    n = x_seq.shape[0]\n",
    "    scores = q @ k.T / np.sqrt(D_MODEL)\n",
    "    mask = np.triu(np.ones((n, n), dtype=bool), k=1)\n",
    "    scores[mask] = -1e9\n",
    "    attended = softmax_rows(scores) @ v\n",
    "    hidden = x_seq + attended\n",
    "    return hidden + np.tanh(hidden @ layer[\"w1\"]) @ layer[\"w2\"]\n",
    "\n",
    "def solution_forward_full(model, token_ids):\n",
    "    x = model.embed[np.array(token_ids)]\n",
    "    for layer in model.layers:\n",
    "        x = _attn_block(layer, x)\n",
    "    return x @ model.unembed\n",
    "\n",
    "def solution_generate_naive(model, prompt_ids, num_new):\n",
    "    ids = list(prompt_ids); work = []\n",
    "    for _ in range(num_new):\n",
    "        logits = solution_forward_full(model, ids)\n",
    "        ids.append(int(np.argmax(logits[-1])))\n",
    "        work.append(len(ids) - 1)\n",
    "    return ids, work\n",
    "\n",
    "def solution_generate_cached(model, prompt_ids, num_new):\n",
    "    caches = [{\"k\": np.zeros((0, D_MODEL)), \"v\": np.zeros((0, D_MODEL))}\n",
    "              for _ in model.layers]\n",
    "    ids = list(prompt_ids); work = []\n",
    "    last_logits = None\n",
    "\n",
    "    def run_tokens(token_ids):\n",
    "        nonlocal last_logits\n",
    "        x = model.embed[np.array(token_ids)]\n",
    "        for layer, cache in zip(model.layers, caches):\n",
    "            q = x @ layer[\"wq\"]; k = x @ layer[\"wk\"]; v = x @ layer[\"wv\"]\n",
    "            cache[\"k\"] = np.vstack([cache[\"k\"], k])\n",
    "            cache[\"v\"] = np.vstack([cache[\"v\"], v])\n",
    "            n_new, n_total = x.shape[0], cache[\"k\"].shape[0]\n",
    "            scores = q @ cache[\"k\"].T / np.sqrt(D_MODEL)\n",
    "            # causal mask within the new block only (past is all visible)\n",
    "            offset = n_total - n_new\n",
    "            for i in range(n_new):\n",
    "                scores[i, offset + i + 1:] = -1e9\n",
    "            attended = softmax_rows(scores) @ cache[\"v\"]\n",
    "            hidden = x + attended\n",
    "            x = hidden + np.tanh(hidden @ layer[\"w1\"]) @ layer[\"w2\"]\n",
    "        last_logits = (x @ model.unembed)[-1]\n",
    "\n",
    "    run_tokens(ids)            # prefill: whole prompt in one parallel pass\n",
    "    work.append(len(ids))\n",
    "    for _ in range(num_new):\n",
    "        ids.append(int(np.argmax(last_logits)))\n",
    "        run_tokens([ids[-1]])  # decode: ONE token\n",
    "        work.append(1)\n",
    "    return ids, work[1:]       # report decode-step work"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    prompt, T = [1, 2, 3, 4], 200\n",
    "    t0 = time.perf_counter(); solution_generate_naive(MODEL, prompt, T)\n",
    "    naive_s = time.perf_counter() - t0\n",
    "    t0 = time.perf_counter(); solution_generate_cached(MODEL, prompt, T)\n",
    "    cached_s = time.perf_counter() - t0\n",
    "    print(f\"naive {naive_s*1000:.0f}ms vs cached {cached_s*1000:.0f}ms \"\n",
    "          f\"-> speedup {naive_s/cached_s:.1f}x at T={T}\")\n",
    "    toy_bytes = 2 * N_LAYERS * (len(prompt) + T) * D_MODEL * 8\n",
    "    print(f\"toy cache: {toy_bytes/1024:.0f} KB\")\n",
    "    real = 2 * 32 * 4096 * (32 * 128) * 2\n",
    "    print(f\"7B-class model @4k ctx, fp16: {real/1e9:.1f} GB per sequence -> \"\n",
    "          f\"x64 batch = {64*real/1e9:.0f} GB. That is why KV memory rules serving.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}