{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mini-project: BPE tokenizer from scratch\n",
    "**Goal:** build byte-pair encoding end to end: count pairs, merge, build a vocab, encode/decode.\n",
    "**Concepts:** tokenization, why vocab size matters, compression vs coverage. **Time:** ~2h. **Difficulty:** medium.\n",
    "\n",
    "**How to work:** do the TODOs in order. Each stage has a `check_stageN()` with asserts \u2014\n",
    "un-comment its call after implementing. Solutions live at the bottom behind the divider;\n",
    "no peeking until you've genuinely attempted each stage."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import re\n",
    "from collections import Counter\n",
    "\n",
    "CORPUS = (\n",
    "    \"the transformer model processes tokens in parallel. the tokenizer splits text into tokens. \"\n",
    "    \"low frequency words split into subword units while frequent words stay whole. \"\n",
    "    \"the lower the merge count the smaller the vocabulary and the longer the token sequences. \"\n",
    "    \"byte pair encoding repeatedly merges the most frequent adjacent pair of symbols. \"\n",
    "    \"the best tokenizer balances vocabulary size against sequence length for the training corpus. \"\n",
    ") * 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 1 \u2014 words to symbol tuples\n",
    "BPE starts from words split into characters, with a special end-of-word marker `</w>` so\n",
    "\"low\" and \"lower\" don't share their final symbol. Build the initial word frequency table:\n",
    "each word maps to a tuple of symbols, and we count how often each word occurs."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def build_word_table(corpus):\n",
    "    \"\"\"Return Counter mapping symbol-tuples to counts.\n",
    "    'the' -> ('t','h','e','</w>') with its frequency.\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def check_stage1():\n",
    "    table = build_word_table(\"the the cat\")\n",
    "    assert table[(\"t\", \"h\", \"e\", \"</w>\")] == 2\n",
    "    assert table[(\"c\", \"a\", \"t\", \"</w>\")] == 1\n",
    "    print(\"stage 1 ok\")\n",
    "# check_stage1()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 2 \u2014 count adjacent pairs\n",
    "The heart of BPE: across the whole table, count every ADJACENT symbol pair, weighted by\n",
    "word frequency. The most frequent pair is the next merge."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def count_pairs(word_table):\n",
    "    \"\"\"Return Counter of adjacent symbol pairs weighted by word counts.\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def check_stage2():\n",
    "    table = Counter({(\"a\", \"b\", \"c\"): 3, (\"b\", \"c\"): 2})\n",
    "    pairs = count_pairs(table)\n",
    "    assert pairs[(\"b\", \"c\")] == 5 and pairs[(\"a\", \"b\")] == 3\n",
    "    print(\"stage 2 ok\")\n",
    "# check_stage2()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 3 \u2014 apply a merge\n",
    "Replace every occurrence of the chosen pair (x, y) with the fused symbol x+y in every word.\n",
    "('l','o','w','</w>') under merge ('l','o') becomes ('lo','w','</w>')."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def apply_merge(word_table, pair):\n",
    "    \"\"\"Return a NEW word table with `pair` fused everywhere.\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def check_stage3():\n",
    "    table = Counter({(\"l\", \"o\", \"w\", \"</w>\"): 2})\n",
    "    merged = apply_merge(table, (\"l\", \"o\"))\n",
    "    assert merged[(\"lo\", \"w\", \"</w>\")] == 2\n",
    "    print(\"stage 3 ok\")\n",
    "# check_stage3()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 4 \u2014 the training loop + encode/decode\n",
    "Train: repeat (count pairs -> take argmax -> merge), recording the merge ORDER \u2014 that order\n",
    "IS the tokenizer. Encode a new word by replaying merges in order; decode by concatenating\n",
    "and stripping `</w>`. Round-trip must be lossless."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def train_bpe(corpus, num_merges):\n",
    "    \"\"\"Return the ordered list of merges.\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def encode_word(word, merges):\n",
    "    \"\"\"Apply merges in training order to one word; return list of tokens.\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def decode(tokens):\n",
    "    return \"\".join(tokens).replace(\"</w>\", \" \").strip()\n",
    "\n",
    "def check_stage4():\n",
    "    merges = train_bpe(CORPUS, 50)\n",
    "    tokens = encode_word(\"tokenizer\", merges)\n",
    "    assert decode(tokens) == \"tokenizer\"\n",
    "    assert len(tokens) < len(\"tokenizer\") + 1, \"merges should compress a frequent-ish word\"\n",
    "    print(\"stage 4 ok:\", tokens)\n",
    "# check_stage4()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 5 \u2014 measure the tradeoff\n",
    "Sweep num_merges in [0, 25, 50, 100, 200]: vocabulary size grows, average tokens-per-word\n",
    "falls. Print the table and notice the diminishing returns \u2014 this is THE tokenizer design\n",
    "tradeoff (bigger vocab = shorter sequences = more embedding parameters).\n",
    "\n",
    "**Stretch goals:** GPT-2-style regex pre-splitting (`re.findall(r\"\\w+|\\S\", text)`) so\n",
    "punctuation never fuses across words; byte-level fallback so ANY string round-trips."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# SOLUTIONS \u2014 no peeking until your attempt\n",
    "Re-write yours from memory tomorrow; recognition is not recall."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def solution_build_word_table(corpus):\n",
    "    table = Counter()\n",
    "    for word in corpus.split():\n",
    "        table[tuple(word) + (\"</w>\",)] += 1\n",
    "    return table\n",
    "\n",
    "def solution_count_pairs(word_table):\n",
    "    pairs = Counter()\n",
    "    for word, count in word_table.items():\n",
    "        for left, right in zip(word, word[1:]):\n",
    "            pairs[(left, right)] += count\n",
    "    return pairs\n",
    "\n",
    "def solution_apply_merge(word_table, pair):\n",
    "    fused = pair[0] + pair[1]\n",
    "    new_table = Counter()\n",
    "    for word, count in word_table.items():\n",
    "        symbols = []\n",
    "        i = 0\n",
    "        while i < len(word):\n",
    "            if i + 1 < len(word) and (word[i], word[i + 1]) == pair:\n",
    "                symbols.append(fused)\n",
    "                i += 2\n",
    "            else:\n",
    "                symbols.append(word[i])\n",
    "                i += 1\n",
    "        new_table[tuple(symbols)] += count\n",
    "    return new_table\n",
    "\n",
    "def solution_train_bpe(corpus, num_merges):\n",
    "    table = solution_build_word_table(corpus)\n",
    "    merges = []\n",
    "    for _ in range(num_merges):\n",
    "        pairs = solution_count_pairs(table)\n",
    "        if not pairs:\n",
    "            break\n",
    "        best = max(pairs, key=pairs.get)\n",
    "        merges.append(best)\n",
    "        table = solution_apply_merge(table, best)\n",
    "    return merges\n",
    "\n",
    "def solution_encode_word(word, merges):\n",
    "    symbols = list(word) + [\"</w>\"]\n",
    "    for pair in merges:                    # replay in training order\n",
    "        i, out = 0, []\n",
    "        while i < len(symbols):\n",
    "            if i + 1 < len(symbols) and (symbols[i], symbols[i + 1]) == pair:\n",
    "                out.append(symbols[i] + symbols[i + 1]); i += 2\n",
    "            else:\n",
    "                out.append(symbols[i]); i += 1\n",
    "        symbols = out\n",
    "    return symbols"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    merges = solution_train_bpe(CORPUS, 100)\n",
    "    sample = \"the tokenizer processes lower frequency tokens\"\n",
    "    total_chars = sum(len(w) for w in sample.split())\n",
    "    total_tokens = sum(len(solution_encode_word(w, merges)) for w in sample.split())\n",
    "    print(f\"chars {total_chars} -> tokens {total_tokens} \"\n",
    "          f\"(compression {total_chars / total_tokens:.2f} chars/token)\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}