{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Two-Tower Retrieval\n",
    "**Goal**: Build a two-tower neural retrieval system with sampled-softmax training,\n",
    "logQ correction, recall@K eval, and approximate nearest-neighbour (LSH) retrieval.\n",
    "\n",
    "**Concepts exercised**: embedding lookup, in-batch negative sampling, temperature scaling,\n",
    "popularity bias correction, brute-force vs LSH retrieval timing.\n",
    "\n",
    "**Estimated time**: 2-4 hours  |  **Difficulty**: Medium-Hard"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to work through this notebook\n",
    "\n",
    "1. Work through the TODO stages in order.\n",
    "2. After filling in a TODO, run the corresponding `check_stageN()` function.\n",
    "   All asserts must pass before you move on.\n",
    "3. Only look at the SOLUTIONS section at the bottom after you have made a genuine attempt."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import time\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "RNG = np.random.default_rng(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "# --- Synthetic data ---\n",
    "N_USERS = 800\n",
    "N_ITEMS = 400\n",
    "N_CLUSTERS = 8\n",
    "EMBED_DIM = 32\n",
    "LATENT_DIM = 16\n",
    "\n",
    "# Users and items belong to latent taste clusters\n",
    "user_cluster = RNG.integers(0, N_CLUSTERS, size=N_USERS)\n",
    "item_cluster = RNG.integers(0, N_CLUSTERS, size=N_ITEMS)\n",
    "\n",
    "# Raw features: cluster one-hot + noise\n",
    "def make_features(cluster_ids, n_clusters, noise=0.3):\n",
    "    n = len(cluster_ids)\n",
    "    feats = np.zeros((n, n_clusters), dtype=np.float32)\n",
    "    feats[np.arange(n), cluster_ids] = 1.0\n",
    "    feats += RNG.standard_normal(feats.shape).astype(np.float32) * noise\n",
    "    return feats\n",
    "\n",
    "user_feats = make_features(user_cluster, N_CLUSTERS)   # (N_USERS, N_CLUSTERS)\n",
    "item_feats = make_features(item_cluster, N_CLUSTERS)   # (N_ITEMS, N_CLUSTERS)\n",
    "\n",
    "# Ground-truth relevance: user and item in same cluster => relevant\n",
    "def build_relevance(u_clusters, i_clusters):\n",
    "    return (u_clusters[:, None] == i_clusters[None, :]).astype(np.float32)\n",
    "\n",
    "relevance = build_relevance(user_cluster, item_cluster)  # (N_USERS, N_ITEMS)\n",
    "\n",
    "# Item popularities (power-law)\n",
    "item_pop = RNG.exponential(1.0, size=N_ITEMS).astype(np.float32)\n",
    "item_pop = item_pop / item_pop.sum()\n",
    "\n",
    "# Training pairs: for each user, sample one positive item from same cluster\n",
    "def make_training_pairs(u_clusters, i_clusters, n_items, rng):\n",
    "    pairs = []\n",
    "    for u, uc in enumerate(u_clusters):\n",
    "        pos_items = np.where(i_clusters == uc)[0]\n",
    "        if len(pos_items) == 0:\n",
    "            continue\n",
    "        i = rng.choice(pos_items)\n",
    "        pairs.append((u, int(i)))\n",
    "    return pairs\n",
    "\n",
    "train_pairs = make_training_pairs(user_cluster, item_cluster, N_ITEMS, RNG)\n",
    "print(f\"Users={N_USERS}, Items={N_ITEMS}, Train pairs={len(train_pairs)}, Clusters={N_CLUSTERS}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 1 \u2014 Tower Encoders\n",
    "\n",
    "A two-tower model has two identical (but separately parameterised) MLPs: one for users,\n",
    "one for items. Both map raw features into a shared embedding space of dimension `EMBED_DIM`.\n",
    "During retrieval we embed all items offline and store the result. At query time we embed\n",
    "the user and find nearest items via dot product (or cosine similarity).\n",
    "\n",
    "The key design choice is L2-normalising the output so that dot product equals cosine\n",
    "similarity \u2014 this simplifies the loss and ensures the temperature parameter has a\n",
    "predictable scale.\n",
    "\n",
    "Architecture: Linear(in, hidden) -> ReLU -> Linear(hidden, EMBED_DIM) -> L2-normalise.\n",
    "Use `hidden=64`. Both towers share the same architecture class but have separate weights."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "class Tower(nn.Module):\n",
    "    def __init__(self, in_dim, hidden, out_dim):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(in_dim, hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        # TODO(you): pass x through self.net, then L2-normalise along dim=-1\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage1():\n",
    "    t = Tower(N_CLUSTERS, 64, EMBED_DIM)\n",
    "    x = torch.from_numpy(user_feats[:10])\n",
    "    out = t(x)\n",
    "    assert out.shape == (10, EMBED_DIM), f\"Expected ({10}, {EMBED_DIM}), got {out.shape}\"\n",
    "    norms = out.norm(dim=-1)\n",
    "    assert torch.allclose(norms, torch.ones(10), atol=1e-5), \"Output must be L2-normalised\"\n",
    "    print(\"Stage 1 passed.\")\n",
    "\n",
    "# check_stage1()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 2 \u2014 In-Batch Sampled-Softmax Loss\n",
    "\n",
    "Given a batch of (user, positive_item) pairs, the in-batch negatives trick treats every\n",
    "OTHER item in the batch as a negative for each user. This is efficient: one forward pass\n",
    "yields a full (batch x batch) similarity matrix.\n",
    "\n",
    "The loss is cross-entropy where the target for user i is column i (its own positive item).\n",
    "A temperature parameter tau scales the logits: higher tau -> softer distribution (less\n",
    "confident), lower tau -> sharper. Good default: tau=0.07.\n",
    "\n",
    "Numerically: logits = (user_embs @ item_embs.T) / tau, then F.cross_entropy with\n",
    "targets = torch.arange(batch_size)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "def sampled_softmax_loss(user_embs, item_embs, tau=0.07):\n",
    "    \"\"\"\n",
    "    user_embs: (B, D) L2-normalised\n",
    "    item_embs: (B, D) L2-normalised\n",
    "    Returns scalar loss.\n",
    "\n",
    "    # TODO(you): compute (B,B) dot-product matrix / tau, then cross_entropy\n",
    "    # with target = arange(B).\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage2():\n",
    "    B = 16\n",
    "    u = F.normalize(torch.randn(B, EMBED_DIM), dim=-1)\n",
    "    v = F.normalize(torch.randn(B, EMBED_DIM), dim=-1)\n",
    "    loss = sampled_softmax_loss(u, v)\n",
    "    assert loss.shape == (), \"Loss must be a scalar\"\n",
    "    # Perfect alignment: each user embedding == its item embedding\n",
    "    loss_perfect = sampled_softmax_loss(u, u.clone())\n",
    "    loss_random = sampled_softmax_loss(u, v)\n",
    "    assert loss_perfect < loss_random, \"Perfect alignment must give lower loss than random\"\n",
    "    print(f\"Stage 2 passed. random_loss={loss_random:.3f}, perfect_loss={loss_perfect:.3f}\")\n",
    "\n",
    "# check_stage2()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 3 \u2014 Training Loop\n",
    "\n",
    "Now wire towers + loss into a training loop. Use Adam with lr=1e-3 and batch size 64.\n",
    "Shuffle the train_pairs each epoch. Run for 20 epochs.\n",
    "\n",
    "Important: pass item features through the item tower each batch (not cached during\n",
    "training). Log the mean loss per epoch so you can verify it decreases."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "user_tower = Tower(N_CLUSTERS, 64, EMBED_DIM)\n",
    "item_tower = Tower(N_CLUSTERS, 64, EMBED_DIM)\n",
    "optimizer = optim.Adam(list(user_tower.parameters()) + list(item_tower.parameters()), lr=1e-3)\n",
    "\n",
    "def train_one_epoch(pairs, u_feats, i_feats, batch_size=64):\n",
    "    \"\"\"\n",
    "    pairs: list of (user_idx, item_idx)\n",
    "    Returns mean loss over all batches.\n",
    "\n",
    "    # TODO(you): shuffle pairs, iterate in batches, call both towers,\n",
    "    # call sampled_softmax_loss, backprop, step optimizer.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage3():\n",
    "    # Just checks the function runs and returns a float\n",
    "    loss_e0 = train_one_epoch(train_pairs, user_feats, item_feats)\n",
    "    assert isinstance(loss_e0, float), \"Should return a float\"\n",
    "    loss_e1 = train_one_epoch(train_pairs, user_feats, item_feats)\n",
    "    print(f\"Stage 3 passed. epoch0_loss={loss_e0:.4f}, epoch1_loss={loss_e1:.4f}\")\n",
    "\n",
    "# check_stage3()"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Full training (run after stage 3 passes)\n",
    "def run_training(n_epochs=20):\n",
    "    losses = []\n",
    "    for ep in range(n_epochs):\n",
    "        l = train_one_epoch(train_pairs, user_feats, item_feats)\n",
    "        losses.append(l)\n",
    "        if (ep + 1) % 5 == 0:\n",
    "            print(f\"  epoch {ep+1:3d}  loss={l:.4f}\")\n",
    "    return losses\n",
    "\n",
    "# epoch_losses = run_training()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 4 \u2014 Recall@K Evaluation\n",
    "\n",
    "After training, embed ALL users and ALL items, then for each user retrieve the top-K\n",
    "items by dot product. Recall@K = fraction of users for whom at least one relevant item\n",
    "appears in the top K (relevant = same cluster).\n",
    "\n",
    "Baseline: popularity-based retrieval \u2014 for every user return the top-K most popular items.\n",
    "A good model should clearly beat this baseline.\n",
    "\n",
    "Steps:\n",
    "1. Embed all users: user_tower(torch.from_numpy(user_feats)).detach().numpy()\n",
    "2. Embed all items similarly.\n",
    "3. Compute scores = user_embs @ item_embs.T  (N_USERS, N_ITEMS)\n",
    "4. For each user, argsort descending, take top K, check if any are relevant."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def recall_at_k(scores, relevance, k=10):\n",
    "    \"\"\"\n",
    "    scores: (N_USERS, N_ITEMS) numpy array\n",
    "    relevance: (N_USERS, N_ITEMS) binary numpy array\n",
    "    Returns scalar recall@k.\n",
    "\n",
    "    # TODO(you): for each user, take top-k indices from scores,\n",
    "    # check if sum of relevance at those indices > 0, average over users.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "def popularity_scores(item_pop, n_users):\n",
    "    \"\"\"Return (n_users, N_ITEMS) score matrix where every row is item_pop.\"\"\"\n",
    "    # TODO(you): one line\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage4():\n",
    "    n_u, n_i = 20, 50\n",
    "    sc = np.zeros((n_u, n_i))\n",
    "    # Plant one known relevant item per user at column = user_index (mod n_i)\n",
    "    rel = np.zeros((n_u, n_i))\n",
    "    for u in range(n_u):\n",
    "        col = u % n_i\n",
    "        sc[u, col] = 100.0\n",
    "        rel[u, col] = 1.0\n",
    "    r = recall_at_k(sc, rel, k=5)\n",
    "    assert abs(r - 1.0) < 1e-6, f\"Expected recall 1.0 with planted scores, got {r}\"\n",
    "    pop_sc = popularity_scores(item_pop, n_u)\n",
    "    assert pop_sc.shape == (n_u, N_ITEMS), \"Wrong shape from popularity_scores\"\n",
    "    assert np.all(pop_sc[0] == pop_sc[1]), \"All rows should be identical\"\n",
    "    print(\"Stage 4 passed.\")\n",
    "\n",
    "# check_stage4()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 5 \u2014 logQ Correction for Popularity Bias\n",
    "\n",
    "When using in-batch negatives the item sampling distribution is NOT uniform \u2014 items\n",
    "that appear more frequently in training get sampled as negatives more often, making\n",
    "them look harder to the model than they really are. This inflates scores for popular\n",
    "items.\n",
    "\n",
    "logQ correction (Yoshua Bengio et al. / Google two-tower paper): subtract\n",
    "log(q_i) from the item logit before computing the loss, where q_i is the probability\n",
    "that item i appears in a batch (approximated by its training frequency / total pairs).\n",
    "\n",
    "Corrected logit for pair (u, i): s(u,i)/tau - log(q_i)\n",
    "\n",
    "Implement `sampled_softmax_loss_logq` that accepts an extra `log_q` vector of shape (B,)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def sampled_softmax_loss_logq(user_embs, item_embs, log_q, tau=0.07):\n",
    "    \"\"\"\n",
    "    Same as stage 2 but subtract log_q (shape B,) from each column of the logit matrix.\n",
    "    # TODO(you): logits = (u @ v.T) / tau - log_q[None, :]\n",
    "    # then cross_entropy as before.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "def estimate_log_q(pairs, n_items, smoothing=1e-8):\n",
    "    \"\"\"\n",
    "    Count how often each item appears as a positive, normalise to get q_i.\n",
    "    Return log(q_i + smoothing) as a numpy array of shape (n_items,).\n",
    "    \"\"\"\n",
    "    # TODO(you)\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage5():\n",
    "    B = 16\n",
    "    u = F.normalize(torch.randn(B, EMBED_DIM), dim=-1)\n",
    "    v = F.normalize(torch.randn(B, EMBED_DIM), dim=-1)\n",
    "    log_q = torch.zeros(B)  # uniform => no correction\n",
    "    loss_plain = sampled_softmax_loss(u, v)\n",
    "    loss_corr = sampled_softmax_loss_logq(u, v, log_q)\n",
    "    assert torch.allclose(loss_plain, loss_corr, atol=1e-5), \\\n",
    "        \"With uniform log_q=0 both losses must be equal\"\n",
    "    counts = estimate_log_q(train_pairs, N_ITEMS)\n",
    "    assert counts.shape == (N_ITEMS,), \"Wrong shape\"\n",
    "    assert np.all(np.isfinite(counts)), \"log_q must be finite\"\n",
    "    print(\"Stage 5 passed.\")\n",
    "\n",
    "# check_stage5()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 6 \u2014 Brute Force vs LSH Retrieval Timing\n",
    "\n",
    "At serving time the system must find the top-K items for a query embedding quickly.\n",
    "Brute force: compute all N_ITEMS dot products. Cost: O(N_ITEMS * D).\n",
    "Random hyperplane LSH: hash each item embedding to a binary code using B random\n",
    "hyperplanes (sign of dot product). At query time hash the query and retrieve only\n",
    "items with the same (or similar) hash bucket.\n",
    "\n",
    "This stage measures the wall-clock speed difference and the recall vs brute force\n",
    "that LSH achieves at varying numbers of hash bits."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def brute_force_topk(query_emb, item_embs, k=10):\n",
    "    \"\"\"\n",
    "    query_emb: (D,) numpy\n",
    "    item_embs: (N, D) numpy\n",
    "    Returns indices of top-k items by dot product.\n",
    "    # TODO(you)\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "def lsh_index(item_embs, n_bits=16, rng=None):\n",
    "    \"\"\"\n",
    "    Build LSH index.\n",
    "    Returns (hyperplanes, codes):\n",
    "      hyperplanes: (n_bits, D) random unit vectors\n",
    "      codes: (N_ITEMS,) integer hash codes (pack n_bits into one int)\n",
    "    # TODO(you): sample n_bits random directions, compute sign(item_embs @ planes.T),\n",
    "    # pack bits into integer codes.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "def lsh_query(query_emb, hyperplanes, item_codes, item_embs, k=10):\n",
    "    \"\"\"\n",
    "    Hash query, retrieve items with matching code, brute-force within bucket.\n",
    "    Falls back to top-k over all items if bucket is empty.\n",
    "    # TODO(you)\n",
    "    \"\"\"\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def check_stage6():\n",
    "    dummy_embs = RNG.standard_normal((N_ITEMS, EMBED_DIM)).astype(np.float32)\n",
    "    dummy_embs = dummy_embs / np.linalg.norm(dummy_embs, axis=1, keepdims=True)\n",
    "    q = dummy_embs[0]\n",
    "    idx = brute_force_topk(q, dummy_embs, k=10)\n",
    "    assert len(idx) == 10, \"Must return exactly k indices\"\n",
    "    assert 0 in idx, \"Query itself should be top-1 (identical embedding)\"\n",
    "    planes, codes = lsh_index(dummy_embs, n_bits=8, rng=RNG)\n",
    "    assert planes.shape == (8, EMBED_DIM)\n",
    "    assert codes.shape == (N_ITEMS,)\n",
    "    print(\"Stage 6 passed.\")\n",
    "\n",
    "# check_stage6()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stretch Goals\n",
    "\n",
    "1. Add hard-negative mining: for each user, find the highest-scoring NON-relevant item\n",
    "   and include it explicitly in the batch.\n",
    "2. Replace L2 normalisation with temperature-scaled dot product (no normalisation) and\n",
    "   compare training stability.\n",
    "3. Implement multi-vector retrieval: each user/item gets K embedding vectors; score is\n",
    "   the max (or sum) of pairwise dot products (ColBERT-style).\n",
    "4. Extend LSH to multi-probe: also check the L nearest hash buckets. Plot recall vs L.\n",
    "5. Export item embeddings and build a tiny HNSW graph by hand (greedy search on a\n",
    "   2-level hierarchy)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# SOLUTIONS -- no peeking until your attempt\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 1 \u2014 Tower forward\n",
    "def solution_tower_forward(self_net, x):\n",
    "    h = self_net(x)\n",
    "    return F.normalize(h, dim=-1)\n",
    "\n",
    "class SolutionTower(nn.Module):\n",
    "    def __init__(self, in_dim, hidden, out_dim):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(nn.Linear(in_dim, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))\n",
    "    def forward(self, x):\n",
    "        return F.normalize(self.net(x), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 2 \u2014 sampled softmax loss\n",
    "def solution_sampled_softmax_loss(user_embs, item_embs, tau=0.07):\n",
    "    logits = (user_embs @ item_embs.T) / tau          # (B, B)\n",
    "    targets = torch.arange(logits.size(0), device=logits.device)\n",
    "    return F.cross_entropy(logits, targets)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 3 \u2014 training loop\n",
    "def solution_train_one_epoch(pairs, u_feats, i_feats, batch_size=64):\n",
    "    indices = np.random.permutation(len(pairs))\n",
    "    total_loss, n_batches = 0.0, 0\n",
    "    for start in range(0, len(indices), batch_size):\n",
    "        batch_idx = indices[start:start + batch_size]\n",
    "        batch_pairs = [pairs[i] for i in batch_idx]\n",
    "        u_idx = torch.tensor([p[0] for p in batch_pairs])\n",
    "        i_idx = torch.tensor([p[1] for p in batch_pairs])\n",
    "        ub = torch.from_numpy(u_feats[u_idx.numpy()])\n",
    "        ib = torch.from_numpy(i_feats[i_idx.numpy()])\n",
    "        ue = user_tower(ub)\n",
    "        ie = item_tower(ib)\n",
    "        loss = solution_sampled_softmax_loss(ue, ie)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "        n_batches += 1\n",
    "    return total_loss / max(n_batches, 1)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 4 \u2014 recall@k and popularity scores\n",
    "def solution_recall_at_k(scores, relevance, k=10):\n",
    "    n_users = scores.shape[0]\n",
    "    hits = 0\n",
    "    for u in range(n_users):\n",
    "        top_k = np.argsort(scores[u])[::-1][:k]\n",
    "        if relevance[u, top_k].sum() > 0:\n",
    "            hits += 1\n",
    "    return hits / n_users\n",
    "\n",
    "def solution_popularity_scores(item_pop, n_users):\n",
    "    return np.tile(item_pop, (n_users, 1))"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 5 \u2014 logQ correction\n",
    "def solution_sampled_softmax_loss_logq(user_embs, item_embs, log_q, tau=0.07):\n",
    "    logits = (user_embs @ item_embs.T) / tau - log_q[None, :]\n",
    "    targets = torch.arange(logits.size(0), device=logits.device)\n",
    "    return F.cross_entropy(logits, targets)\n",
    "\n",
    "def solution_estimate_log_q(pairs, n_items, smoothing=1e-8):\n",
    "    counts = np.zeros(n_items, dtype=np.float64)\n",
    "    for _, i in pairs:\n",
    "        counts[i] += 1\n",
    "    q = (counts + smoothing) / (counts.sum() + smoothing * n_items)\n",
    "    return np.log(q)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Solution: Stage 6 \u2014 brute force and LSH\n",
    "def solution_brute_force_topk(query_emb, item_embs, k=10):\n",
    "    scores = item_embs @ query_emb\n",
    "    return np.argsort(scores)[::-1][:k]\n",
    "\n",
    "def solution_lsh_index(item_embs, n_bits=16, rng=None):\n",
    "    rng = rng or np.random.default_rng(0)\n",
    "    D = item_embs.shape[1]\n",
    "    planes = rng.standard_normal((n_bits, D)).astype(np.float32)\n",
    "    planes = planes / np.linalg.norm(planes, axis=1, keepdims=True)\n",
    "    bits = (item_embs @ planes.T) > 0   # (N, n_bits) bool\n",
    "    codes = bits.astype(np.uint64) @ (2 ** np.arange(n_bits, dtype=np.uint64))\n",
    "    return planes, codes\n",
    "\n",
    "def solution_lsh_query(query_emb, hyperplanes, item_codes, item_embs, k=10):\n",
    "    bits = (hyperplanes @ query_emb) > 0\n",
    "    q_code = int(bits.astype(np.uint64) @ (2 ** np.arange(len(bits), dtype=np.uint64)))\n",
    "    bucket = np.where(item_codes == q_code)[0]\n",
    "    if len(bucket) < k:\n",
    "        bucket = np.arange(len(item_embs))\n",
    "    scores = item_embs[bucket] @ query_emb\n",
    "    local_top = np.argsort(scores)[::-1][:k]\n",
    "    return bucket[local_top]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}