{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mini-project: k-means + an LSH approximate-nearest-neighbor index\n",
    "**Goal:** cluster from scratch (memorable style), see why k-means++ init matters, then build\n",
    "a random-hyperplane LSH index and measure the recall-vs-speed tradeoff that every vector DB lives on.\n",
    "**Concepts:** clustering, init sensitivity, ANN search. **Time:** ~2h.\n",
    "\n",
    "**How to work:** TODOs in order; un-comment each check after implementing. Solutions at the bottom."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "rng = np.random.default_rng(42)\n",
    "np.seterr(all=\"ignore\")  # macOS Accelerate emits spurious matmul warnings\n",
    "\n",
    "def make_blobs(num_clusters=5, points_per=200, dim=16, spread=0.6):\n",
    "    centers = rng.normal(0, 4, (num_clusters, dim))\n",
    "    points, labels = [], []\n",
    "    for idx, center in enumerate(centers):\n",
    "        points.append(center + rng.normal(0, spread, (points_per, dim)))\n",
    "        labels += [idx] * points_per\n",
    "    return np.vstack(points), np.array(labels), centers\n",
    "\n",
    "X, TRUE_LABELS, TRUE_CENTERS = make_blobs()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 1 \u2014 k-means, the memorable way\n",
    "Two named helpers, plain loops where it reads better:\n",
    "`nearest_centroid(point, centroids)` and `recompute_centroids(X, assignments, k)`.\n",
    "Loop: assign every point to its nearest centroid, recompute each centroid as the mean of\n",
    "its points, stop when assignments stop changing. Track inertia (sum of squared distances)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def nearest_centroid(point, centroids):\n",
    "    \"\"\"Index of the closest centroid (Euclidean).\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def kmeans(X, k, max_iters=100, init_centroids=None):\n",
    "    \"\"\"Return (centroids, assignments, inertia).\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def check_stage1():\n",
    "    cents, assign, inertia = kmeans(X, 5)\n",
    "    assert cents.shape == (5, X.shape[1]) and len(assign) == len(X)\n",
    "    assert inertia < 3 * len(X), \"blobs are tight; inertia should be small per point\"\n",
    "    print(\"stage 1 ok, inertia/pt:\", round(inertia / len(X), 2))\n",
    "# check_stage1()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 2 \u2014 k-means++ init\n",
    "Random init sometimes drops two centroids in one blob (bad local optimum). k-means++:\n",
    "first centroid uniform; each next centroid sampled with probability proportional to the\n",
    "SQUARED distance to its nearest already-chosen centroid. Run 20 restarts of both inits\n",
    "and compare worst-case inertia \u2014 the ++ version's worst case is far better."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def kmeans_pp_init(X, k):\n",
    "    \"\"\"Return k initial centroids chosen by the D^2 weighting.\"\"\"\n",
    "    raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def check_stage2():\n",
    "    worst_random = max(kmeans(X, 5, init_centroids=X[rng.choice(len(X), 5, replace=False)])[2]\n",
    "                       for _ in range(10))\n",
    "    worst_pp = max(kmeans(X, 5, init_centroids=kmeans_pp_init(X, 5))[2] for _ in range(10))\n",
    "    print(f\"worst inertia \u2014 random: {worst_random:.0f}, ++: {worst_pp:.0f}\")\n",
    "    assert worst_pp <= worst_random * 1.05\n",
    "# check_stage2()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 3 \u2014 random-hyperplane LSH\n",
    "ANN idea: hash each vector by the SIGN of its dot product with b random hyperplanes \u2014\n",
    "similar vectors land in the same bucket with high probability. Query: hash, scan only that\n",
    "bucket. More bits = smaller buckets = faster but lower recall."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "class LSHIndex:\n",
    "    def __init__(self, X, num_bits):\n",
    "        self.X = X\n",
    "        self.planes = rng.normal(0, 1, (num_bits, X.shape[1]))\n",
    "        self.buckets = {}\n",
    "        for idx in range(len(X)):\n",
    "            self.buckets.setdefault(self._hash(X[idx]), []).append(idx)\n",
    "\n",
    "    def _hash(self, vec):\n",
    "        \"\"\"Tuple of 0/1 signs against each hyperplane.\"\"\"\n",
    "        raise NotImplementedError  # TODO(you)\n",
    "\n",
    "    def query(self, vec, top_k=10):\n",
    "        \"\"\"Brute-force ONLY within the query's bucket; return up to top_k indices.\"\"\"\n",
    "        raise NotImplementedError  # TODO(you)\n",
    "\n",
    "def brute_force_topk(X, vec, top_k=10):\n",
    "    dists = ((X - vec) ** 2).sum(axis=1)\n",
    "    return list(np.argsort(dists)[:top_k])\n",
    "\n",
    "def check_stage3():\n",
    "    index = LSHIndex(X, num_bits=6)\n",
    "    q = X[7] + rng.normal(0, 0.05, X.shape[1])\n",
    "    approx = set(index.query(q, 10)); exact = set(brute_force_topk(X, q, 10))\n",
    "    recall = len(approx & exact) / 10\n",
    "    print(\"stage 3 ok, recall@10:\", recall)\n",
    "    assert recall >= 0.3\n",
    "# check_stage3()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 4 \u2014 the recall/speed curve\n",
    "Sweep num_bits in [2, 4, 6, 8, 10, 12] over 50 queries: measure mean recall@10 and mean\n",
    "candidates scanned (the speed proxy). Print the table \u2014 THE tradeoff every ANN system tunes.\n",
    "**Stretch:** multi-table LSH (OR of several indexes) to recover recall; compare against\n",
    "an IVF-style index (k-means coarse quantizer + probe nearest clusters \u2014 you built the\n",
    "k-means already!)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# SOLUTIONS \u2014 no peeking until your attempt"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def solution_nearest_centroid(point, centroids):\n",
    "    best_idx, best_dist = 0, float(\"inf\")\n",
    "    for idx in range(len(centroids)):\n",
    "        dist = float(((point - centroids[idx]) ** 2).sum())\n",
    "        if dist < best_dist:\n",
    "            best_idx, best_dist = idx, dist\n",
    "    return best_idx\n",
    "\n",
    "def solution_kmeans(X, k, max_iters=100, init_centroids=None):\n",
    "    centroids = X[rng.choice(len(X), k, replace=False)] if init_centroids is None \\\n",
    "        else init_centroids.copy()\n",
    "    assignments = np.full(len(X), -1)\n",
    "    for _ in range(max_iters):\n",
    "        new_assign = np.array([solution_nearest_centroid(p, centroids) for p in X])\n",
    "        if (new_assign == assignments).all():\n",
    "            break\n",
    "        assignments = new_assign\n",
    "        for c in range(k):\n",
    "            members = X[assignments == c]\n",
    "            if len(members):\n",
    "                centroids[c] = members.mean(axis=0)\n",
    "    inertia = sum(((X[i] - centroids[assignments[i]]) ** 2).sum() for i in range(len(X)))\n",
    "    return centroids, assignments, inertia\n",
    "\n",
    "def solution_kmeans_pp_init(X, k):\n",
    "    centroids = [X[rng.integers(len(X))]]\n",
    "    for _ in range(k - 1):\n",
    "        d2 = np.min([((X - c) ** 2).sum(axis=1) for c in centroids], axis=0)\n",
    "        probs = d2 / d2.sum()\n",
    "        centroids.append(X[rng.choice(len(X), p=probs)])\n",
    "    return np.array(centroids)\n",
    "\n",
    "class SolutionLSHIndex(LSHIndex):\n",
    "    def _hash(self, vec):\n",
    "        return tuple((self.planes @ vec > 0).astype(int))\n",
    "    def query(self, vec, top_k=10):\n",
    "        candidates = self.buckets.get(self._hash(vec), [])\n",
    "        dists = [(((self.X[i] - vec) ** 2).sum(), i) for i in candidates]\n",
    "        dists.sort()\n",
    "        return [i for _, i in dists[:top_k]]"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    cents, assign, inertia = solution_kmeans(X, 5, init_centroids=solution_kmeans_pp_init(X, 5))\n",
    "    print(f\"k-means++ inertia/pt: {inertia/len(X):.2f}\")\n",
    "    print(f\"{'bits':>5} {'recall@10':>10} {'scanned':>8}\")\n",
    "    for bits in [2, 4, 6, 8, 10]:\n",
    "        index = SolutionLSHIndex(X, num_bits=bits)\n",
    "        recalls, scanned = [], []\n",
    "        for _ in range(50):\n",
    "            q = X[rng.integers(len(X))] + rng.normal(0, 0.05, X.shape[1])\n",
    "            approx = set(index.query(q, 10))\n",
    "            exact = set(brute_force_topk(X, q, 10))\n",
    "            recalls.append(len(approx & exact) / 10)\n",
    "            scanned.append(len(index.buckets.get(index._hash(q), [])))\n",
    "        print(f\"{bits:>5} {np.mean(recalls):>10.2f} {np.mean(scanned):>8.0f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}