{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transformer Block from Scratch\n",
    "**Goal**: Implement scaled dot-product attention in NumPy, add causal masking,\n",
    "multi-head reshaping, LayerNorm + residual + MLP, then train a tiny char-LM in PyTorch.\n",
    "\n",
    "**Concepts exercised**: attention mechanism, causal masking, multi-head attention,\n",
    "layer normalisation, residual connections, character-level language modelling.\n",
    "\n",
    "**Estimated time**: 2-4 hours  |  **Difficulty**: Medium"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to work through this notebook\n",
    "\n",
    "1. Work through the TODO stages in order.\n",
    "2. After each TODO, run the corresponding `check_stageN()` function.\n",
    "3. Only consult the SOLUTIONS section after a genuine attempt."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import math\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "RNG = np.random.default_rng(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "# --- Tiny text corpus embedded as a constant ---\n",
    "CORPUS = (\n",
    "    \"the quick brown fox jumps over the lazy dog \"\n",
    "    \"pack my box with five dozen liquor jugs \"\n",
    "    \"how vexingly quick daft zebras jump \"\n",
    "    \"the five boxing wizards jump quickly \"\n",
    "    \"sphinx of black quartz judge my vow \"\n",
    "    \"two driven jocks help fax my big quiz \"\n",
    "    \"five quacking zephyrs jolt my wax bed \"\n",
    "    \"the jay pig fox zebra and my wolves quack \"\n",
    "    \"blowzy red vixens fight for a quick jump \"\n",
    "    \"glib jocks quiz nymph to vex dwarf \"\n",
    "    \"jackdaws love my big sphinx of quartz \"\n",
    "    \"the quick onyx goblin jumps over the lazy dwarf \"\n",
    "    \"how quickly daft jumping zebras vex \"\n",
    "    \"bright vixens jump dozy fowl quack \"\n",
    "    \"quick wafting zephyrs vex bold jim \"\n",
    "    \"quick zephyrs blow vexing daft jim \"\n",
    "    \"sex prof blew my junk tv quiz \"\n",
    "    \"bawds jog flick quartz vex nymph \"\n",
    "    \"waltz nymph for quick jigs vex bud \"\n",
    "    \"mr jock tv quiz phd bags few lynx \"\n",
    ")\n",
    "\n",
    "chars = sorted(set(CORPUS))\n",
    "vocab_size = len(chars)\n",
    "ch2idx = {c: i for i, c in enumerate(chars)}\n",
    "idx2ch = {i: c for c, i in ch2idx.items()}\n",
    "data = np.array([ch2idx[c] for c in CORPUS], dtype=np.int64)\n",
    "\n",
    "SEQ_LEN = 16\n",
    "D_MODEL = 32\n",
    "N_HEADS = 4\n",
    "D_HEAD = D_MODEL // N_HEADS\n",
    "\n",
    "print(f\"Vocab size: {vocab_size}, Corpus length: {len(data)}, D_MODEL: {D_MODEL}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 1 \u2014 Scaled Dot-Product Attention (NumPy)\n",
    "\n",
    "Attention computes a weighted average of values, where the weights come from the\n",
    "compatibility between queries and keys. The formula is:\n",
    "\n",
    "  Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) V\n",
    "\n",
    "where Q, K, V are (seq_len, d_k) matrices. The scaling by sqrt(d_k) prevents the\n",
    "dot products from growing too large in magnitude, which would push softmax into\n",
    "regions with very small gradients.\n",
    "\n",
    "Shape walkthrough on a 4-token example with d_k=8:\n",
    "  Q: (4, 8), K: (4, 8), V: (4, 8)\n",
    "  Q @ K.T: (4, 4) \u2014 one score per (query, key) pair\n",
    "  softmax per row: (4, 4) \u2014 attention weights, each row sums to 1\n",
    "  weights @ V: (4, 8) \u2014 weighted sum of values"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def scaled_dot_product_attention(Q, K, V, mask=None):\n",
    "    \"\"\"\n",
    "    Q: (T, d_k) or (B, H, T, d_k)\n",
    "    K: same shape as Q\n",
    "    V: same shape as Q\n",
    "    mask: optional boolean array, True = BLOCK this position (set to -inf before softmax)\n",
    "    Returns: attended values, same shape as Q; attention weights (last two dims: T x T)\n",
    "\n",
    "    # TODO(you):\n",
    "    # 1. Compute raw scores = Q @ K.T / sqrt(d_k)  (last two dims: T x T)\n",
    "    # 2. If mask provided, set masked positions to -1e9\n",
    "    # 3. Softmax over last axis\n",
    "    # 4. Output = weights @ V\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage1():\n",
    "    T, dk = 4, 8\n",
    "    rng = np.random.default_rng(0)\n",
    "    Q = rng.standard_normal((T, dk)).astype(np.float32)\n",
    "    K = rng.standard_normal((T, dk)).astype(np.float32)\n",
    "    V = rng.standard_normal((T, dk)).astype(np.float32)\n",
    "    out, weights = scaled_dot_product_attention(Q, K, V)\n",
    "    assert out.shape == (T, dk), f\"Output shape wrong: {out.shape}\"\n",
    "    assert weights.shape == (T, T), f\"Weights shape wrong: {weights.shape}\"\n",
    "    assert np.allclose(weights.sum(axis=-1), 1.0, atol=1e-5), \"Attention weights must sum to 1 per row\"\n",
    "    # When Q==K==V and no mask: each row output should be a weighted avg of V rows\n",
    "    print(\"Stage 1 passed.\")\n",
    "\n",
    "# check_stage1()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 2 \u2014 Causal Masking\n",
    "\n",
    "In a decoder (auto-regressive) model, position i may only attend to positions <= i.\n",
    "We enforce this with a causal (lower-triangular) mask: the upper triangle of the\n",
    "attention score matrix is set to -infinity before softmax.\n",
    "\n",
    "Why does this prevent leakage? After softmax, those positions get weight ~0, so the\n",
    "output at position i carries zero information from future tokens.\n",
    "\n",
    "An assert can verify this: run attention on a sequence where the VALUE at position 3\n",
    "has been modified to something unusual. With causal mask the output at position 0\n",
    "should NOT change."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def make_causal_mask(T):\n",
    "    \"\"\"\n",
    "    Return (T, T) boolean mask where mask[i, j] = True means position j is BLOCKED\n",
    "    for query position i (i.e. j > i).\n",
    "    # TODO(you): np.triu with k=1 gives upper triangle excluding diagonal.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage2():\n",
    "    T = 6\n",
    "    mask = make_causal_mask(T)\n",
    "    assert mask.shape == (T, T)\n",
    "    # Lower triangle (including diagonal) must be False (allowed)\n",
    "    for i in range(T):\n",
    "        for j in range(T):\n",
    "            expected = j > i\n",
    "            assert mask[i, j] == expected, f\"mask[{i},{j}] should be {expected}\"\n",
    "    # Verify no information leak from future\n",
    "    rng = np.random.default_rng(1)\n",
    "    dk = 8\n",
    "    Q = rng.standard_normal((T, dk)).astype(np.float32)\n",
    "    K = rng.standard_normal((T, dk)).astype(np.float32)\n",
    "    V = rng.standard_normal((T, dk)).astype(np.float32)\n",
    "    out_base, _ = scaled_dot_product_attention(Q, K, V, mask=mask)\n",
    "    V2 = V.copy()\n",
    "    V2[3:] += 999.0   # drastically change future values\n",
    "    out_mod, _ = scaled_dot_product_attention(Q, K, V2, mask=mask)\n",
    "    assert np.allclose(out_base[0], out_mod[0], atol=1e-4), \\\n",
    "        \"Position 0 output must not change when future values change (causal mask leak!)\"\n",
    "    print(\"Stage 2 passed. Causal mask verified, no information leak.\")\n",
    "\n",
    "# check_stage2()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 3 \u2014 Multi-Head Attention Reshaping\n",
    "\n",
    "Multi-head attention runs H independent attention heads in parallel. Each head operates\n",
    "on a d_head = d_model / H dimensional subspace. The trick is that we project Q, K, V\n",
    "once with large weight matrices (d_model x d_model), then RESHAPE the result to\n",
    "separate the heads \u2014 no separate projection per head is needed.\n",
    "\n",
    "Reshape: (B, T, d_model) -> (B, H, T, d_head) by splitting the last dimension.\n",
    "After attention: (B, H, T, d_head) -> (B, T, d_model) by merging H and d_head back.\n",
    "This reshape-then-attend trick is equivalent to H separate attention operations."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def split_heads(x, n_heads):\n",
    "    \"\"\"\n",
    "    x: (B, T, d_model) numpy array\n",
    "    Returns: (B, n_heads, T, d_model // n_heads)\n",
    "    # TODO(you): reshape last dim, then transpose axes so heads dim is second.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "def merge_heads(x):\n",
    "    \"\"\"\n",
    "    x: (B, n_heads, T, d_head) numpy array\n",
    "    Returns: (B, T, n_heads * d_head)\n",
    "    # TODO(you): transpose back, then reshape to merge last two dims.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage3():\n",
    "    B, T, H, d_model = 2, 6, 4, 32\n",
    "    x = RNG.standard_normal((B, T, d_model)).astype(np.float32)\n",
    "    split = split_heads(x, H)\n",
    "    assert split.shape == (B, H, T, d_model // H), f\"split_heads shape wrong: {split.shape}\"\n",
    "    merged = merge_heads(split)\n",
    "    assert merged.shape == (B, T, d_model), f\"merge_heads shape wrong: {merged.shape}\"\n",
    "    assert np.allclose(x, merged, atol=1e-6), \"split then merge must recover original\"\n",
    "    print(\"Stage 3 passed.\")\n",
    "\n",
    "# check_stage3()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 4 \u2014 LayerNorm + Residual + MLP Block\n",
    "\n",
    "A transformer block wraps attention inside:\n",
    "  x = x + Attention(LayerNorm(x))\n",
    "  x = x + MLP(LayerNorm(x))\n",
    "\n",
    "LayerNorm normalises across the feature dimension (not the batch) so that each\n",
    "position's embedding has mean 0 and variance 1, then applies learnable scale (gamma)\n",
    "and shift (beta). This stabilises training by preventing internal covariate shift.\n",
    "\n",
    "The MLP is two linear layers with a GELU (or ReLU) in between, expanding to 4*d_model\n",
    "then back. Residual connections let gradients flow directly to earlier layers.\n",
    "\n",
    "Implement LayerNorm in NumPy (no PyTorch), then write the full block in PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def layer_norm_np(x, gamma, beta, eps=1e-5):\n",
    "    \"\"\"\n",
    "    x: (..., d_model) numpy\n",
    "    gamma, beta: (d_model,) numpy\n",
    "    Normalise last dim, scale and shift.\n",
    "    # TODO(you): mean and var over last axis (keepdims=True), then normalise.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage4():\n",
    "    d = 16\n",
    "    x = RNG.standard_normal((3, 5, d)).astype(np.float32)\n",
    "    gamma = np.ones(d, dtype=np.float32)\n",
    "    beta = np.zeros(d, dtype=np.float32)\n",
    "    out = layer_norm_np(x, gamma, beta)\n",
    "    assert out.shape == x.shape\n",
    "    means = out.mean(axis=-1)\n",
    "    vars_ = out.var(axis=-1)\n",
    "    assert np.allclose(means, 0.0, atol=1e-5), f\"Mean not 0: {means.max()}\"\n",
    "    assert np.allclose(vars_, 1.0, atol=1e-4), f\"Var not 1: {vars_.min()}\"\n",
    "    # Non-trivial gamma/beta\n",
    "    gamma2 = np.full(d, 2.0, dtype=np.float32)\n",
    "    beta2 = np.full(d, -1.0, dtype=np.float32)\n",
    "    out2 = layer_norm_np(x, gamma2, beta2)\n",
    "    assert np.allclose(out2.mean(axis=-1), -1.0, atol=1e-4)\n",
    "    print(\"Stage 4 passed.\")\n",
    "\n",
    "# check_stage4()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 5 \u2014 Tiny Char-LM in PyTorch\n",
    "\n",
    "Now tie everything together. Build a single-block transformer language model over\n",
    "characters using PyTorch's built-in MultiheadAttention (or your own if you prefer).\n",
    "Architecture:\n",
    "  Embedding(vocab_size, D_MODEL)\n",
    "  + positional embedding (learned, shape SEQ_LEN x D_MODEL)\n",
    "  -> TransformerBlock (1 layer, N_HEADS heads, causal mask)\n",
    "  -> Linear(D_MODEL, vocab_size)\n",
    "\n",
    "Train with cross-entropy, AdamW lr=3e-3, 300 steps, batch=32.\n",
    "After training, implement a greedy sampler and generate 80 characters."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "class TinyCharLM(nn.Module):\n",
    "    def __init__(self, vocab_size, d_model, n_heads, seq_len):\n",
    "        super().__init__()\n",
    "        self.tok_emb = nn.Embedding(vocab_size, d_model)\n",
    "        self.pos_emb = nn.Embedding(seq_len, d_model)\n",
    "        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)\n",
    "        self.ln1 = nn.LayerNorm(d_model)\n",
    "        self.ln2 = nn.LayerNorm(d_model)\n",
    "        self.mlp = nn.Sequential(\n",
    "            nn.Linear(d_model, 4 * d_model),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(4 * d_model, d_model),\n",
    "        )\n",
    "        self.head = nn.Linear(d_model, vocab_size)\n",
    "        self.seq_len = seq_len\n",
    "\n",
    "    def forward(self, idx):\n",
    "        # idx: (B, T)\n",
    "        # TODO(you):\n",
    "        # 1. tok_emb(idx) + pos_emb(arange(T))\n",
    "        # 2. Build causal mask for nn.MultiheadAttention (attn_mask, additive, shape T x T)\n",
    "        # 3. LayerNorm -> self-attention with causal mask -> residual\n",
    "        # 4. LayerNorm -> MLP -> residual\n",
    "        # 5. Return self.head(x)  -- (B, T, vocab_size)\n",
    "        raise NotImplementedError\n",
    "\n",
    "def make_batch(data, seq_len, batch_size, rng):\n",
    "    \"\"\"Sample random (x, y) pairs where y is x shifted by 1.\"\"\"\n",
    "    starts = rng.integers(0, len(data) - seq_len - 1, size=batch_size)\n",
    "    x = np.stack([data[s:s + seq_len] for s in starts])\n",
    "    y = np.stack([data[s + 1:s + seq_len + 1] for s in starts])\n",
    "    return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage5():\n",
    "    model = TinyCharLM(vocab_size, D_MODEL, N_HEADS, SEQ_LEN)\n",
    "    x, _ = make_batch(data, SEQ_LEN, batch_size=4, rng=RNG)\n",
    "    logits = model(x)\n",
    "    assert logits.shape == (4, SEQ_LEN, vocab_size), f\"Bad logits shape: {logits.shape}\"\n",
    "    print(\"Stage 5 passed.\")\n",
    "\n",
    "# check_stage5()"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Training cell (run after stage 5 passes)\n",
    "def train_char_lm(n_steps=300, batch_size=32, lr=3e-3):\n",
    "    model = TinyCharLM(vocab_size, D_MODEL, N_HEADS, SEQ_LEN)\n",
    "    opt = optim.AdamW(model.parameters(), lr=lr)\n",
    "    rng = np.random.default_rng(7)\n",
    "    losses = []\n",
    "    for step in range(n_steps):\n",
    "        x, y = make_batch(data, SEQ_LEN, batch_size, rng)\n",
    "        logits = model(x)   # (B, T, V)\n",
    "        loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1))\n",
    "        opt.zero_grad()\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        losses.append(loss.item())\n",
    "        if (step + 1) % 100 == 0:\n",
    "            print(f\"  step {step+1:4d}  loss={loss.item():.4f}\")\n",
    "    return model, losses\n",
    "\n",
    "# model, losses = train_char_lm()\n",
    "\n",
    "def greedy_sample(model, prompt, n_chars=80):\n",
    "    model.eval()\n",
    "    idx = [ch2idx.get(c, 0) for c in prompt]\n",
    "    result = list(prompt)\n",
    "    with torch.no_grad():\n",
    "        for _ in range(n_chars):\n",
    "            inp = torch.tensor([idx[-SEQ_LEN:]], dtype=torch.long)\n",
    "            logits = model(inp)   # (1, T, V)\n",
    "            next_id = logits[0, -1].argmax().item()\n",
    "            result.append(idx2ch[next_id])\n",
    "            idx.append(next_id)\n",
    "    return \"\".join(result)\n",
    "\n",
    "# print(greedy_sample(model, \"the \"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stretch Goals\n",
    "\n",
    "1. Replace greedy sampling with temperature sampling and nucleus (top-p) sampling.\n",
    "   Compare diversity vs coherence at different temperatures.\n",
    "2. Add positional encodings using sinusoidal functions (no learned embeddings) and\n",
    "   compare training speed and final loss.\n",
    "3. Stack two transformer blocks and observe whether loss improves. Add dropout.\n",
    "4. Visualise the attention weights on a sample sequence as a heatmap.\n",
    "5. Implement relative positional encodings (RoPE) and swap in the NumPy attention\n",
    "   from stages 1-2."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# SOLUTIONS -- no peeking until your attempt\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 1 \u2014 scaled dot-product attention\n",
    "def solution_scaled_dot_product_attention(Q, K, V, mask=None):\n",
    "    d_k = Q.shape[-1]\n",
    "    scores = Q @ K.swapaxes(-2, -1) / math.sqrt(d_k)   # (..., T, T)\n",
    "    if mask is not None:\n",
    "        scores = np.where(mask, -1e9, scores)\n",
    "    def softmax(x):\n",
    "        x = x - x.max(axis=-1, keepdims=True)\n",
    "        e = np.exp(x)\n",
    "        return e / e.sum(axis=-1, keepdims=True)\n",
    "    weights = softmax(scores)\n",
    "    return weights @ V, weights"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 2 \u2014 causal mask\n",
    "def solution_make_causal_mask(T):\n",
    "    return np.triu(np.ones((T, T), dtype=bool), k=1)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 3 \u2014 split / merge heads\n",
    "def solution_split_heads(x, n_heads):\n",
    "    B, T, d = x.shape\n",
    "    d_head = d // n_heads\n",
    "    x = x.reshape(B, T, n_heads, d_head)\n",
    "    return x.transpose(0, 2, 1, 3)  # (B, H, T, d_head)\n",
    "\n",
    "def solution_merge_heads(x):\n",
    "    B, H, T, d_head = x.shape\n",
    "    x = x.transpose(0, 2, 1, 3)     # (B, T, H, d_head)\n",
    "    return x.reshape(B, T, H * d_head)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 4 \u2014 LayerNorm in NumPy\n",
    "def solution_layer_norm_np(x, gamma, beta, eps=1e-5):\n",
    "    mean = x.mean(axis=-1, keepdims=True)\n",
    "    var = x.var(axis=-1, keepdims=True)\n",
    "    x_norm = (x - mean) / np.sqrt(var + eps)\n",
    "    return gamma * x_norm + beta"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 5 \u2014 TinyCharLM forward\n",
    "class SolutionTinyCharLM(nn.Module):\n",
    "    def __init__(self, vocab_size, d_model, n_heads, seq_len):\n",
    "        super().__init__()\n",
    "        self.tok_emb = nn.Embedding(vocab_size, d_model)\n",
    "        self.pos_emb = nn.Embedding(seq_len, d_model)\n",
    "        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)\n",
    "        self.ln1 = nn.LayerNorm(d_model)\n",
    "        self.ln2 = nn.LayerNorm(d_model)\n",
    "        self.mlp = nn.Sequential(nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model))\n",
    "        self.head = nn.Linear(d_model, vocab_size)\n",
    "        self.seq_len = seq_len\n",
    "\n",
    "    def forward(self, idx):\n",
    "        B, T = idx.shape\n",
    "        pos = torch.arange(T, device=idx.device)\n",
    "        x = self.tok_emb(idx) + self.pos_emb(pos)\n",
    "        # additive causal mask for nn.MultiheadAttention: shape (T, T), upper-tri = -inf\n",
    "        causal = torch.triu(torch.full((T, T), float('-inf'), device=idx.device), diagonal=1)\n",
    "        x2 = self.ln1(x)\n",
    "        attn_out, _ = self.attn(x2, x2, x2, attn_mask=causal)\n",
    "        x = x + attn_out\n",
    "        x = x + self.mlp(self.ln2(x))\n",
    "        return self.head(x)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}