diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5649b46..6dc22fe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,6 +7,7 @@ on: # Push events on main and dev branch - main - dev + - dynamic-lmi # Sequence of patterns matched against refs/tags tags: '*' @@ -113,10 +114,5 @@ jobs: conda install matplotlib pandas scikit-learn pip install h5py setuptools tqdm faiss-cpu pip install torch --index-url https://download.pytorch.org/whl/cpu - - name: Run learned index - shell: bash -el {0} - run: | - conda activate env - pip install --editable . - python3 search/search.py - python3 eval/eval.py + - name: Run tests + run: pytest diff --git a/Development.ipynb b/Development.ipynb new file mode 100644 index 0000000..eac0811 --- /dev/null +++ b/Development.ipynb @@ -0,0 +1,3105 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "4e93aafa-d72b-4660-a13e-1e65a18efbf8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4e8783fd-4590-4665-9beb-85adf879f6ae", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import logging\n", + "import numpy as np\n", + "np.random.seed(2023)\n", + "\n", + "logging.basicConfig(\n", + " level=logging.DEBUG,\n", + " format='[%(asctime)s][%(levelname)-5.5s][%(name)-.20s] %(message)s'\n", + ")\n", + "LOG = logging.getLogger(__name__)" + ] + }, + { + "cell_type": "markdown", + "id": "98460326-c0a2-4309-a147-e7b5e2699746", + "metadata": {}, + "source": [ + "# 1. Load the data\n", + "The data are from SISAP 2023 indexing challenge (LAION dataset). There are `100K`, `300K`, and `10M` versions (also `100M`, but that one wasn't tested with LMI). The queries are not included in the data (they are outside of the dataset)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5a368ad6-f582-4013-b092-1f135740f042", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-03-18 10:58:52,732][DEBUG][h5py._conv] Creating converter from 7 to 5\n", + "[2024-03-18 10:58:52,732][DEBUG][h5py._conv] Creating converter from 5 to 7\n", + "[2024-03-18 10:58:52,732][DEBUG][h5py._conv] Creating converter from 7 to 5\n", + "[2024-03-18 10:58:52,732][DEBUG][h5py._conv] Creating converter from 5 to 7\n" + ] + } + ], + "source": [ + "import os\n", + "from urllib.request import urlretrieve\n", + "from pathlib import Path\n", + "import h5py\n", + "\n", + "def download(src, dst):\n", + " if not os.path.exists(dst):\n", + " os.makedirs(Path(dst).parent, exist_ok=True)\n", + " LOG.info('downloading %s -> %s...' % (src, dst))\n", + " urlretrieve(src, dst)\n", + "\n", + "def prepare(kind, size):\n", + " url = \"https://sisap-23-challenge.s3.amazonaws.com/SISAP23-Challenge\"\n", + " task = {\n", + " \"query\": f\"{url}/public-queries-10k-{kind}.h5\",\n", + " \"dataset\": f\"{url}/laion2B-en-{kind}-n={size}.h5\",\n", + " }\n", + "\n", + " for version, url in task.items():\n", + " target_path = os.path.join(\"data\", kind, size, f\"{version}.h5\")\n", + " download(url, target_path)\n", + " assert os.path.exists(target_path), f\"Failed to download {url}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "30603999-c63c-4b96-b4a0-7902f207baf9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "config = {\n", + " # get the smallest version of the LAION dataset\n", + " 'dataset': 'pca32v2',\n", + " 'emb': 'pca32',\n", + " 'size': '100K',\n", + " # n. of nearest neighbors\n", + " 'k': 10,\n", + " # normalize the data to be able to use K-Means\n", + " 'preprocess': True\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "19ec8167-261d-473b-b363-26727c5caeef", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "((100000, 32), (10000, 32))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# download the data\n", + "prepare(config['dataset'], config['size'])\n", + "\n", + "def get_data(data_part, **config):\n", + " return np.array(\n", + " h5py.File(\n", + " os.path.join(\n", + " 'data',\n", + " config['dataset'],\n", + " config['size'],\n", + " data_part\n", + " ),\n", + " 'r'\n", + " )[config['emb']]\n", + " )\n", + "\n", + "# load the data \n", + "data = get_data(\"dataset.h5\", **config)\n", + "queries = get_data(\"query.h5\", **config)\n", + "data.shape, queries.shape" + ] + }, + { + "cell_type": "markdown", + "id": "a24a39e7-fbc9-404d-84ce-5e31e223b23a", + "metadata": {}, + "source": [ + "## 1.2. Pre-process the data\n", + "The default distance metric for LAION dataset is the cosine distance. In order for us to use K-Means for partitioning (which operates only with Euclidean distances), we need to **normalize the data to unit length** (i.e., a single vector will sum up to 1). Data normalized like this can continue to be used with euclidean distance." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "43066688-f16a-4131-a98b-f1e043d88b74", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.4463985259644687" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# data characteristic before:\n", + "sum(data[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ffcdd826-9e23-4259-8d1d-0d29a5337c03", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from sklearn import preprocessing\n", + "if config['preprocess']:\n", + " data = preprocessing.normalize(data)\n", + " queries = preprocessing.normalize(queries)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "75e923c4-952f-47ec-adff-a13f66f07de2", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1.004468702711165" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# data characteristics after\n", + "sum(data[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b23007bd-4b5e-4096-b1bf-a3d3a688e9e1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "# data to pandas\n", + "data = pd.DataFrame(data)\n", + "# index from one (needed to fit the evaluation procedure later)\n", + "data.index += 1" + ] + }, + { + "cell_type": "markdown", + "id": "f3f9e9bc-c91a-442d-b68f-7cfbaa5a22a1", + "metadata": {}, + "source": [ + "# 2. Build the index" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "83796d0d-04ca-4317-98f2-afe6529a14a6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-03-18 10:58:56,559][INFO ][faiss.loader] Loading faiss with AVX2 support.\n", + "[2024-03-18 10:58:56,559][INFO ][faiss.loader] Could not load library with AVX2 support due to:\n", + "ModuleNotFoundError(\"No module named 'faiss.swigfaiss_avx2'\")\n", + "[2024-03-18 10:58:56,560][INFO ][faiss.loader] Loading faiss.\n", + "[2024-03-18 10:58:56,579][INFO ][faiss.loader] Successfully loaded faiss.\n" + ] + } + ], + "source": [ + "from li.BuildConfiguration import BuildConfiguration\n", + "from li.clustering import algorithms\n", + "from li.LearnedIndexBuilder import LearnedIndexBuilder" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "92b4c129-bd5e-49b3-b19d-349d350a77f8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "n_categories = [10, 10, 10]\n", + "\n", + "build_config = BuildConfiguration(\n", + " # which clustering algorithm to use\n", + " algorithms['faiss_kmeans'],\n", + " # how many epochs to train for\n", + " 200,\n", + " # what model to use (see li/model.py\n", + " 'MLP',\n", + " # what learning rate to use\n", + " 0.01,\n", + " # how many categories at what level to build LMI for\n", + " # 10, 10 results in 100 buckets in total\n", + " n_categories\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "5b7a7740", + "metadata": {}, + "source": [ + "## 1, 2, 3 levels" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a293906e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-03-18 10:58:59,528][DEBUG][li.LearnedIndexBuild] Training the root model.\n", + "[2024-03-18 10:58:59,981][DEBUG][li.LearnedIndexBuild] Epochs: 200, step: 20\n", + "[2024-03-18 10:59:03,845][DEBUG][li.LearnedIndexBuild] Epoch 20 | Loss 1.36070\n", + "[2024-03-18 10:59:07,599][DEBUG][li.LearnedIndexBuild] Epoch 40 | Loss 0.59687\n", + "[2024-03-18 10:59:11,184][DEBUG][li.LearnedIndexBuild] Epoch 60 | Loss 0.50983\n", + "[2024-03-18 10:59:15,137][DEBUG][li.LearnedIndexBuild] Epoch 80 | Loss 0.26845\n", + "[2024-03-18 10:59:18,882][DEBUG][li.LearnedIndexBuild] Epoch 100 | Loss 0.21265\n", + "[2024-03-18 10:59:22,764][DEBUG][li.LearnedIndexBuild] Epoch 120 | Loss 0.17393\n", + "[2024-03-18 10:59:27,404][DEBUG][li.LearnedIndexBuild] Epoch 140 | Loss 0.05081\n", + "[2024-03-18 10:59:31,285][DEBUG][li.LearnedIndexBuild] Epoch 160 | Loss 0.16542\n", + "[2024-03-18 10:59:34,970][DEBUG][li.LearnedIndexBuild] Epoch 180 | Loss 0.29115\n", + "[2024-03-18 10:59:38,699][DEBUG][li.LearnedIndexBuild] Trained the model in: 39.17102861404419\n", + "[2024-03-18 10:59:38,706][DEBUG][li.LearnedIndexBuild] Training [10, 10] internal models.\n", + "[2024-03-18 10:59:38,706][DEBUG][li.LearnedIndexBuild] Training level 1.\n", + " 0%| | 0/10 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0123456789...22232425262728293031
10.5408320.2946480.1703030.300801-0.3340420.0306030.0103700.064329-0.0995020.102232...-0.2351220.021346-0.0322890.133272-0.2486490.0471220.119099-0.1077600.186530-0.043702
2-0.475594-0.244831-0.0889420.2919820.0471020.044693-0.1129340.0982640.037613-0.101401...0.041023-0.023160-0.2412600.064872-0.0966140.0786530.0241110.075076-0.160184-0.300525
\n", + "

2 rows × 32 columns

\n", + "" + ], + "text/plain": [ + " 0 1 2 3 4 5 6 \\\n", + "1 0.540832 0.294648 0.170303 0.300801 -0.334042 0.030603 0.010370 \n", + "2 -0.475594 -0.244831 -0.088942 0.291982 0.047102 0.044693 -0.112934 \n", + "\n", + " 7 8 9 ... 22 23 24 25 \\\n", + "1 0.064329 -0.099502 0.102232 ... -0.235122 0.021346 -0.032289 0.133272 \n", + "2 0.098264 0.037613 -0.101401 ... 0.041023 -0.023160 -0.241260 0.064872 \n", + "\n", + " 26 27 28 29 30 31 \n", + "1 -0.248649 0.047122 0.119099 -0.107760 0.186530 -0.043702 \n", + "2 -0.096614 0.078653 0.024111 0.075076 -0.160184 -0.300525 \n", + "\n", + "[2 rows x 32 columns]" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_to_insert.head(2)" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "44bce059", + "metadata": {}, + "outputs": [], + "source": [ + "n_levels = builder.config.n_levels\n", + "data_prediction: npt.NDArray[np.int64] = np.full(\n", + " (data_to_insert.shape[0], n_levels), fill_value=EMPTY_VALUE, dtype=np.int64\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "43c2680d", + "metadata": {}, + "outputs": [], + "source": [ + "data_prediction[:, 0] = builder.root_model.predict(data_X_to_torch(data_to_insert))" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "4b85efe2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 8, -1, -1],\n", + " [ 9, -1, -1],\n", + " [ 5, -1, -1],\n", + " ...,\n", + " [ 6, -1, -1],\n", + " [ 1, -1, -1],\n", + " [ 6, -1, -1]], dtype=int64)" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_prediction" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "3755aa04", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: use as _generate_internal_node_paths\n", + "def generate_internal_node_paths(\n", + " level: int, n_levels: int, n_categories: List[int]\n", + ") -> List[Tuple]:\n", + " \"\"\"Generates all possible paths to internal nodes at the given `level`.\n", + "\n", + " Parameters\n", + " ----------\n", + " level : int\n", + " Desired level of the internal nodes.\n", + " n_levels : int\n", + " Total number of levels in the index.\n", + " n_categories : List[int]\n", + " Number of categories for each level of the index.\n", + "\n", + " Returns\n", + " -------\n", + " List[Tuple]\n", + " List of all possible paths to internal nodes at the given `level`.\n", + " \"\"\"\n", + " path_combinations = [range(n_categories[lvl]) for lvl in range(level)]\n", + " padding = [[EMPTY_VALUE]] * (n_levels - level)\n", + "\n", + " return list(product(*path_combinations, *padding))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "b9948379", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-03-18 10:10:24,321][INFO ][li.LearnedIndexBuild] Training level 1.\n", + " 0%| | 0/10 [00:00 data/clip768v2/100K/query.h5...\n", + "[2023-10-04 09:59:11,056][INFO ][__main__] downloading https://sisap-23-challenge.s3.amazonaws.com/SISAP23-Challenge/laion2B-en-clip768v2-n=100K.h5 -> data/clip768v2/100K/dataset.h5...\n" + ] + }, + { + "data": { + "text/plain": [ + "((100000, 768), (10000, 768))" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config['dataset'] = 'clip768v2'\n", + "config['emb'] = 'emb'\n", + "\n", + "prepare(config['dataset'], config['size'])\n", + "data_search = get_data(\"dataset.h5\", **config)\n", + "queries_search = get_data(\"query.h5\", **config)\n", + "\n", + "data_search = pd.DataFrame(data_search)\n", + "data_search.index += 1\n", + "\n", + "data_search.shape, queries_search.shape" + ] + }, + { + "cell_type": "markdown", + "id": "fbd59797-8dc9-46c8-bd6c-28f85c933bc2", + "metadata": {}, + "source": [ + "# 4. Search in the index" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "f92c58de-8939-4648-97fb-bca324633ac0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# specify the stop condition\n", + "bucket=10\n", + "# specify the n. of neighbors\n", + "k=10" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "69146a6e-c4fe-4c82-a01c-cc0f83139a6e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-10-04 10:07:00,406][INFO ][li.LearnedIndex.Lear] Precomputed bucket order time: 0.4756293296813965\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 170.06it/s]\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 174.12it/s]\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 177.48it/s]\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 177.95it/s]\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 179.18it/s]\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 180.52it/s]\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 180.18it/s]\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 181.15it/s]\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 181.54it/s]\n", + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 180.81it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 7.44 s, sys: 55.5 ms, total: 7.49 s\n", + "Wall time: 7.75 s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "%%time\n", + "dists, nns, measured_time = li.search(\n", + " # the 'navigation' data\n", + " data_navigation=data,\n", + " queries_navigation=queries,\n", + " # the 'sequential filtering' data\n", + " data_search=data_search,\n", + " queries_search=queries_search,\n", + " # mapping of object -> bucket\n", + " data_prediction=data_prediction,\n", + " # n. of categories present in index\n", + " n_categories=n_categories,\n", + " # stop condition for the search\n", + " n_buckets=bucket,\n", + " # number of nearest neighbors we're interested in\n", + " k=k\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "65440198-8335-430a-9936-e500fa3ab143", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "defaultdict(float,\n", + " {'inference': 0.05044746398925781,\n", + " 'search_within_buckets': 7.249154806137085,\n", + " 'seq_search': 4.535206317901611,\n", + " 'sort': 0.0,\n", + " 'search': 7.75262188911438})" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Time to search (broken down into various search parts)\n", + "measured_time" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "af84d793-a15f-4322-83c9-ae1fc85f1042", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(10000, 10)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# matrix of the nearest neighbors (`k` for each query)\n", + "nns.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "eb0a131a-c02f-4f48-bc9c-b665f8d7dde8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[79172, 15735, 22337, 74173, 41079, 38159, 69015, 92811, 79896,\n", + " 13236],\n", + " [14347, 82848, 79302, 85923, 6016, 67067, 54566, 34591, 11620,\n", + " 53783]], dtype=uint32)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nns[:2]" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "8018f307-21b4-4365-8fca-869f4f3ce97b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(10000, 10)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# matrix of distances to the closest neighbors (`k` for each query)\n", + "dists.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "13aa6923-3fce-4cf2-9694-192204015dae", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.27291209, 0.30623567, 0.3131932 , 0.32404494, 0.33161247,\n", + " 0.33278447, 0.34032881, 0.34535122, 0.35354602, 0.36600691],\n", + " [0.19766825, 0.21139383, 0.22871637, 0.23902297, 0.25272477,\n", + " 0.25969118, 0.2700808 , 0.2767331 , 0.27809215, 0.28464031]])" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dists[:2]" + ] + }, + { + "cell_type": "markdown", + "id": "b53c9269-32cb-43b2-94a7-3cac3765b01d", + "metadata": {}, + "source": [ + "# 5. Evaluate the search performance" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "6a6c721c-925a-4c3e-a290-d0e58c6fa84c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-10-04 10:03:59,376][INFO ][__main__] downloading https://sisap-23-challenge.s3.amazonaws.com/SISAP23-Challenge/laion2B-en-public-gold-standard-v2-100K.h5 -> data/groundtruth-100K.h5...\n" + ] + } + ], + "source": [ + "def get_groundtruth(size=\"100K\"):\n", + " url = f\"https://sisap-23-challenge.s3.amazonaws.com/SISAP23-Challenge/laion2B-en-public-gold-standard-v2-{size}.h5\"\n", + "\n", + " out_fn = os.path.join(\"data\", f\"groundtruth-{size}.h5\")\n", + " download(url, out_fn)\n", + " gt_f = h5py.File(out_fn, \"r\")\n", + " true_I = np.array(gt_f['knns'])\n", + " gt_f.close()\n", + " return true_I\n", + "\n", + "gt = get_groundtruth(config['size'])" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "dbbb8831-8ae8-4dd8-a850-b7147febd6a5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def get_recall(I, gt, k):\n", + " assert k <= I.shape[1]\n", + " assert len(I) == len(gt)\n", + "\n", + " n = len(I)\n", + " recall = 0\n", + " for i in range(n):\n", + " recall += len(set(I[i, :k]) & set(gt[i, :k]))\n", + " return recall / (n * k)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "7df2cb73-e965-4b9b-831d-cd24ca1a59b9", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.87099" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "recall = get_recall(nns, gt, k)\n", + "recall" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/README.md b/README.md index 22eb69f..5e27611 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,8 @@ python3 eval/plot.py res.csv - ~6h of runtime (waries depending on the hardware) # LMI in action -🌐 [**Similarity search in 214M protein structures (AlphaFold DB)**](https://alphafind.fi.muni.cz/search) +- 🌐 [**Similarity search in 1M images**](https://web.lmi.dyn.cloud.e-infra.cz/images) +- 🌐 [**Similarity search in protein structures**](https://staging.proteins.dyn.cloud.e-infra.cz/protein-search) # Publications @@ -82,15 +83,8 @@ python3 eval/plot.py res.csv - [**Mendeley data**](https://data.mendeley.com/datasets/8wp73zxr47/12) > T. Slanináková, M. Antol, J. Ol'ha, V. Kaňa, V. Dohnal, S. Ladra, M. A. Martinez-Prieto: [Reproducible experiments with Learned Metric Index Framework](https://www.sciencedirect.com/science/article/pii/S0306437923000911). Information Systems, Volume 118, September 2023, 102255 (2023) -**"LMI in a large (214M) protein database" (2024):** -- [**Web**](https://alphafind.fi.muni.cz/search) -- [**Repository**](https://github.com/Coda-Research-Group/AlphaFind) -- [**Data**](https://data.narodni-repozitar.cz/general/datasets/d35zf-1ja47) -> PROCHÁZKA, David, Terézia SLANINÁKOVÁ, Jaroslav OĽHA, Adrián ROŠINEC, Katarína GREŠOVÁ, Miriama JÁNOŠOVÁ, Jakub ČILLÍK, Jana PORUBSKÁ, Radka SVOBODOVÁ, Vlastislav DOHNAL a Matej ANTOL.: [AlphaFind: Discover structure similarity across the entire known proteome](https://www.biorxiv.org/content/10.1101/2024.02.15.580465v1). BioRxiv (pre-print version) - ## Team -🔎[**Complex data analysis research group**](https://disa.fi.muni.cz/complex-data-analysis) - [Terézia Slanináková](https://github.com/TerkaSlan), Masaryk University - [David Procházka](https://github.com/ProchazkaDavid), Masaryk University - [Jaroslav Oľha](https://github.com/JaroOlha), Masaryk University diff --git a/search/li/LearnedIndex.py b/search/li/LearnedIndex.py index 09fd33c..acd7e1e 100644 --- a/search/li/LearnedIndex.py +++ b/search/li/LearnedIndex.py @@ -106,7 +106,7 @@ def search( # Search in the `n_buckets` most similar buckets for bucket_order_idx in range(n_buckets): self.logger.debug( - f"Searching in bucket {bucket_order_idx + 1} out of {n_buckets}" + f"Searching in bucket {bucket_order_idx + 1} out of {n_buckets} | bucket order={bucket_order[:, bucket_order_idx, :]}" ) (dists, anns, t_all, t_seq_search, t_sort) = self._search_single_bucket( data_navigation=data_navigation, @@ -348,6 +348,8 @@ def _search_single_bucket( t_sort = 0.0 for path, g in tqdm(data_navigation.groupby(possible_bucket_paths)): + #self.logger.info(f'path: {path}, g: {g}') + bucket_obj_indexes = g.index relevant_query_idxs = filter_path_idxs(bucket_path, path) @@ -355,8 +357,9 @@ def _search_single_bucket( if bucket_obj_indexes.shape[0] != 0 and relevant_query_idxs.shape[0] != 0: queries_for_this_bucket = queries_search[relevant_query_idxs] data_in_this_bucket = data_search.loc[bucket_obj_indexes].to_numpy() - + #self.logger.info(f"queries_for_this_bucket: {queries_for_this_bucket}") s = time.time() + similarity, indices = faiss.knn( queries_for_this_bucket, data_in_this_bucket, diff --git a/search/li/LearnedIndexBuilder.py b/search/li/LearnedIndexBuilder.py index fa33d41..28d5a69 100644 --- a/search/li/LearnedIndexBuilder.py +++ b/search/li/LearnedIndexBuilder.py @@ -105,6 +105,65 @@ def build( time.time() - s, root_cluster_t + internal_cluster_t, ) + + def insert(self, data) -> Tuple[LearnedIndex, npt.NDArray[np.int64], int, float, float]: + """ + Inserts data into the index. + + Parameters + ---------- + data : pd.DataFrame + Data to build the index on. + config : BuildConfiguration + Configuration for the training. + + Returns + ------- + Tuple[npt.NDArray[np.int64], int, float, float] + An array of shape (data.shape[0], len(config.n_levels)) with predicted paths for each data point, + number of buckets, time it took to build the index, time it took to cluster the data. + """ + s = time.time() + + n_levels = self.config.n_levels + + # Where should the training data be placed with respect to each level + data_prediction: npt.NDArray[np.int64] = np.full( + (data.shape[0], n_levels), fill_value=EMPTY_VALUE, dtype=np.int64 + ) + # Extend the data with the new data to be inserted + data = data.reset_index(drop=True) + data.index = data.index + 1 + self.data = pd.concat([self.data, data], ignore_index=True) + + self.logger.debug("Predicting the root model.") + data_prediction[:, 0] = self.root_model.predict(data_X_to_torch(data)) + + if n_levels == 1: + + return ( + #self._create_index(), + data_prediction, + len(self.bucket_paths), + time.time() - s, + ) + + self.logger.debug(f"Predicting {self.config.n_categories[:-1]} internal models.") + s_internal = time.time() + self._insert_internal_models( + data, + data_prediction, + self.config, + ) + self.logger.debug( + f"Trained {self.config.n_categories[:-1]} internal models in {time.time()-s_internal:.2f}s." + ) + + return ( + data_prediction, + len(self.bucket_paths), + time.time() - s + ) def _create_index(self) -> LearnedIndex: """Creates the index from the trained models.""" @@ -279,6 +338,60 @@ def _train_internal_models( return overall_cluster_t + def _insert_internal_models( + self, + data: pd.DataFrame, + data_prediction: npt.NDArray[np.int64], + config: BuildConfiguration, + ): + """ + Predicts the data to be inserted on internal models. + + ! The `data_prediction` array is modified in-place. + + Parameters + ---------- + data : pd.DataFrame + Data to train the models on. + data_prediction : npt.NDArray[np.int64] + Predicted paths for each data point. + config : BuildConfiguration + Configuration for the training. + + Returns + ------- + float + Time it took to cluster the data. + """ + assert ( + self.root_model is not None + ), "The root model is not trained, call `_train_root_model` first." + + for level in range(1, config.n_levels): + internal_node_paths = self._generate_internal_node_paths( + level, config.n_levels, config.n_categories + ) + self.logger.debug(f"Predicting level {level}.") + + for path in tqdm(internal_node_paths): + self.logger.debug(f"Predicting model on path {path}.") + + data_idxs = filter_path_idxs(data_prediction, path) + assert ( + data_idxs.shape[0] != 0 + ), "There are no data points associated with the given path." + + # +1 as the data is indexed from 1 + data_prediction_per_path = data.loc[data_idxs + 1] + + # The subset needs to be reindexed; otherwise, the object accesses are invalid. + original_pd_indices = data_prediction_per_path.index.values + predictions = self.internal_models[path].predict(data_X_to_torch(data_prediction_per_path)) + + # original_pd_indices-1 as data is indexed from 1 + # level as we are predicting the next level but the indexing is 0-based + data_prediction[original_pd_indices - 1, level] = predictions + def _cluster( self, data: pd.DataFrame, diff --git a/search/test_insert.py b/search/test_insert.py new file mode 100644 index 0000000..7daa18d --- /dev/null +++ b/search/test_insert.py @@ -0,0 +1,96 @@ +import pytest +from search import download, prepare +from li.BuildConfiguration import BuildConfiguration +from li.clustering import algorithms +from li.LearnedIndexBuilder import LearnedIndexBuilder + +from sklearn import preprocessing +import pandas as pd + +def get_data(data_part, **config): + return np.array( + h5py.File( + os.path.join( + 'data', + config['dataset'], + config['size'], + data_part + ), + 'r' + )[config['emb']] + ) + + +# Fixture for common setup tasks +@pytest.fixture(scope="session") +def setup(): + # Perform setup tasks here + print("\nSetting up tests...") + + config = { + # get the smallest version of the LAION dataset + 'dataset': 'pca32v2', + 'emb': 'pca32', + 'size': '100K', + # n. of nearest neighbors + 'k': 10, + # normalize the data to be able to use K-Means + 'preprocess': True + } + prepare(config['dataset'], config['size']) + data = get_data("dataset.h5", **config) + queries = get_data("query.h5", **config) + if config['preprocess']: + data = preprocessing.normalize(data) + queries = preprocessing.normalize(queries) + yield data, queries + +# Example tests using the setup fixture +def test_one_level(setup): + data, queries = setup + n_categories = [10] + + build_config = BuildConfiguration( + # which clustering algorithm to use + algorithms['faiss_kmeans'], + # how many epochs to train for + 100, + # what model to use (see li/model.py + 'MLP', + # what learning rate to use + 0.01, + # how many categories at what level to build LMI for + # 10, 10 results in 100 buckets in total + n_categories + ) + dimensionality = data.shape[1] + sample = 1_000 + increment = 100 + + first_build_data = data.iloc[:sample] + builder = LearnedIndexBuilder(first_build_data, build_config) + li, data_prediction, n_buckets_in_index, build_t, cluster_t = builder.build() + assert data_prediction.shape == (sample, 1), "Data prediction shape is not correct: " + str(data_prediction.shape) + assert builder.data.shape == (sample, dimensionality), "Data shape is not correct: " + str(builder.data.shape) + + insert_data = data.iloc[sample:sample+increment] + data_prediction_2, n_buckets_in_index_2, insert_t = builder.insert(insert_data) + assert data_prediction_2.shape == (increment, 1), "Data prediction shape is not correct: " + str(data_prediction.shape) + # 32 + assert builder.data.shape == (sample+increment, dimensionality+len(n_categories)), "Data shape is not correct: " + str(builder.data.shape) + + data_prediction_all = np.vstack((data_prediction, data_prediction_2)) + + n_queries=5 + k=10 + dists, nns, measured_time = li.search( + data_navigation=builder.data, + queries_navigation=queries[:n_queries], + data_search=builder.data[[col for col in builder.data.columns if type(col) is int]], + queries_search=queries[:n_queries], + data_prediction=data_prediction_all, + n_categories=n_categories, + n_buckets=1, + k=k, + ) + assert dists.shape == (n_queries, k), "Dists shape is not correct: " + str(dists.shape) \ No newline at end of file