From e88a0c25fa57175ec6e700c09fd1335562090039 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Mon, 18 Dec 2023 16:38:23 +0000 Subject: [PATCH 01/37] added tfidf added tfidf encoder, edited hybrid layer to let user choose sparse encoder type --- docs/examples/hybrid-layer.ipynb | 133 ++++++++++++++++++++------- poetry.lock | 102 +++++++++++++++++++- pyproject.toml | 1 + semantic_router/encoders/__init__.py | 3 +- semantic_router/encoders/tfidf.py | 33 +++++++ semantic_router/hybrid_layer.py | 33 ++++--- 6 files changed, 258 insertions(+), 47 deletions(-) create mode 100644 semantic_router/encoders/tfidf.py diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 8b1da5ae..89965b4e 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -30,11 +30,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "!pip install -qU semantic-router==0.0.6" + "# !pip install -qU semantic-router==0.0.6" ] }, { @@ -46,18 +46,15 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "metadata": {}, "outputs": [ { - "ename": "ImportError", - "evalue": "cannot import name 'Route' from 'semantic_router.schema' (/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/semantic_router/schema.py)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/Users/jakit/customers/aurelio/semantic-router/docs/examples/hybrid-layer.ipynb Cell 7\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msemantic_router\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mschema\u001b[39;00m \u001b[39mimport\u001b[39;00m Route\n\u001b[1;32m 3\u001b[0m politics \u001b[39m=\u001b[39m Route(\n\u001b[1;32m 4\u001b[0m name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mpolitics\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 5\u001b[0m utterances\u001b[39m=\u001b[39m[\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 12\u001b[0m ],\n\u001b[1;32m 13\u001b[0m )\n", - "\u001b[0;31mImportError\u001b[0m: cannot import name 'Route' from 'semantic_router.schema' (/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/semantic_router/schema.py)" + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/danielgriffiths/Coding_files/Aurelio_local/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -86,21 +83,10 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ - "chitchat = Route(\n", - " name=\"chitchat\",\n", - " utterances=[\n", - " \"how's the weather today?\",\n", - " \"how are things going?\",\n", - " \"lovely weather today\",\n", - " \"the weather is horrendous\",\n", - " \"let's go to the chippy\",\n", - " ],\n", - ")\n", - "\n", "chitchat = Route(\n", " name=\"chitchat\",\n", " utterances=[\n", @@ -124,19 +110,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import os\n", - "from semantic_router.encoders import CohereEncoder\n", + "from semantic_router.encoders import CohereEncoder, BM25Encoder, TfidfEncoder\n", "from getpass import getpass\n", "\n", "os.environ[\"COHERE_API_KEY\"] = os.environ[\"COHERE_API_KEY\"] or getpass(\n", " \"Enter Cohere API Key: \"\n", ")\n", "\n", - "encoder = CohereEncoder()" + "dense_encoder = CohereEncoder()\n", + "# sparse_encoder = BM25Encoder()\n", + "sparse_encoder = TfidfEncoder()" ] }, { @@ -148,33 +136,110 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 2/2 [00:00<00:00, 2.58it/s]\n" + ] + } + ], "source": [ "from semantic_router.hybrid_layer import HybridRouteLayer\n", "\n", - "dl = HybridRouteLayer(encoder=encoder, routes=routes)" + "dl = HybridRouteLayer(dense_encoder=dense_encoder, sparse_encoder=sparse_encoder, routes=routes)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'politics'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "dl(\"don't you love politics?\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'chitchat'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "dl(\"how's the weather today?\")" ] }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "religion = Route(\n", + " name=\"religion\",\n", + " utterances=[\n", + " \"what do you know about Buddhism?\",\n", + " \"tell me about Christianity\",\n", + " \"explain the principles of Hinduism\",\n", + " \"describe the teachings of Islam\",\n", + " \"what are the main beliefs of Judaism?\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "dl.add(religion)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'religion'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl(\"what do you think of Hinduism?\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -199,7 +264,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/poetry.lock b/poetry.lock index f5d58647..b6fff470 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1916,6 +1916,95 @@ files = [ {file = "ruff-0.1.8.tar.gz", hash = "sha256:f7ee467677467526cfe135eab86a40a0e8db43117936ac4f9b469ce9cdb3fb62"}, ] +[[package]] +name = "scikit-learn" +version = "1.3.2" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.8" +files = [ + {file = "scikit-learn-1.3.2.tar.gz", hash = "sha256:a2f54c76accc15a34bfb9066e6c7a56c1e7235dda5762b990792330b52ccfb05"}, + {file = "scikit_learn-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e326c0eb5cf4d6ba40f93776a20e9a7a69524c4db0757e7ce24ba222471ee8a1"}, + {file = "scikit_learn-1.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:535805c2a01ccb40ca4ab7d081d771aea67e535153e35a1fd99418fcedd1648a"}, + {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1215e5e58e9880b554b01187b8c9390bf4dc4692eedeaf542d3273f4785e342c"}, + {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ee107923a623b9f517754ea2f69ea3b62fc898a3641766cb7deb2f2ce450161"}, + {file = "scikit_learn-1.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:35a22e8015048c628ad099da9df5ab3004cdbf81edc75b396fd0cff8699ac58c"}, + {file = "scikit_learn-1.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6fb6bc98f234fda43163ddbe36df8bcde1d13ee176c6dc9b92bb7d3fc842eb66"}, + {file = "scikit_learn-1.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:18424efee518a1cde7b0b53a422cde2f6625197de6af36da0b57ec502f126157"}, + {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3271552a5eb16f208a6f7f617b8cc6d1f137b52c8a1ef8edf547db0259b2c9fb"}, + {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4144a5004a676d5022b798d9e573b05139e77f271253a4703eed295bde0433"}, + {file = "scikit_learn-1.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:67f37d708f042a9b8d59551cf94d30431e01374e00dc2645fa186059c6c5d78b"}, + {file = "scikit_learn-1.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8db94cd8a2e038b37a80a04df8783e09caac77cbe052146432e67800e430c028"}, + {file = "scikit_learn-1.3.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:61a6efd384258789aa89415a410dcdb39a50e19d3d8410bd29be365bcdd512d5"}, + {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb06f8dce3f5ddc5dee1715a9b9f19f20d295bed8e3cd4fa51e1d050347de525"}, + {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b2de18d86f630d68fe1f87af690d451388bb186480afc719e5f770590c2ef6c"}, + {file = "scikit_learn-1.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:0402638c9a7c219ee52c94cbebc8fcb5eb9fe9c773717965c1f4185588ad3107"}, + {file = "scikit_learn-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a19f90f95ba93c1a7f7924906d0576a84da7f3b2282ac3bfb7a08a32801add93"}, + {file = "scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b8692e395a03a60cd927125eef3a8e3424d86dde9b2370d544f0ea35f78a8073"}, + {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15e1e94cc23d04d39da797ee34236ce2375ddea158b10bee3c343647d615581d"}, + {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:785a2213086b7b1abf037aeadbbd6d67159feb3e30263434139c98425e3dcfcf"}, + {file = "scikit_learn-1.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:64381066f8aa63c2710e6b56edc9f0894cc7bf59bd71b8ce5613a4559b6145e0"}, + {file = "scikit_learn-1.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c43290337f7a4b969d207e620658372ba3c1ffb611f8bc2b6f031dc5c6d1d03"}, + {file = "scikit_learn-1.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dc9002fc200bed597d5d34e90c752b74df516d592db162f756cc52836b38fe0e"}, + {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d08ada33e955c54355d909b9c06a4789a729977f165b8bae6f225ff0a60ec4a"}, + {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f0ae4b79b0ff9cca0bf3716bcc9915bdacff3cebea15ec79652d1cc4fa5c9"}, + {file = "scikit_learn-1.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:ed932ea780517b00dae7431e031faae6b49b20eb6950918eb83bd043237950e0"}, +] + +[package.dependencies] +joblib = ">=1.1.1" +numpy = ">=1.17.3,<2.0" +scipy = ">=1.5.0" +threadpoolctl = ">=2.0.0" + +[package.extras] +benchmark = ["matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "pandas (>=1.0.5)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.10.1)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] +examples = ["matplotlib (>=3.1.3)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)"] +tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.0.272)", "scikit-image (>=0.16.2)"] + +[[package]] +name = "scipy" +version = "1.11.4" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scipy-1.11.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc9a714581f561af0848e6b69947fda0614915f072dfd14142ed1bfe1b806710"}, + {file = "scipy-1.11.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:cf00bd2b1b0211888d4dc75656c0412213a8b25e80d73898083f402b50f47e41"}, + {file = "scipy-1.11.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9999c008ccf00e8fbcce1236f85ade5c569d13144f77a1946bef8863e8f6eb4"}, + {file = "scipy-1.11.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:933baf588daa8dc9a92c20a0be32f56d43faf3d1a60ab11b3f08c356430f6e56"}, + {file = "scipy-1.11.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8fce70f39076a5aa62e92e69a7f62349f9574d8405c0a5de6ed3ef72de07f446"}, + {file = "scipy-1.11.4-cp310-cp310-win_amd64.whl", hash = "sha256:6550466fbeec7453d7465e74d4f4b19f905642c89a7525571ee91dd7adabb5a3"}, + {file = "scipy-1.11.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f313b39a7e94f296025e3cffc2c567618174c0b1dde173960cf23808f9fae4be"}, + {file = "scipy-1.11.4-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1b7c3dca977f30a739e0409fb001056484661cb2541a01aba0bb0029f7b68db8"}, + {file = "scipy-1.11.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00150c5eae7b610c32589dda259eacc7c4f1665aedf25d921907f4d08a951b1c"}, + {file = "scipy-1.11.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:530f9ad26440e85766509dbf78edcfe13ffd0ab7fec2560ee5c36ff74d6269ff"}, + {file = "scipy-1.11.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5e347b14fe01003d3b78e196e84bd3f48ffe4c8a7b8a1afbcb8f5505cb710993"}, + {file = "scipy-1.11.4-cp311-cp311-win_amd64.whl", hash = "sha256:acf8ed278cc03f5aff035e69cb511741e0418681d25fbbb86ca65429c4f4d9cd"}, + {file = "scipy-1.11.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:028eccd22e654b3ea01ee63705681ee79933652b2d8f873e7949898dda6d11b6"}, + {file = "scipy-1.11.4-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c6ff6ef9cc27f9b3db93a6f8b38f97387e6e0591600369a297a50a8e96e835d"}, + {file = "scipy-1.11.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b030c6674b9230d37c5c60ab456e2cf12f6784596d15ce8da9365e70896effc4"}, + {file = "scipy-1.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad669df80528aeca5f557712102538f4f37e503f0c5b9541655016dd0932ca79"}, + {file = "scipy-1.11.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ce7fff2e23ab2cc81ff452a9444c215c28e6305f396b2ba88343a567feec9660"}, + {file = "scipy-1.11.4-cp312-cp312-win_amd64.whl", hash = "sha256:36750b7733d960d7994888f0d148d31ea3017ac15eef664194b4ef68d36a4a97"}, + {file = "scipy-1.11.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6e619aba2df228a9b34718efb023966da781e89dd3d21637b27f2e54db0410d7"}, + {file = "scipy-1.11.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:f3cd9e7b3c2c1ec26364856f9fbe78695fe631150f94cd1c22228456404cf1ec"}, + {file = "scipy-1.11.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d10e45a6c50211fe256da61a11c34927c68f277e03138777bdebedd933712fea"}, + {file = "scipy-1.11.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91af76a68eeae0064887a48e25c4e616fa519fa0d38602eda7e0f97d65d57937"}, + {file = "scipy-1.11.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6df1468153a31cf55ed5ed39647279beb9cfb5d3f84369453b49e4b8502394fd"}, + {file = "scipy-1.11.4-cp39-cp39-win_amd64.whl", hash = "sha256:ee410e6de8f88fd5cf6eadd73c135020bfbbbdfcd0f6162c36a7638a1ea8cc65"}, + {file = "scipy-1.11.4.tar.gz", hash = "sha256:90a2b78e7f5733b9de748f589f09225013685f9b218275257f8a8168ededaeaa"}, +] + +[package.dependencies] +numpy = ">=1.21.6,<1.28.0" + +[package.extras] +dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] +test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + [[package]] name = "six" version = "1.16.0" @@ -1957,6 +2046,17 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "threadpoolctl" +version = "3.2.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.2.0-py3-none-any.whl", hash = "sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032"}, + {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"}, +] + [[package]] name = "tokenize-rt" version = "5.2.0" @@ -2203,4 +2303,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "3300c77d6b6fab3faca403e2f3064a23e9f5ddcd34d63cad42d39b94b1ae5c2b" +content-hash = "7e705f5c5f2a8bba630031c0ff6752972e7cddc8ec95f3fb05b5be2ad7962268" diff --git a/pyproject.toml b/pyproject.toml index 030a8f72..9f42964b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ cohere = "^4.32" numpy = "^1.25.2" pinecone-text = "^0.7.0" colorlog = "^6.8.0" +scikit-learn = "^1.3.2" [tool.poetry.group.dev.dependencies] diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 30ad624a..2769b31d 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -2,5 +2,6 @@ from .bm25 import BM25Encoder from .cohere import CohereEncoder from .openai import OpenAIEncoder +from .tfidf import TfidfEncoder -__all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder"] +__all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder", "TfidfEncoder"] diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py new file mode 100644 index 00000000..41487a06 --- /dev/null +++ b/semantic_router/encoders/tfidf.py @@ -0,0 +1,33 @@ +from typing import Any +from sklearn.feature_extraction.text import TfidfVectorizer +from semantic_router.encoders import BaseEncoder +from semantic_router.schema import Route + +class TfidfEncoder(BaseEncoder): + vectorizer: TfidfVectorizer | None = None + + def __init__(self, name: str = "tfidf"): + super().__init__(name=name) + self.vectorizer = TfidfVectorizer() + + def __call__(self, docs: list[str]) -> list[list[float]]: + if self.vectorizer is None: + raise ValueError("Vectorizer is not initialized.") + if len(docs) == 0: + raise ValueError("No documents to encode.") + + embeds = self.vectorizer.transform(docs).toarray() + return embeds.tolist() + + def fit(self, routes: list[Route]): + if self.vectorizer is None: + raise ValueError("Vectorizer is not initialized.") + docs = self._get_all_utterances(routes) + self.vectorizer.fit(docs) + + def _get_all_utterances(self, routes: list[Route]) -> list[str]: + utterances = [] + for route in routes: + for utterance in route.utterances: + utterances.append(utterance) + return utterances \ No newline at end of file diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index dec6336e..a68472d3 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -19,19 +19,22 @@ class HybridRouteLayer: score_threshold = 0.82 def __init__( - self, encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3 + self, dense_encoder: BaseEncoder, sparse_encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3 ): - self.encoder = encoder - self.sparse_encoder = BM25Encoder() + self.dense_encoder = dense_encoder + self.sparse_encoder = sparse_encoder self.alpha = alpha + self.routes = routes # decide on default threshold based on encoder - if isinstance(encoder, OpenAIEncoder): + if isinstance(dense_encoder, OpenAIEncoder): self.score_threshold = 0.82 - elif isinstance(encoder, CohereEncoder): + elif isinstance(dense_encoder, CohereEncoder): self.score_threshold = 0.3 else: self.score_threshold = 0.82 # if routes list has been passed, we initialize index now + if self.sparse_encoder.name == 'tfidf': + self.sparse_encoder.fit(routes) if routes: # initialize index now for route in tqdm(routes): @@ -47,15 +50,18 @@ def __call__(self, text: str) -> str | None: return None def add(self, route: Route): + if self.sparse_encoder.name == 'tfidf': + self.sparse_encoder.fit(self.routes + [route]) + self.sparse_index = None + for r in self.routes: + self.calculate_sparse_embeds(r) + self.routes.append(route) self._add_route(route=route) def _add_route(self, route: Route): # create embeddings - dense_embeds = np.array(self.encoder(route.utterances)) # * self.alpha - sparse_embeds = np.array( - self.sparse_encoder(route.utterances) - ) # * (1 - self.alpha) - + dense_embeds = np.array(self.dense_encoder(route.utterances)) # * self.alpha + self.compute_and_store_sparse_embeddings(route) # create route array if self.categories is None: self.categories = np.array([route.name] * len(route.utterances)) @@ -71,6 +77,11 @@ def _add_route(self, route: Route): self.index = dense_embeds else: self.index = np.concatenate([self.index, dense_embeds]) + + def compute_and_store_sparse_embeddings(self, route: Route): + sparse_embeds = np.array( + self.sparse_encoder(route.utterances) + ) # * (1 - self.alpha) # create sparse utterance array if self.sparse_index is None: self.sparse_index = sparse_embeds @@ -82,7 +93,7 @@ def _query(self, text: str, top_k: int = 5): retrieve the top_k most similar records. """ # create dense query vector - xq_d = np.array(self.encoder([text])) + xq_d = np.array(self.dense_encoder([text])) xq_d = np.squeeze(xq_d) # Reduce to 1d array. # create sparse query vector xq_s = np.array(self.sparse_encoder([text])) From 4394759e4b96969544a723a14f169e1d8e82e8f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Mon, 18 Dec 2023 19:08:40 +0000 Subject: [PATCH 02/37] update sparse_encoder if statements to type update sparse_encoder if statements to type rather than name --- semantic_router/hybrid_layer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index a68472d3..33a3269f 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -7,6 +7,7 @@ BM25Encoder, CohereEncoder, OpenAIEncoder, + TfidfEncoder ) from semantic_router.schema import Route from semantic_router.utils.logger import logger @@ -33,7 +34,7 @@ def __init__( else: self.score_threshold = 0.82 # if routes list has been passed, we initialize index now - if self.sparse_encoder.name == 'tfidf': + if isinstance(sparse_encoder, TfidfEncoder): self.sparse_encoder.fit(routes) if routes: # initialize index now @@ -50,11 +51,11 @@ def __call__(self, text: str) -> str | None: return None def add(self, route: Route): - if self.sparse_encoder.name == 'tfidf': + if isinstance(self.sparse_encoder, TfidfEncoder): self.sparse_encoder.fit(self.routes + [route]) self.sparse_index = None for r in self.routes: - self.calculate_sparse_embeds(r) + self.compute_and_store_sparse_embeddings(r) self.routes.append(route) self._add_route(route=route) From 4d3ba4d387ff56c7ea9951f1754f593e6a60ea0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Tue, 19 Dec 2023 10:36:51 +0000 Subject: [PATCH 03/37] fixed tests --- docs/examples/hybrid-layer.ipynb | 6 ++++-- poetry.lock | 2 +- semantic_router/encoders/__init__.py | 8 +++++++- semantic_router/encoders/tfidf.py | 7 ++++--- semantic_router/hybrid_layer.py | 8 ++++++-- tests/unit/test_hybrid_layer.py | 20 +++++++++++++++++--- 6 files changed, 39 insertions(+), 12 deletions(-) diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 89965b4e..1257e0a1 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -143,14 +143,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:00<00:00, 2.58it/s]\n" + "100%|██████████| 2/2 [00:00<00:00, 4.22it/s]\n" ] } ], "source": [ "from semantic_router.hybrid_layer import HybridRouteLayer\n", "\n", - "dl = HybridRouteLayer(dense_encoder=dense_encoder, sparse_encoder=sparse_encoder, routes=routes)" + "dl = HybridRouteLayer(\n", + " dense_encoder=dense_encoder, sparse_encoder=sparse_encoder, routes=routes\n", + ")" ] }, { diff --git a/poetry.lock b/poetry.lock index b6fff470..e6d799b3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2302,5 +2302,5 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" -python-versions = "^3.10" +python-versions = "^3.9" content-hash = "7e705f5c5f2a8bba630031c0ff6752972e7cddc8ec95f3fb05b5be2ad7962268" diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 2769b31d..c2bde1e5 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -4,4 +4,10 @@ from .openai import OpenAIEncoder from .tfidf import TfidfEncoder -__all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder", "TfidfEncoder"] +__all__ = [ + "BaseEncoder", + "CohereEncoder", + "OpenAIEncoder", + "BM25Encoder", + "TfidfEncoder", +] diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 41487a06..ea7a7726 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -3,6 +3,7 @@ from semantic_router.encoders import BaseEncoder from semantic_router.schema import Route + class TfidfEncoder(BaseEncoder): vectorizer: TfidfVectorizer | None = None @@ -28,6 +29,6 @@ def fit(self, routes: list[Route]): def _get_all_utterances(self, routes: list[Route]) -> list[str]: utterances = [] for route in routes: - for utterance in route.utterances: - utterances.append(utterance) - return utterances \ No newline at end of file + for utterance in route.utterances: + utterances.append(utterance) + return utterances diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 33a3269f..8dfedb14 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -7,7 +7,7 @@ BM25Encoder, CohereEncoder, OpenAIEncoder, - TfidfEncoder + TfidfEncoder, ) from semantic_router.schema import Route from semantic_router.utils.logger import logger @@ -20,7 +20,11 @@ class HybridRouteLayer: score_threshold = 0.82 def __init__( - self, dense_encoder: BaseEncoder, sparse_encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3 + self, + dense_encoder: BaseEncoder, + sparse_encoder: BaseEncoder, + routes: list[Route] = [], + alpha: float = 0.3, ): self.dense_encoder = dense_encoder self.sparse_encoder = sparse_encoder diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 94720cd8..a9d35ea5 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -1,6 +1,12 @@ import pytest -from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder +from semantic_router.encoders import ( + BaseEncoder, + CohereEncoder, + OpenAIEncoder, + TfidfEncoder, + BM25Encoder, +) from semantic_router.hybrid_layer import HybridRouteLayer from semantic_router.schema import Route @@ -34,6 +40,12 @@ def openai_encoder(mocker): return OpenAIEncoder(name="test-openai-encoder", openai_api_key="test_api_key") +@pytest.fixture +def bm25_encoder(mocker): + mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call) + return BM25Encoder(name="test-bm25-encoder") + + @pytest.fixture def routes(): return [ @@ -73,8 +85,10 @@ def test_add_multiple_routes(self, openai_encoder, routes): assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 - def test_query_and_classification(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + def test_query_and_classification(self, openai_encoder, bm25_encoder, routes): + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes + ) query_result = route_layer("Hello") assert query_result in ["Route 1", "Route 2"] From 724480cb4eb175dd241f75502c7efefc26e8d594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Tue, 19 Dec 2023 13:43:07 +0000 Subject: [PATCH 04/37] linter used, tests changed to include sparse and dense encoders --- semantic_router/encoders/tfidf.py | 1 - semantic_router/hybrid_layer.py | 1 - tests/unit/test_hybrid_layer.py | 44 +++++++++++++++++-------------- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index ea7a7726..226e9dd0 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -1,4 +1,3 @@ -from typing import Any from sklearn.feature_extraction.text import TfidfVectorizer from semantic_router.encoders import BaseEncoder from semantic_router.schema import Route diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 8dfedb14..3993ca45 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -4,7 +4,6 @@ from semantic_router.encoders import ( BaseEncoder, - BM25Encoder, CohereEncoder, OpenAIEncoder, TfidfEncoder, diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index a9d35ea5..0a5dba6c 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -4,8 +4,8 @@ BaseEncoder, CohereEncoder, OpenAIEncoder, - TfidfEncoder, BM25Encoder, + TfidfEncoder, ) from semantic_router.hybrid_layer import HybridRouteLayer from semantic_router.schema import Route @@ -45,6 +45,10 @@ def bm25_encoder(mocker): mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call) return BM25Encoder(name="test-bm25-encoder") +@pytest.fixture +def tfidf_encoder(mocker): + mocker.patch.object(TfidfEncoder, "__call__", side_effect=mock_encoder_call) + return TfidfEncoder(name="test-tfidf-encoder") @pytest.fixture def routes(): @@ -55,30 +59,30 @@ def routes(): class TestHybridRouteLayer: - def test_initialization(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + def test_initialization(self, openai_encoder, bm25_encoder, routes): + route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes) assert route_layer.index is not None and route_layer.categories is not None assert route_layer.score_threshold == 0.82 assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 - def test_initialization_different_encoders(self, cohere_encoder, openai_encoder): - route_layer_cohere = HybridRouteLayer(encoder=cohere_encoder) + def test_initialization_different_encoders(self, cohere_encoder, openai_encoder, bm25_encoder): + route_layer_cohere = HybridRouteLayer(dense_encoder=cohere_encoder, sparse_encoder=bm25_encoder) assert route_layer_cohere.score_threshold == 0.3 - route_layer_openai = HybridRouteLayer(encoder=openai_encoder) + route_layer_openai = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) assert route_layer_openai.score_threshold == 0.82 - def test_add_route(self, openai_encoder): - route_layer = HybridRouteLayer(encoder=openai_encoder) + def test_add_route(self, openai_encoder, bm25_encoder): + route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) route = Route(name="Route 3", utterances=["Yes", "No"]) route_layer.add(route) assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 2 assert len(set(route_layer.categories)) == 1 - def test_add_multiple_routes(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder) + def test_add_multiple_routes(self, openai_encoder, bm25_encoder, routes): + route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) for route in routes: route_layer.add(route) assert route_layer.index is not None and route_layer.categories is not None @@ -92,12 +96,12 @@ def test_query_and_classification(self, openai_encoder, bm25_encoder, routes): query_result = route_layer("Hello") assert query_result in ["Route 1", "Route 2"] - def test_query_with_no_index(self, openai_encoder): - route_layer = HybridRouteLayer(encoder=openai_encoder) + def test_query_with_no_index(self, openai_encoder, bm25_encoder): + route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) assert route_layer("Anything") is None - def test_semantic_classify(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + def test_semantic_classify(self, openai_encoder, bm25_encoder, routes): + route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -107,8 +111,8 @@ def test_semantic_classify(self, openai_encoder, routes): assert classification == "Route 1" assert score == [0.9] - def test_semantic_classify_multiple_routes(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + def test_semantic_classify_multiple_routes(self, openai_encoder, bm25_encoder, routes): + route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -119,13 +123,13 @@ def test_semantic_classify_multiple_routes(self, openai_encoder, routes): assert classification == "Route 1" assert score == [0.9, 0.8] - def test_pass_threshold(self, openai_encoder): - route_layer = HybridRouteLayer(encoder=openai_encoder) + def test_pass_threshold(self, openai_encoder, bm25_encoder): + route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) assert not route_layer._pass_threshold([], 0.5) assert route_layer._pass_threshold([0.6, 0.7], 0.5) - def test_failover_score_threshold(self, base_encoder): - route_layer = HybridRouteLayer(encoder=base_encoder) + def test_failover_score_threshold(self, base_encoder, bm25_encoder): + route_layer = HybridRouteLayer(dense_encoder=base_encoder, sparse_encoder=bm25_encoder) assert route_layer.score_threshold == 0.82 From 99a0983fe9f705b7202f93704714f2697c61790a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 20 Dec 2023 20:33:20 +0000 Subject: [PATCH 05/37] removed sci-kit learn, added tests --- coverage.xml | 113 +++++++++++++++++++++--------- docs/examples/hybrid-layer.ipynb | 2 +- poetry.lock | 104 +-------------------------- pyproject.toml | 1 - semantic_router/encoders/tfidf.py | 59 ++++++++++++---- tests/unit/encoders/test_tfidf.py | 61 ++++++++++++++++ tests/unit/test_hybrid_layer.py | 50 +++++++++---- 7 files changed, 224 insertions(+), 166 deletions(-) create mode 100644 tests/unit/encoders/test_tfidf.py diff --git a/coverage.xml b/coverage.xml index 9af9ebee..9c68d3d9 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,12 +1,12 @@ - - + + - /Users/jakit/customers/aurelio/semantic-router/semantic_router + /Users/danielgriffiths/Coding_files/Aurelio_local/semantic-router/semantic_router - + @@ -16,7 +16,7 @@ - + @@ -31,84 +31,95 @@ - - - + - + - + + - - - + + + + + + - + - + + - + + - + + - - + - - - - + - - + + - + - - - + + + + + - - + + + + + + + + + + @@ -250,7 +261,7 @@ - + @@ -259,7 +270,8 @@ - + + @@ -274,7 +286,7 @@ - + @@ -298,7 +310,7 @@ - + @@ -387,6 +399,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 1257e0a1..9c0a02fc 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -143,7 +143,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:00<00:00, 4.22it/s]\n" + "100%|██████████| 2/2 [00:00<00:00, 3.41it/s]\n" ] } ], diff --git a/poetry.lock b/poetry.lock index 5d9fc23d..0a9be4e1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1935,95 +1935,6 @@ files = [ {file = "ruff-0.1.8.tar.gz", hash = "sha256:f7ee467677467526cfe135eab86a40a0e8db43117936ac4f9b469ce9cdb3fb62"}, ] -[[package]] -name = "scikit-learn" -version = "1.3.2" -description = "A set of python modules for machine learning and data mining" -optional = false -python-versions = ">=3.8" -files = [ - {file = "scikit-learn-1.3.2.tar.gz", hash = "sha256:a2f54c76accc15a34bfb9066e6c7a56c1e7235dda5762b990792330b52ccfb05"}, - {file = "scikit_learn-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e326c0eb5cf4d6ba40f93776a20e9a7a69524c4db0757e7ce24ba222471ee8a1"}, - {file = "scikit_learn-1.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:535805c2a01ccb40ca4ab7d081d771aea67e535153e35a1fd99418fcedd1648a"}, - {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1215e5e58e9880b554b01187b8c9390bf4dc4692eedeaf542d3273f4785e342c"}, - {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ee107923a623b9f517754ea2f69ea3b62fc898a3641766cb7deb2f2ce450161"}, - {file = "scikit_learn-1.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:35a22e8015048c628ad099da9df5ab3004cdbf81edc75b396fd0cff8699ac58c"}, - {file = "scikit_learn-1.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6fb6bc98f234fda43163ddbe36df8bcde1d13ee176c6dc9b92bb7d3fc842eb66"}, - {file = "scikit_learn-1.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:18424efee518a1cde7b0b53a422cde2f6625197de6af36da0b57ec502f126157"}, - {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3271552a5eb16f208a6f7f617b8cc6d1f137b52c8a1ef8edf547db0259b2c9fb"}, - {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4144a5004a676d5022b798d9e573b05139e77f271253a4703eed295bde0433"}, - {file = "scikit_learn-1.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:67f37d708f042a9b8d59551cf94d30431e01374e00dc2645fa186059c6c5d78b"}, - {file = "scikit_learn-1.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8db94cd8a2e038b37a80a04df8783e09caac77cbe052146432e67800e430c028"}, - {file = "scikit_learn-1.3.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:61a6efd384258789aa89415a410dcdb39a50e19d3d8410bd29be365bcdd512d5"}, - {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb06f8dce3f5ddc5dee1715a9b9f19f20d295bed8e3cd4fa51e1d050347de525"}, - {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b2de18d86f630d68fe1f87af690d451388bb186480afc719e5f770590c2ef6c"}, - {file = "scikit_learn-1.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:0402638c9a7c219ee52c94cbebc8fcb5eb9fe9c773717965c1f4185588ad3107"}, - {file = "scikit_learn-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a19f90f95ba93c1a7f7924906d0576a84da7f3b2282ac3bfb7a08a32801add93"}, - {file = "scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b8692e395a03a60cd927125eef3a8e3424d86dde9b2370d544f0ea35f78a8073"}, - {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15e1e94cc23d04d39da797ee34236ce2375ddea158b10bee3c343647d615581d"}, - {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:785a2213086b7b1abf037aeadbbd6d67159feb3e30263434139c98425e3dcfcf"}, - {file = "scikit_learn-1.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:64381066f8aa63c2710e6b56edc9f0894cc7bf59bd71b8ce5613a4559b6145e0"}, - {file = "scikit_learn-1.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c43290337f7a4b969d207e620658372ba3c1ffb611f8bc2b6f031dc5c6d1d03"}, - {file = "scikit_learn-1.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dc9002fc200bed597d5d34e90c752b74df516d592db162f756cc52836b38fe0e"}, - {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d08ada33e955c54355d909b9c06a4789a729977f165b8bae6f225ff0a60ec4a"}, - {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f0ae4b79b0ff9cca0bf3716bcc9915bdacff3cebea15ec79652d1cc4fa5c9"}, - {file = "scikit_learn-1.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:ed932ea780517b00dae7431e031faae6b49b20eb6950918eb83bd043237950e0"}, -] - -[package.dependencies] -joblib = ">=1.1.1" -numpy = ">=1.17.3,<2.0" -scipy = ">=1.5.0" -threadpoolctl = ">=2.0.0" - -[package.extras] -benchmark = ["matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "pandas (>=1.0.5)"] -docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.10.1)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] -examples = ["matplotlib (>=3.1.3)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)"] -tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.0.272)", "scikit-image (>=0.16.2)"] - -[[package]] -name = "scipy" -version = "1.11.4" -description = "Fundamental algorithms for scientific computing in Python" -optional = false -python-versions = ">=3.9" -files = [ - {file = "scipy-1.11.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc9a714581f561af0848e6b69947fda0614915f072dfd14142ed1bfe1b806710"}, - {file = "scipy-1.11.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:cf00bd2b1b0211888d4dc75656c0412213a8b25e80d73898083f402b50f47e41"}, - {file = "scipy-1.11.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9999c008ccf00e8fbcce1236f85ade5c569d13144f77a1946bef8863e8f6eb4"}, - {file = "scipy-1.11.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:933baf588daa8dc9a92c20a0be32f56d43faf3d1a60ab11b3f08c356430f6e56"}, - {file = "scipy-1.11.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8fce70f39076a5aa62e92e69a7f62349f9574d8405c0a5de6ed3ef72de07f446"}, - {file = "scipy-1.11.4-cp310-cp310-win_amd64.whl", hash = "sha256:6550466fbeec7453d7465e74d4f4b19f905642c89a7525571ee91dd7adabb5a3"}, - {file = "scipy-1.11.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f313b39a7e94f296025e3cffc2c567618174c0b1dde173960cf23808f9fae4be"}, - {file = "scipy-1.11.4-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1b7c3dca977f30a739e0409fb001056484661cb2541a01aba0bb0029f7b68db8"}, - {file = "scipy-1.11.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00150c5eae7b610c32589dda259eacc7c4f1665aedf25d921907f4d08a951b1c"}, - {file = "scipy-1.11.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:530f9ad26440e85766509dbf78edcfe13ffd0ab7fec2560ee5c36ff74d6269ff"}, - {file = "scipy-1.11.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5e347b14fe01003d3b78e196e84bd3f48ffe4c8a7b8a1afbcb8f5505cb710993"}, - {file = "scipy-1.11.4-cp311-cp311-win_amd64.whl", hash = "sha256:acf8ed278cc03f5aff035e69cb511741e0418681d25fbbb86ca65429c4f4d9cd"}, - {file = "scipy-1.11.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:028eccd22e654b3ea01ee63705681ee79933652b2d8f873e7949898dda6d11b6"}, - {file = "scipy-1.11.4-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c6ff6ef9cc27f9b3db93a6f8b38f97387e6e0591600369a297a50a8e96e835d"}, - {file = "scipy-1.11.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b030c6674b9230d37c5c60ab456e2cf12f6784596d15ce8da9365e70896effc4"}, - {file = "scipy-1.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad669df80528aeca5f557712102538f4f37e503f0c5b9541655016dd0932ca79"}, - {file = "scipy-1.11.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ce7fff2e23ab2cc81ff452a9444c215c28e6305f396b2ba88343a567feec9660"}, - {file = "scipy-1.11.4-cp312-cp312-win_amd64.whl", hash = "sha256:36750b7733d960d7994888f0d148d31ea3017ac15eef664194b4ef68d36a4a97"}, - {file = "scipy-1.11.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6e619aba2df228a9b34718efb023966da781e89dd3d21637b27f2e54db0410d7"}, - {file = "scipy-1.11.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:f3cd9e7b3c2c1ec26364856f9fbe78695fe631150f94cd1c22228456404cf1ec"}, - {file = "scipy-1.11.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d10e45a6c50211fe256da61a11c34927c68f277e03138777bdebedd933712fea"}, - {file = "scipy-1.11.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91af76a68eeae0064887a48e25c4e616fa519fa0d38602eda7e0f97d65d57937"}, - {file = "scipy-1.11.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6df1468153a31cf55ed5ed39647279beb9cfb5d3f84369453b49e4b8502394fd"}, - {file = "scipy-1.11.4-cp39-cp39-win_amd64.whl", hash = "sha256:ee410e6de8f88fd5cf6eadd73c135020bfbbbdfcd0f6162c36a7638a1ea8cc65"}, - {file = "scipy-1.11.4.tar.gz", hash = "sha256:90a2b78e7f5733b9de748f589f09225013685f9b218275257f8a8168ededaeaa"}, -] - -[package.dependencies] -numpy = ">=1.21.6,<1.28.0" - -[package.extras] -dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] -doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] -test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] - [[package]] name = "six" version = "1.16.0" @@ -2065,17 +1976,6 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] -[[package]] -name = "threadpoolctl" -version = "3.2.0" -description = "threadpoolctl" -optional = false -python-versions = ">=3.8" -files = [ - {file = "threadpoolctl-3.2.0-py3-none-any.whl", hash = "sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032"}, - {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"}, -] - [[package]] name = "tokenize-rt" version = "5.2.0" @@ -2322,4 +2222,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "7e705f5c5f2a8bba630031c0ff6752972e7cddc8ec95f3fb05b5be2ad7962268" \ No newline at end of file +content-hash = "f2735c243faa3d788c0f6268d6cb550648ed0d1fffec27a084344dafa4590a80" diff --git a/pyproject.toml b/pyproject.toml index c491ed1c..e45e5f17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ cohere = "^4.32" numpy = "^1.25.2" pinecone-text = "^0.7.0" colorlog = "^6.8.0" -scikit-learn = "^1.3.2" [tool.poetry.group.dev.dependencies] diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 226e9dd0..e7c5782f 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -1,33 +1,62 @@ -from sklearn.feature_extraction.text import TfidfVectorizer +import numpy as np +from collections import Counter from semantic_router.encoders import BaseEncoder from semantic_router.schema import Route +from numpy.linalg import norm class TfidfEncoder(BaseEncoder): - vectorizer: TfidfVectorizer | None = None + idf: dict | None = None + word_index: dict | None = None def __init__(self, name: str = "tfidf"): super().__init__(name=name) - self.vectorizer = TfidfVectorizer() + self.word_index = None + self.idf = None def __call__(self, docs: list[str]) -> list[list[float]]: - if self.vectorizer is None: + if self.word_index is None or self.idf is None: raise ValueError("Vectorizer is not initialized.") if len(docs) == 0: raise ValueError("No documents to encode.") - embeds = self.vectorizer.transform(docs).toarray() - return embeds.tolist() + tf = self._compute_tf(docs) + tfidf = tf * self.idf + return tfidf.tolist() def fit(self, routes: list[Route]): - if self.vectorizer is None: - raise ValueError("Vectorizer is not initialized.") - docs = self._get_all_utterances(routes) - self.vectorizer.fit(docs) - - def _get_all_utterances(self, routes: list[Route]) -> list[str]: - utterances = [] + docs = [] for route in routes: for utterance in route.utterances: - utterances.append(utterance) - return utterances + docs.append(utterance) + self.word_index = self._build_word_index(docs) + self.idf = self._compute_idf(docs) + + def _build_word_index(self, docs: list[str]) -> dict: + words = set() + for doc in docs: + for word in doc.split(): + words.add(word) + word_index = {word: i for i, word in enumerate(words)} + return word_index + + def _compute_tf(self, docs: list[str]) -> np.ndarray: + tf = np.zeros((len(docs), len(self.word_index))) + for i, doc in enumerate(docs): + word_counts = Counter(doc.split()) + for word, count in word_counts.items(): + if word in self.word_index: + tf[i, self.word_index[word]] = count + # L2 normalization + tf = tf / norm(tf, axis=1, keepdims=True) + return tf + + def _compute_idf(self, docs: list[str]) -> np.ndarray: + idf = np.zeros(len(self.word_index)) + for doc in docs: + words = set(doc.split()) + for word in words: + if word in self.word_index: + idf[self.word_index[word]] += 1 + idf = np.log(len(docs) / (idf + 1)) + return idf diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py new file mode 100644 index 00000000..93a96639 --- /dev/null +++ b/tests/unit/encoders/test_tfidf.py @@ -0,0 +1,61 @@ +import pytest +from semantic_router.encoders import TfidfEncoder +from semantic_router.schema import Route + + +@pytest.fixture +def tfidf_encoder(): + return TfidfEncoder() + + +class TestTfidfEncoder: + def test_initialization(self, tfidf_encoder): + assert tfidf_encoder.word_index is None + assert tfidf_encoder.idf is None + + def test_fit(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + assert tfidf_encoder.word_index is not None + assert tfidf_encoder.idf is not None + + def test_call_method(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + result = tfidf_encoder(["test"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(sublist, list) for sublist in result + ), "Each item in result should be a list" + + def test_call_method_no_docs(self, tfidf_encoder): + with pytest.raises(ValueError): + tfidf_encoder([]) + + def test_call_method_no_word(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + result = tfidf_encoder(["doc with fake word gta5jabcxyz"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(sublist, list) for sublist in result + ), "Each item in result should be a list" + + def test_call_method_with_uninitialized_model(self, tfidf_encoder): + with pytest.raises(ValueError): + tfidf_encoder(["test"]) diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 0a5dba6c..ee7d8f6b 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -45,11 +45,13 @@ def bm25_encoder(mocker): mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call) return BM25Encoder(name="test-bm25-encoder") + @pytest.fixture def tfidf_encoder(mocker): mocker.patch.object(TfidfEncoder, "__call__", side_effect=mock_encoder_call) return TfidfEncoder(name="test-tfidf-encoder") + @pytest.fixture def routes(): return [ @@ -60,21 +62,31 @@ def routes(): class TestHybridRouteLayer: def test_initialization(self, openai_encoder, bm25_encoder, routes): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes + ) assert route_layer.index is not None and route_layer.categories is not None assert route_layer.score_threshold == 0.82 assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 - def test_initialization_different_encoders(self, cohere_encoder, openai_encoder, bm25_encoder): - route_layer_cohere = HybridRouteLayer(dense_encoder=cohere_encoder, sparse_encoder=bm25_encoder) + def test_initialization_different_encoders( + self, cohere_encoder, openai_encoder, bm25_encoder + ): + route_layer_cohere = HybridRouteLayer( + dense_encoder=cohere_encoder, sparse_encoder=bm25_encoder + ) assert route_layer_cohere.score_threshold == 0.3 - route_layer_openai = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) + route_layer_openai = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder + ) assert route_layer_openai.score_threshold == 0.82 def test_add_route(self, openai_encoder, bm25_encoder): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder + ) route = Route(name="Route 3", utterances=["Yes", "No"]) route_layer.add(route) assert route_layer.index is not None and route_layer.categories is not None @@ -82,7 +94,9 @@ def test_add_route(self, openai_encoder, bm25_encoder): assert len(set(route_layer.categories)) == 1 def test_add_multiple_routes(self, openai_encoder, bm25_encoder, routes): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder + ) for route in routes: route_layer.add(route) assert route_layer.index is not None and route_layer.categories is not None @@ -97,11 +111,15 @@ def test_query_and_classification(self, openai_encoder, bm25_encoder, routes): assert query_result in ["Route 1", "Route 2"] def test_query_with_no_index(self, openai_encoder, bm25_encoder): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder + ) assert route_layer("Anything") is None def test_semantic_classify(self, openai_encoder, bm25_encoder, routes): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes + ) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -111,8 +129,12 @@ def test_semantic_classify(self, openai_encoder, bm25_encoder, routes): assert classification == "Route 1" assert score == [0.9] - def test_semantic_classify_multiple_routes(self, openai_encoder, bm25_encoder, routes): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes) + def test_semantic_classify_multiple_routes( + self, openai_encoder, bm25_encoder, routes + ): + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder, routes=routes + ) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -124,12 +146,16 @@ def test_semantic_classify_multiple_routes(self, openai_encoder, bm25_encoder, r assert score == [0.9, 0.8] def test_pass_threshold(self, openai_encoder, bm25_encoder): - route_layer = HybridRouteLayer(dense_encoder=openai_encoder, sparse_encoder=bm25_encoder) + route_layer = HybridRouteLayer( + dense_encoder=openai_encoder, sparse_encoder=bm25_encoder + ) assert not route_layer._pass_threshold([], 0.5) assert route_layer._pass_threshold([0.6, 0.7], 0.5) def test_failover_score_threshold(self, base_encoder, bm25_encoder): - route_layer = HybridRouteLayer(dense_encoder=base_encoder, sparse_encoder=bm25_encoder) + route_layer = HybridRouteLayer( + dense_encoder=base_encoder, sparse_encoder=bm25_encoder + ) assert route_layer.score_threshold == 0.82 From 2ae56f5b7218385bf85dbe61a929de463be16cd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 20 Dec 2023 20:56:33 +0000 Subject: [PATCH 06/37] added text preprocessing --- semantic_router/encoders/tfidf.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index e7c5782f..d6be6da5 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -3,6 +3,7 @@ from semantic_router.encoders import BaseEncoder from semantic_router.schema import Route from numpy.linalg import norm +import string class TfidfEncoder(BaseEncoder): @@ -20,6 +21,7 @@ def __call__(self, docs: list[str]) -> list[list[float]]: if len(docs) == 0: raise ValueError("No documents to encode.") + docs = [self._preprocess(doc) for doc in docs] tf = self._compute_tf(docs) tfidf = tf * self.idf return tfidf.tolist() @@ -27,8 +29,8 @@ def __call__(self, docs: list[str]) -> list[list[float]]: def fit(self, routes: list[Route]): docs = [] for route in routes: - for utterance in route.utterances: - docs.append(utterance) + for doc in route.utterances: + docs.append(self._preprocess(doc)) self.word_index = self._build_word_index(docs) self.idf = self._compute_idf(docs) @@ -60,3 +62,10 @@ def _compute_idf(self, docs: list[str]) -> np.ndarray: idf[self.word_index[word]] += 1 idf = np.log(len(docs) / (idf + 1)) return idf + + def _preprocess(self, doc: str) -> str: + lowercased_doc = doc.lower() + no_punctuation_doc = lowercased_doc.translate( + str.maketrans("", "", string.punctuation) + ) + return no_punctuation_doc From fb5da0c14a56c9c5de127539486156f4d68df95e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 20 Dec 2023 20:59:50 +0000 Subject: [PATCH 07/37] revert --- docs/examples/hybrid-layer.ipynb | 135 ++++++++----------------------- 1 file changed, 34 insertions(+), 101 deletions(-) diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 9c0a02fc..8b1da5ae 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -30,11 +30,11 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# !pip install -qU semantic-router==0.0.6" + "!pip install -qU semantic-router==0.0.6" ] }, { @@ -46,15 +46,18 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/danielgriffiths/Coding_files/Aurelio_local/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + "ename": "ImportError", + "evalue": "cannot import name 'Route' from 'semantic_router.schema' (/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/semantic_router/schema.py)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/jakit/customers/aurelio/semantic-router/docs/examples/hybrid-layer.ipynb Cell 7\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msemantic_router\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mschema\u001b[39;00m \u001b[39mimport\u001b[39;00m Route\n\u001b[1;32m 3\u001b[0m politics \u001b[39m=\u001b[39m Route(\n\u001b[1;32m 4\u001b[0m name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mpolitics\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 5\u001b[0m utterances\u001b[39m=\u001b[39m[\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 12\u001b[0m ],\n\u001b[1;32m 13\u001b[0m )\n", + "\u001b[0;31mImportError\u001b[0m: cannot import name 'Route' from 'semantic_router.schema' (/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/semantic_router/schema.py)" ] } ], @@ -83,10 +86,21 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", "chitchat = Route(\n", " name=\"chitchat\",\n", " utterances=[\n", @@ -110,21 +124,19 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", - "from semantic_router.encoders import CohereEncoder, BM25Encoder, TfidfEncoder\n", + "from semantic_router.encoders import CohereEncoder\n", "from getpass import getpass\n", "\n", "os.environ[\"COHERE_API_KEY\"] = os.environ[\"COHERE_API_KEY\"] or getpass(\n", " \"Enter Cohere API Key: \"\n", ")\n", "\n", - "dense_encoder = CohereEncoder()\n", - "# sparse_encoder = BM25Encoder()\n", - "sparse_encoder = TfidfEncoder()" + "encoder = CohereEncoder()" ] }, { @@ -136,110 +148,31 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 2/2 [00:00<00:00, 3.41it/s]\n" - ] - } - ], + "outputs": [], "source": [ "from semantic_router.hybrid_layer import HybridRouteLayer\n", "\n", - "dl = HybridRouteLayer(\n", - " dense_encoder=dense_encoder, sparse_encoder=sparse_encoder, routes=routes\n", - ")" + "dl = HybridRouteLayer(encoder=encoder, routes=routes)" ] }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'politics'" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dl(\"don't you love politics?\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'chitchat'" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dl(\"how's the weather today?\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "religion = Route(\n", - " name=\"religion\",\n", - " utterances=[\n", - " \"what do you know about Buddhism?\",\n", - " \"tell me about Christianity\",\n", - " \"explain the principles of Hinduism\",\n", - " \"describe the teachings of Islam\",\n", - " \"what are the main beliefs of Judaism?\",\n", - " ],\n", - ")" + "dl(\"don't you love politics?\")" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "dl.add(religion)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'religion'" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dl(\"what do you think of Hinduism?\")" + "dl(\"how's the weather today?\")" ] }, { @@ -266,7 +199,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.3" } }, "nbformat": 4, From d0390428f3099044f1adbd4dbcb0a4adcfcd36cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 20 Dec 2023 22:45:49 +0000 Subject: [PATCH 08/37] original coverage --- coverage.xml | 113 ++++++++++++++++----------------------------------- 1 file changed, 35 insertions(+), 78 deletions(-) diff --git a/coverage.xml b/coverage.xml index 9c68d3d9..9af9ebee 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,12 +1,12 @@ - - + + - /Users/danielgriffiths/Coding_files/Aurelio_local/semantic-router/semantic_router + /Users/jakit/customers/aurelio/semantic-router/semantic_router - + @@ -16,7 +16,7 @@ - + @@ -31,95 +31,84 @@ + + + - - + + - - - - - - - - + + + - + - - + - - + + - - - + + + + + - + + - + - - + + + - - + - - - + - - - - - - - - - - + @@ -261,7 +250,7 @@ - + @@ -270,8 +259,7 @@ - - + @@ -286,7 +274,7 @@ - + @@ -310,7 +298,7 @@ - + @@ -399,37 +387,6 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - From 8c79849dc08647f4c5dd53f63a1670650887d58a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Thu, 21 Dec 2023 10:25:12 +0000 Subject: [PATCH 09/37] schemas in seperate files due to circular import --- README.md | 7 +- docs/examples/function_calling.ipynb | 1250 ++++++++--------- docs/examples/hybrid-layer.ipynb | 398 +++--- semantic_router/encoders/tfidf.py | 2 +- semantic_router/hybrid_layer.py | 2 +- semantic_router/layer.py | 2 +- .../{schema.py => schemas/encoder.py} | 21 - semantic_router/schemas/route.py | 7 + semantic_router/schemas/semantic_space.py | 17 + tests/unit/encoders/test_tfidf.py | 2 +- tests/unit/test_hybrid_layer.py | 2 +- tests/unit/test_layer.py | 2 +- tests/unit/test_schema.py | 8 +- walkthrough.ipynb | 439 +++--- 14 files changed, 1096 insertions(+), 1063 deletions(-) rename semantic_router/{schema.py => schemas/encoder.py} (67%) create mode 100644 semantic_router/schemas/route.py create mode 100644 semantic_router/schemas/semantic_space.py diff --git a/README.md b/README.md index 2db0e9b7..4ab443ce 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ [![Aurelio AI](https://pbs.twimg.com/profile_banners/1671498317455581184/1696285195/1500x500)](https://aurelio.ai) # Semantic Router +

GitHub Contributors GitHub Last Commit @@ -24,7 +25,7 @@ pip install -qU semantic-router We begin by defining a set of `Decision` objects. These are the decision paths that the semantic router can decide to use, let's try two simple decisions for now — one for talk on _politics_ and another for _chitchat_: ```python -from semantic_router.schema import Route +from semantic_router.schemas.route import Route # we could use this as a guide for our chatbot to avoid political conversations politics = Route( @@ -112,6 +113,6 @@ In this case, no decision could be made as we had no matches — so our decision ## 📚 Resources -| | | -| --------------------------------------------------------------------------------------------------------------- | -------------------------- | +| | | +| ------------------------------------------------------------------------------------------------------------------ | -------------------------- | | 🏃[Walkthrough](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/walkthrough.ipynb) | Quickstart Python notebook | diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 5d3be2fb..f487f77d 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -1,647 +1,647 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define LLMs" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "# OpenAI\n", - "import openai\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "# Docs # https://platform.openai.com/docs/guides/function-calling\n", - "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n", - " try:\n", - " logger.info(f\"Calling {model} model\")\n", - " response = openai.chat.completions.create(\n", - " model=model,\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": f\"{prompt}\"},\n", - " ],\n", - " )\n", - " ai_message = response.choices[0].message.content\n", - " if not ai_message:\n", - " raise Exception(\"AI message is empty\", ai_message)\n", - " logger.info(f\"AI message: {ai_message}\")\n", - " return ai_message\n", - " except Exception as e:\n", - " raise Exception(\"Failed to call OpenAI API\", e)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "# Mistral\n", - "import os\n", - "import requests\n", - "\n", - "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n", - "HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n", - "\n", - "\n", - "def llm_mistral(prompt: str) -> str:\n", - " api_url = \"https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/\"\n", - " headers = {\n", - " \"Authorization\": f\"Bearer {HF_API_TOKEN}\",\n", - " \"Content-Type\": \"application/json\",\n", - " }\n", - "\n", - " logger.info(\"Calling Mistral model\")\n", - " response = requests.post(\n", - " api_url,\n", - " headers=headers,\n", - " json={\n", - " \"inputs\": f\"You are a helpful assistant, user query: {prompt}\",\n", - " \"parameters\": {\n", - " \"max_new_tokens\": 200,\n", - " \"temperature\": 0.01,\n", - " \"num_beams\": 5,\n", - " \"num_return_sequences\": 1,\n", - " },\n", - " },\n", - " )\n", - " if response.status_code != 200:\n", - " raise Exception(\"Failed to call HuggingFace API\", response.text)\n", - "\n", - " ai_message = response.json()[0][\"generated_text\"]\n", - " if not ai_message:\n", - " raise Exception(\"AI message is empty\", ai_message)\n", - " logger.info(f\"AI message: {ai_message}\")\n", - " return ai_message" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Now we need to generate config from function schema using LLM" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "import inspect\n", - "from typing import Any\n", - "\n", - "\n", - "def get_function_schema(function) -> dict[str, Any]:\n", - " schema = {\n", - " \"name\": function.__name__,\n", - " \"description\": str(inspect.getdoc(function)),\n", - " \"signature\": str(inspect.signature(function)),\n", - " \"output\": str(\n", - " inspect.signature(function).return_annotation,\n", - " ),\n", - " }\n", - " return schema" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "\n", - "def is_valid_config(route_config_str: str) -> bool:\n", - " try:\n", - " output_json = json.loads(route_config_str)\n", - " return all(key in output_json for key in [\"name\", \"utterances\"])\n", - " except json.JSONDecodeError:\n", - " return False" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "def generate_route(function) -> dict:\n", - " logger.info(\"Generating config...\")\n", - "\n", - " function_schema = get_function_schema(function)\n", - "\n", - " prompt = f\"\"\"\n", - " You are tasked to generate a JSON configuration based on the provided\n", - " function schema. Please follow the template below:\n", - "\n", - " {{\n", - " \"name\": \"\",\n", - " \"utterances\": [\n", - " \"\",\n", - " \"\",\n", - " \"\",\n", - " \"\",\n", - " \"\"]\n", - " }}\n", - "\n", - " Only include the \"name\" and \"utterances\" keys in your answer.\n", - " The \"name\" should match the function name and the \"utterances\"\n", - " should comprise a list of 5 example phrases that could be used to invoke\n", - " the function.\n", - "\n", - " Input schema:\n", - " {function_schema}\n", - " \"\"\"\n", - "\n", - " try:\n", - " ai_message = llm_mistral(prompt)\n", - "\n", - " # Parse the response\n", - " ai_message = ai_message[ai_message.find(\"{\") :]\n", - " ai_message = (\n", - " ai_message.replace(\"'\", '\"')\n", - " .replace('\"s', \"'s\")\n", - " .strip()\n", - " .rstrip(\",\")\n", - " .replace(\"}\", \"}\")\n", - " )\n", - "\n", - " valid_config = is_valid_config(ai_message)\n", - "\n", - " if not valid_config:\n", - " logger.warning(f\"Mistral failed with error, falling back to OpenAI\")\n", - " ai_message = llm_openai(prompt)\n", - " if not is_valid_config(ai_message):\n", - " raise Exception(\"Invalid config generated\")\n", - " except Exception as e:\n", - " logger.error(f\"Fall back to OpenAI failed with error {e}\")\n", - " ai_message = llm_openai(prompt)\n", - " if not is_valid_config(ai_message):\n", - " raise Exception(\"Failed to generate config\")\n", - "\n", - " try:\n", - " route_config = json.loads(ai_message)\n", - " logger.info(f\"Generated config: {route_config}\")\n", - " return route_config\n", - " except json.JSONDecodeError as json_error:\n", - " logger.error(f\"JSON parsing error {json_error}\")\n", - " print(f\"AI message: {ai_message}\")\n", - " return {\"error\": \"Failed to generate config\"}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Extract function parameters using `Mistral` open-source model" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "def validate_parameters(function, parameters):\n", - " sig = inspect.signature(function)\n", - " for name, param in sig.parameters.items():\n", - " if name not in parameters:\n", - " return False, f\"Parameter {name} missing from query\"\n", - " if not isinstance(parameters[name], param.annotation):\n", - " return False, f\"Parameter {name} is not of type {param.annotation}\"\n", - " return True, \"Parameters are valid\"" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "def extract_parameters(query: str, function) -> dict:\n", - " logger.info(\"Extracting parameters...\")\n", - " example_query = \"How is the weather in Hawaii right now in International units?\"\n", - "\n", - " example_schema = {\n", - " \"name\": \"get_weather\",\n", - " \"description\": \"Useful to get the weather in a specific location\",\n", - " \"signature\": \"(location: str, degree: str) -> str\",\n", - " \"output\": \"\",\n", - " }\n", - "\n", - " example_parameters = {\n", - " \"location\": \"London\",\n", - " \"degree\": \"Celsius\",\n", - " }\n", - "\n", - " prompt = f\"\"\"\n", - " You are a helpful assistant designed to output JSON.\n", - " Given the following function schema\n", - " << {get_function_schema(function)} >>\n", - " and query\n", - " << {query} >>\n", - " extract the parameters values from the query, in a valid JSON format.\n", - " Example:\n", - " Input:\n", - " query: {example_query}\n", - " schema: {example_schema}\n", - "\n", - " Result: {example_parameters}\n", - "\n", - " Input:\n", - " query: {query}\n", - " schema: {get_function_schema(function)}\n", - " Result:\n", - " \"\"\"\n", - "\n", - " try:\n", - " ai_message = llm_mistral(prompt)\n", - " ai_message = (\n", - " ai_message.replace(\"Output:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", - " )\n", - " except Exception as e:\n", - " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", - " ai_message = llm_openai(prompt)\n", - "\n", - " try:\n", - " parameters = json.loads(ai_message)\n", - " valid, message = validate_parameters(function, parameters)\n", - "\n", - " if not valid:\n", - " logger.warning(\n", - " f\"Invalid parameters from Mistral, falling back to OpenAI: {message}\"\n", - " )\n", - " # Fall back to OpenAI\n", - " ai_message = llm_openai(prompt)\n", - " parameters = json.loads(ai_message)\n", - " valid, message = validate_parameters(function, parameters)\n", - " if not valid:\n", - " raise ValueError(message)\n", - "\n", - " logger.info(f\"Extracted parameters: {parameters}\")\n", - " return parameters\n", - " except ValueError as e:\n", - " logger.error(f\"Parameter validation error: {str(e)}\")\n", - " return {\"error\": \"Failed to validate parameters\"}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Set up the routing layer" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.schema import Route\n", - "from semantic_router.encoders import CohereEncoder\n", - "from semantic_router.layer import RouteLayer\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "def create_router(routes: list[dict]) -> RouteLayer:\n", - " logger.info(\"Creating route layer...\")\n", - " encoder = CohereEncoder()\n", - "\n", - " route_list: list[Route] = []\n", - " for route in routes:\n", - " if \"name\" in route and \"utterances\" in route:\n", - " print(f\"Route: {route}\")\n", - " route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", - " else:\n", - " logger.warning(f\"Misconfigured route: {route}\")\n", - "\n", - " return RouteLayer(encoder=encoder, routes=route_list)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Set up calling functions" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Callable\n", - "from semantic_router.layer import RouteLayer\n", - "\n", - "\n", - "def call_function(function: Callable, parameters: dict[str, str]):\n", - " try:\n", - " return function(**parameters)\n", - " except TypeError as e:\n", - " logger.error(f\"Error calling function: {e}\")\n", - "\n", - "\n", - "def call_llm(query: str) -> str:\n", - " try:\n", - " ai_message = llm_mistral(query)\n", - " except Exception as e:\n", - " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", - " ai_message = llm_openai(query)\n", - "\n", - " return ai_message\n", - "\n", - "\n", - "def call(query: str, functions: list[Callable], router: RouteLayer):\n", - " function_name = router(query)\n", - " if not function_name:\n", - " logger.warning(\"No function found\")\n", - " return call_llm(query)\n", - "\n", - " for function in functions:\n", - " if function.__name__ == function_name:\n", - " parameters = extract_parameters(query, function)\n", - " print(f\"parameters: {parameters}\")\n", - " return call_function(function, parameters)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Workflow" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[31m2023-12-18 12:18:00 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:00 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger AI message: {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"what is the time in new york\",\n", - " \"can you tell me the time in london\",\n", - " \"get me the current time in tokyo\",\n", - " \"i need to know the time in sydney\",\n", - " \"please tell me the current time in paris\"\n", - " ]\n", - "}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[31m2023-12-18 12:18:07 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:07 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger AI message: {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"Can I get the latest news in Canada?\",\n", - " \"Show me the recent news in the US\",\n", - " \"I would like to know about the sports news in England\",\n", - " \"Let's check the technology news in Japan\",\n", - " \"Show me the health related news in Germany\"\n", - " ]\n", - "}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define LLMs" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Route: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\n", - "Route: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\n" - ] - } - ], - "source": [ - "def get_time(location: str) -> str:\n", - " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", - " return \"get_time\"\n", - "\n", - "\n", - "def get_news(category: str, country: str) -> str:\n", - " \"\"\"Useful to get the news in a specific country\"\"\"\n", - " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", - " )\n", - " return \"get_news\"\n", - "\n", - "\n", - "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", - "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", - "\n", - "# Tools\n", - "tools = [get_time, get_news]" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# OpenAI\n", + "import openai\n", + "from semantic_router.utils.logger import logger\n", + "\n", + "\n", + "# Docs # https://platform.openai.com/docs/guides/function-calling\n", + "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n", + " try:\n", + " logger.info(f\"Calling {model} model\")\n", + " response = openai.chat.completions.create(\n", + " model=model,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": f\"{prompt}\"},\n", + " ],\n", + " )\n", + " ai_message = response.choices[0].message.content\n", + " if not ai_message:\n", + " raise Exception(\"AI message is empty\", ai_message)\n", + " logger.info(f\"AI message: {ai_message}\")\n", + " return ai_message\n", + " except Exception as e:\n", + " raise Exception(\"Failed to call OpenAI API\", e)" + ] + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger AI message: \n", - " Example output:\n", - " {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"What's the time in New York?\",\n", - " \"Tell me the time in Tokyo.\",\n", - " \"Can you give me the time in London?\",\n", - " \"What's the current time in Sydney?\",\n", - " \"Can you tell me the time in Berlin?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger AI message: \n", - " Example output:\n", - " {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"Tell me the latest news from the US\",\n", - " \"What's happening in India today?\",\n", - " \"Get me the top stories from Japan\",\n", - " \"Can you give me the breaking news from Brazil?\",\n", - " \"What's the latest news from Germany?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" - ] + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# Mistral\n", + "import os\n", + "import requests\n", + "\n", + "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n", + "HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n", + "\n", + "\n", + "def llm_mistral(prompt: str) -> str:\n", + " api_url = \"https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/\"\n", + " headers = {\n", + " \"Authorization\": f\"Bearer {HF_API_TOKEN}\",\n", + " \"Content-Type\": \"application/json\",\n", + " }\n", + "\n", + " logger.info(\"Calling Mistral model\")\n", + " response = requests.post(\n", + " api_url,\n", + " headers=headers,\n", + " json={\n", + " \"inputs\": f\"You are a helpful assistant, user query: {prompt}\",\n", + " \"parameters\": {\n", + " \"max_new_tokens\": 200,\n", + " \"temperature\": 0.01,\n", + " \"num_beams\": 5,\n", + " \"num_return_sequences\": 1,\n", + " },\n", + " },\n", + " )\n", + " if response.status_code != 200:\n", + " raise Exception(\"Failed to call HuggingFace API\", response.text)\n", + "\n", + " ai_message = response.json()[0][\"generated_text\"]\n", + " if not ai_message:\n", + " raise Exception(\"AI message is empty\", ai_message)\n", + " logger.info(f\"AI message: {ai_message}\")\n", + " return ai_message" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Route: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\n", - "Route: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\n" - ] - } - ], - "source": [ - "def get_time(location: str) -> str:\n", - " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", - " return \"get_time\"\n", - "\n", - "\n", - "def get_news(category: str, country: str) -> str:\n", - " \"\"\"Useful to get the news in a specific country\"\"\"\n", - " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", - " )\n", - " return \"get_news\"\n", - "\n", - "\n", - "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", - "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", - "\n", - "# Tools\n", - "tools = [get_time, get_news]" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Now we need to generate config from function schema using LLM" + ] + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"location\": \"Stockholm\"\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" - ] + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import inspect\n", + "from typing import Any\n", + "\n", + "\n", + "def get_function_schema(function) -> dict[str, Any]:\n", + " schema = {\n", + " \"name\": function.__name__,\n", + " \"description\": str(inspect.getdoc(function)),\n", + " \"signature\": str(inspect.signature(function)),\n", + " \"output\": str(\n", + " inspect.signature(function).return_annotation,\n", + " ),\n", + " }\n", + " return schema" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "parameters: {'location': 'Stockholm'}\n", - "Calling `get_time` function with location: Stockholm\n" - ] + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "\n", + "def is_valid_config(route_config_str: str) -> bool:\n", + " try:\n", + " output_json = json.loads(route_config_str)\n", + " return all(key in output_json for key in [\"name\", \"utterances\"])\n", + " except json.JSONDecodeError:\n", + " return False" + ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"category\": \"tech\",\n", - " \"country\": \"Lithuania\"\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" - ] + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "from semantic_router.utils.logger import logger\n", + "\n", + "\n", + "def generate_route(function) -> dict:\n", + " logger.info(\"Generating config...\")\n", + "\n", + " function_schema = get_function_schema(function)\n", + "\n", + " prompt = f\"\"\"\n", + " You are tasked to generate a JSON configuration based on the provided\n", + " function schema. Please follow the template below:\n", + "\n", + " {{\n", + " \"name\": \"\",\n", + " \"utterances\": [\n", + " \"\",\n", + " \"\",\n", + " \"\",\n", + " \"\",\n", + " \"\"]\n", + " }}\n", + "\n", + " Only include the \"name\" and \"utterances\" keys in your answer.\n", + " The \"name\" should match the function name and the \"utterances\"\n", + " should comprise a list of 5 example phrases that could be used to invoke\n", + " the function.\n", + "\n", + " Input schema:\n", + " {function_schema}\n", + " \"\"\"\n", + "\n", + " try:\n", + " ai_message = llm_mistral(prompt)\n", + "\n", + " # Parse the response\n", + " ai_message = ai_message[ai_message.find(\"{\") :]\n", + " ai_message = (\n", + " ai_message.replace(\"'\", '\"')\n", + " .replace('\"s', \"'s\")\n", + " .strip()\n", + " .rstrip(\",\")\n", + " .replace(\"}\", \"}\")\n", + " )\n", + "\n", + " valid_config = is_valid_config(ai_message)\n", + "\n", + " if not valid_config:\n", + " logger.warning(f\"Mistral failed with error, falling back to OpenAI\")\n", + " ai_message = llm_openai(prompt)\n", + " if not is_valid_config(ai_message):\n", + " raise Exception(\"Invalid config generated\")\n", + " except Exception as e:\n", + " logger.error(f\"Fall back to OpenAI failed with error {e}\")\n", + " ai_message = llm_openai(prompt)\n", + " if not is_valid_config(ai_message):\n", + " raise Exception(\"Failed to generate config\")\n", + "\n", + " try:\n", + " route_config = json.loads(ai_message)\n", + " logger.info(f\"Generated config: {route_config}\")\n", + " return route_config\n", + " except json.JSONDecodeError as json_error:\n", + " logger.error(f\"JSON parsing error {json_error}\")\n", + " print(f\"AI message: {ai_message}\")\n", + " return {\"error\": \"Failed to generate config\"}" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "parameters: {'category': 'tech', 'country': 'Lithuania'}\n", - "Calling `get_news` function with category: tech and country: Lithuania\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Extract function parameters using `Mistral` open-source model" + ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[33m2023-12-18 12:20:05 WARNING semantic_router.utils.logger No function found\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:06 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" - ] + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def validate_parameters(function, parameters):\n", + " sig = inspect.signature(function)\n", + " for name, param in sig.parameters.items():\n", + " if name not in parameters:\n", + " return False, f\"Parameter {name} missing from query\"\n", + " if not isinstance(parameters[name], param.annotation):\n", + " return False, f\"Parameter {name} is not of type {param.annotation}\"\n", + " return True, \"Parameters are valid\"" + ] }, { - "data": { - "text/plain": [ - "' How can I help you today?'" + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "def extract_parameters(query: str, function) -> dict:\n", + " logger.info(\"Extracting parameters...\")\n", + " example_query = \"How is the weather in Hawaii right now in International units?\"\n", + "\n", + " example_schema = {\n", + " \"name\": \"get_weather\",\n", + " \"description\": \"Useful to get the weather in a specific location\",\n", + " \"signature\": \"(location: str, degree: str) -> str\",\n", + " \"output\": \"\",\n", + " }\n", + "\n", + " example_parameters = {\n", + " \"location\": \"London\",\n", + " \"degree\": \"Celsius\",\n", + " }\n", + "\n", + " prompt = f\"\"\"\n", + " You are a helpful assistant designed to output JSON.\n", + " Given the following function schema\n", + " << {get_function_schema(function)} >>\n", + " and query\n", + " << {query} >>\n", + " extract the parameters values from the query, in a valid JSON format.\n", + " Example:\n", + " Input:\n", + " query: {example_query}\n", + " schema: {example_schema}\n", + "\n", + " Result: {example_parameters}\n", + "\n", + " Input:\n", + " query: {query}\n", + " schema: {get_function_schema(function)}\n", + " Result:\n", + " \"\"\"\n", + "\n", + " try:\n", + " ai_message = llm_mistral(prompt)\n", + " ai_message = (\n", + " ai_message.replace(\"Output:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", + " )\n", + " except Exception as e:\n", + " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", + " ai_message = llm_openai(prompt)\n", + "\n", + " try:\n", + " parameters = json.loads(ai_message)\n", + " valid, message = validate_parameters(function, parameters)\n", + "\n", + " if not valid:\n", + " logger.warning(\n", + " f\"Invalid parameters from Mistral, falling back to OpenAI: {message}\"\n", + " )\n", + " # Fall back to OpenAI\n", + " ai_message = llm_openai(prompt)\n", + " parameters = json.loads(ai_message)\n", + " valid, message = validate_parameters(function, parameters)\n", + " if not valid:\n", + " raise ValueError(message)\n", + "\n", + " logger.info(f\"Extracted parameters: {parameters}\")\n", + " return parameters\n", + " except ValueError as e:\n", + " logger.error(f\"Parameter validation error: {str(e)}\")\n", + " return {\"error\": \"Failed to validate parameters\"}" ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set up the routing layer" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router.schemas.route import Route\n", + "from semantic_router.encoders import CohereEncoder\n", + "from semantic_router.layer import RouteLayer\n", + "from semantic_router.utils.logger import logger\n", + "\n", + "\n", + "def create_router(routes: list[dict]) -> RouteLayer:\n", + " logger.info(\"Creating route layer...\")\n", + " encoder = CohereEncoder()\n", + "\n", + " route_list: list[Route] = []\n", + " for route in routes:\n", + " if \"name\" in route and \"utterances\" in route:\n", + " print(f\"Route: {route}\")\n", + " route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", + " else:\n", + " logger.warning(f\"Misconfigured route: {route}\")\n", + "\n", + " return RouteLayer(encoder=encoder, routes=route_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set up calling functions" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable\n", + "from semantic_router.layer import RouteLayer\n", + "\n", + "\n", + "def call_function(function: Callable, parameters: dict[str, str]):\n", + " try:\n", + " return function(**parameters)\n", + " except TypeError as e:\n", + " logger.error(f\"Error calling function: {e}\")\n", + "\n", + "\n", + "def call_llm(query: str) -> str:\n", + " try:\n", + " ai_message = llm_mistral(query)\n", + " except Exception as e:\n", + " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", + " ai_message = llm_openai(query)\n", + "\n", + " return ai_message\n", + "\n", + "\n", + "def call(query: str, functions: list[Callable], router: RouteLayer):\n", + " function_name = router(query)\n", + " if not function_name:\n", + " logger.warning(\"No function found\")\n", + " return call_llm(query)\n", + "\n", + " for function in functions:\n", + " if function.__name__ == function_name:\n", + " parameters = extract_parameters(query, function)\n", + " print(f\"parameters: {parameters}\")\n", + " return call_function(function, parameters)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[31m2023-12-18 12:18:00 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:00 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger AI message: {\n", + " \"name\": \"get_time\",\n", + " \"utterances\": [\n", + " \"what is the time in new york\",\n", + " \"can you tell me the time in london\",\n", + " \"get me the current time in tokyo\",\n", + " \"i need to know the time in sydney\",\n", + " \"please tell me the current time in paris\"\n", + " ]\n", + "}\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[31m2023-12-18 12:18:07 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:07 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger AI message: {\n", + " \"name\": \"get_news\",\n", + " \"utterances\": [\n", + " \"Can I get the latest news in Canada?\",\n", + " \"Show me the recent news in the US\",\n", + " \"I would like to know about the sports news in England\",\n", + " \"Let's check the technology news in Japan\",\n", + " \"Show me the health related news in Germany\"\n", + " ]\n", + "}\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Route: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\n", + "Route: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\n" + ] + } + ], + "source": [ + "def get_time(location: str) -> str:\n", + " \"\"\"Useful to get the time in a specific location\"\"\"\n", + " print(f\"Calling `get_time` function with location: {location}\")\n", + " return \"get_time\"\n", + "\n", + "\n", + "def get_news(category: str, country: str) -> str:\n", + " \"\"\"Useful to get the news in a specific country\"\"\"\n", + " print(\n", + " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", + " )\n", + " return \"get_news\"\n", + "\n", + "\n", + "# Registering functions to the router\n", + "route_get_time = generate_route(get_time)\n", + "route_get_news = generate_route(get_news)\n", + "\n", + "routes = [route_get_time, route_get_news]\n", + "router = create_router(routes)\n", + "\n", + "# Tools\n", + "tools = [get_time, get_news]" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger AI message: \n", + " Example output:\n", + " {\n", + " \"name\": \"get_time\",\n", + " \"utterances\": [\n", + " \"What's the time in New York?\",\n", + " \"Tell me the time in Tokyo.\",\n", + " \"Can you give me the time in London?\",\n", + " \"What's the current time in Sydney?\",\n", + " \"Can you tell me the time in Berlin?\"\n", + " ]\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger AI message: \n", + " Example output:\n", + " {\n", + " \"name\": \"get_news\",\n", + " \"utterances\": [\n", + " \"Tell me the latest news from the US\",\n", + " \"What's happening in India today?\",\n", + " \"Get me the top stories from Japan\",\n", + " \"Can you give me the breaking news from Brazil?\",\n", + " \"What's the latest news from Germany?\"\n", + " ]\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Route: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\n", + "Route: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\n" + ] + } + ], + "source": [ + "def get_time(location: str) -> str:\n", + " \"\"\"Useful to get the time in a specific location\"\"\"\n", + " print(f\"Calling `get_time` function with location: {location}\")\n", + " return \"get_time\"\n", + "\n", + "\n", + "def get_news(category: str, country: str) -> str:\n", + " \"\"\"Useful to get the news in a specific country\"\"\"\n", + " print(\n", + " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", + " )\n", + " return \"get_news\"\n", + "\n", + "\n", + "# Registering functions to the router\n", + "route_get_time = generate_route(get_time)\n", + "route_get_news = generate_route(get_news)\n", + "\n", + "routes = [route_get_time, route_get_news]\n", + "router = create_router(routes)\n", + "\n", + "# Tools\n", + "tools = [get_time, get_news]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger AI message: \n", + " {\n", + " \"location\": \"Stockholm\"\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "parameters: {'location': 'Stockholm'}\n", + "Calling `get_time` function with location: Stockholm\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger AI message: \n", + " {\n", + " \"category\": \"tech\",\n", + " \"country\": \"Lithuania\"\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "parameters: {'category': 'tech', 'country': 'Lithuania'}\n", + "Calling `get_news` function with category: tech and country: Lithuania\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m2023-12-18 12:20:05 WARNING semantic_router.utils.logger No function found\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:06 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "' How can I help you today?'" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", + "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", + "call(query=\"Hi!\", functions=tools, router=router)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" } - ], - "source": [ - "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", - "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", - "call(query=\"Hi!\", functions=tools, router=router)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 8b1da5ae..9e5eca66 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -1,207 +1,199 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Semantic Router: Hybrid Layer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The Hybrid Layer in the Semantic Router library can improve making performance particularly for niche use-cases that contain specific terminology, such as finance or medical. It helps us provide more importance to making based on the keywords contained in our utterances and user queries." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Getting Started" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We start by installing the library:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -qU semantic-router==0.0.6" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We start by defining a dictionary mapping s to example phrases that should trigger those s." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "ename": "ImportError", - "evalue": "cannot import name 'Route' from 'semantic_router.schema' (/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/semantic_router/schema.py)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/Users/jakit/customers/aurelio/semantic-router/docs/examples/hybrid-layer.ipynb Cell 7\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msemantic_router\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mschema\u001b[39;00m \u001b[39mimport\u001b[39;00m Route\n\u001b[1;32m 3\u001b[0m politics \u001b[39m=\u001b[39m Route(\n\u001b[1;32m 4\u001b[0m name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mpolitics\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 5\u001b[0m utterances\u001b[39m=\u001b[39m[\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 12\u001b[0m ],\n\u001b[1;32m 13\u001b[0m )\n", - "\u001b[0;31mImportError\u001b[0m: cannot import name 'Route' from 'semantic_router.schema' (/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/semantic_router/schema.py)" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Semantic Router: Hybrid Layer\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Hybrid Layer in the Semantic Router library can improve making performance particularly for niche use-cases that contain specific terminology, such as finance or medical. It helps us provide more importance to making based on the keywords contained in our utterances and user queries.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing the library:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install -qU semantic-router==0.0.6" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by defining a dictionary mapping s to example phrases that should trigger those s.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router.schemas.route import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\",\n", + " \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define another for good measure:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we initialize our embedding model:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from semantic_router.encoders import CohereEncoder, BM25Encoder, TfidfEncoder\n", + "from getpass import getpass\n", + "\n", + "os.environ[\"COHERE_API_KEY\"] = os.environ[\"COHERE_API_KEY\"] or getpass(\n", + " \"Enter Cohere API Key: \"\n", + ")\n", + "\n", + "dense_encoder = CohereEncoder()\n", + "# sparse_encoder = BM25Encoder()\n", + "sparse_encoder = TfidfEncoder()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router.hybrid_layer import HybridRouteLayer\n", + "\n", + "dl = HybridRouteLayer(\n", + " dense_encoder=dense_encoder, sparse_encoder=sparse_encoder, routes=routes\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dl(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dl(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" } - ], - "source": [ - "from semantic_router.schema import Route\n", - "\n", - "politics = Route(\n", - " name=\"politics\",\n", - " utterances=[\n", - " \"isn't politics the best thing ever\",\n", - " \"why don't you tell me about your political opinions\",\n", - " \"don't you just love the president\",\n", - " \"don't you just hate the president\",\n", - " \"they're going to destroy this country!\",\n", - " \"they will save the country!\",\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's define another for good measure:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "chitchat = Route(\n", - " name=\"chitchat\",\n", - " utterances=[\n", - " \"how's the weather today?\",\n", - " \"how are things going?\",\n", - " \"lovely weather today\",\n", - " \"the weather is horrendous\",\n", - " \"let's go to the chippy\",\n", - " ],\n", - ")\n", - "\n", - "chitchat = Route(\n", - " name=\"chitchat\",\n", - " utterances=[\n", - " \"how's the weather today?\",\n", - " \"how are things going?\",\n", - " \"lovely weather today\",\n", - " \"the weather is horrendous\",\n", - " \"let's go to the chippy\",\n", - " ],\n", - ")\n", - "\n", - "routes = [politics, chitchat]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we initialize our embedding model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from semantic_router.encoders import CohereEncoder\n", - "from getpass import getpass\n", - "\n", - "os.environ[\"COHERE_API_KEY\"] = os.environ[\"COHERE_API_KEY\"] or getpass(\n", - " \"Enter Cohere API Key: \"\n", - ")\n", - "\n", - "encoder = CohereEncoder()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.hybrid_layer import HybridRouteLayer\n", - "\n", - "dl = HybridRouteLayer(encoder=encoder, routes=routes)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dl(\"don't you love politics?\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dl(\"how's the weather today?\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "decision-layer", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index d6be6da5..6fc420eb 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -1,7 +1,7 @@ import numpy as np from collections import Counter from semantic_router.encoders import BaseEncoder -from semantic_router.schema import Route +from semantic_router.schemas.route import Route from numpy.linalg import norm import string diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 3993ca45..2901871a 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -8,7 +8,7 @@ OpenAIEncoder, TfidfEncoder, ) -from semantic_router.schema import Route +from semantic_router.schemas.route import Route from semantic_router.utils.logger import logger diff --git a/semantic_router/layer.py b/semantic_router/layer.py index cb408c5c..af08a9c1 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -6,7 +6,7 @@ OpenAIEncoder, ) from semantic_router.linear import similarity_matrix, top_scores -from semantic_router.schema import Route +from semantic_router.schemas.route import Route from semantic_router.utils.logger import logger diff --git a/semantic_router/schema.py b/semantic_router/schemas/encoder.py similarity index 67% rename from semantic_router/schema.py rename to semantic_router/schemas/encoder.py index 007cddcb..1b2ad74c 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schemas/encoder.py @@ -1,6 +1,5 @@ from enum import Enum -from pydantic import BaseModel from pydantic.dataclasses import dataclass from semantic_router.encoders import ( @@ -10,12 +9,6 @@ ) -class Route(BaseModel): - name: str - utterances: list[str] - description: str | None = None - - class EncoderType(Enum): HUGGINGFACE = "huggingface" OPENAI = "openai" @@ -40,17 +33,3 @@ def __init__(self, type: str, name: str): def __call__(self, texts: list[str]) -> list[list[float]]: return self.model(texts) - - -@dataclass -class SemanticSpace: - id: str - routes: list[Route] - encoder: str = "" - - def __init__(self, routes: list[Route] = []): - self.id = "" - self.routes = routes - - def add(self, route: Route): - self.routes.append(route) diff --git a/semantic_router/schemas/route.py b/semantic_router/schemas/route.py new file mode 100644 index 00000000..b70bc60f --- /dev/null +++ b/semantic_router/schemas/route.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class Route(BaseModel): + name: str + utterances: list[str] + description: str | None = None diff --git a/semantic_router/schemas/semantic_space.py b/semantic_router/schemas/semantic_space.py new file mode 100644 index 00000000..92e7adaf --- /dev/null +++ b/semantic_router/schemas/semantic_space.py @@ -0,0 +1,17 @@ +from pydantic.dataclasses import dataclass + +from semantic_router.schemas.route import Route + + +@dataclass +class SemanticSpace: + id: str + routes: list[Route] + encoder: str = "" + + def __init__(self, routes: list[Route] = []): + self.id = "" + self.routes = routes + + def add(self, route: Route): + self.routes.append(route) diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py index 93a96639..68e37d9e 100644 --- a/tests/unit/encoders/test_tfidf.py +++ b/tests/unit/encoders/test_tfidf.py @@ -1,6 +1,6 @@ import pytest from semantic_router.encoders import TfidfEncoder -from semantic_router.schema import Route +from semantic_router.schemas.route import Route @pytest.fixture diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index ee7d8f6b..2506c199 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -8,7 +8,7 @@ TfidfEncoder, ) from semantic_router.hybrid_layer import HybridRouteLayer -from semantic_router.schema import Route +from semantic_router.schemas.route import Route def mock_encoder_call(utterances): diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 66e0d53b..d049243f 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -2,7 +2,7 @@ from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.layer import RouteLayer -from semantic_router.schema import Route +from semantic_router.schemas.route import Route def mock_encoder_call(utterances): diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index f471755c..f47643c9 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -1,11 +1,17 @@ import pytest -from semantic_router.schema import ( +from semantic_router.schemas.encoder import ( CohereEncoder, Encoder, EncoderType, OpenAIEncoder, +) + +from semantic_router.schemas.route import ( Route, +) + +from semantic_router.schemas.semantic_space import ( SemanticSpace, ) diff --git a/walkthrough.ipynb b/walkthrough.ipynb index d008739c..346b576c 100644 --- a/walkthrough.ipynb +++ b/walkthrough.ipynb @@ -1,206 +1,237 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Semantic Router Walkthrough" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The Semantic Router library can be used as a super fast route making layer on top of LLMs. That means rather than waiting on a slow agent to decide what to do, we can use the magic of semantic vector space to make routes. Cutting route making time down from seconds to milliseconds." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Getting Started" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We start by installing the library:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -qU semantic-router==0.0.8" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We start by defining a dictionary mapping routes to example phrases that should trigger those routes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.schema import Route\n", - "\n", - "politics = Route(\n", - " name=\"politics\",\n", - " utterances=[\n", - " \"isn't politics the best thing ever\",\n", - " \"why don't you tell me about your political opinions\",\n", - " \"don't you just love the president\" \"don't you just hate the president\",\n", - " \"they're going to destroy this country!\",\n", - " \"they will save the country!\",\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's define another for good measure:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "chitchat = Route(\n", - " name=\"chitchat\",\n", - " utterances=[\n", - " \"how's the weather today?\",\n", - " \"how are things going?\",\n", - " \"lovely weather today\",\n", - " \"the weather is horrendous\",\n", - " \"let's go to the chippy\",\n", - " ],\n", - ")\n", - "\n", - "routes = [politics, chitchat]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we initialize our embedding model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from getpass import getpass\n", - "from semantic_router.encoders import CohereEncoder\n", - "\n", - "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n", - " \"Enter Cohere API Key: \"\n", - ")\n", - "\n", - "encoder = CohereEncoder()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.layer import RouteLayer\n", - "\n", - "dl = RouteLayer(encoder=encoder, routes=routes)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can test it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dl(\"don't you love politics?\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dl(\"how's the weather today?\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dl(\"I'm interested in learning about llama 2\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this case, we return `None` because no matches were identified." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "decision-layer", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Semantic Router Walkthrough\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Semantic Router library can be used as a super fast route making layer on top of LLMs. That means rather than waiting on a slow agent to decide what to do, we can use the magic of semantic vector space to make routes. Cutting route making time down from seconds to milliseconds.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing the library:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU semantic-router==0.0.8" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by defining a dictionary mapping routes to example phrases that should trigger those routes.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/danielgriffiths/Coding_files/Aurelio_local/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from semantic_router.schemas.route import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\" \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define another for good measure:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we initialize our embedding model:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "from semantic_router.encoders import CohereEncoder\n", + "\n", + "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n", + " \"Enter Cohere API Key: \"\n", + ")\n", + "\n", + "encoder = CohereEncoder()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router.layer import RouteLayer\n", + "\n", + "dl = RouteLayer(encoder=encoder, routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can test it:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'politics'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'chitchat'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects?\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "dl(\"I'm interested in learning about llama 2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we return `None` because no matches were identified.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } From 82f31476a1d49fc032d40ee6ba1d0c8896d722b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:50:38 +0000 Subject: [PATCH 10/37] updated with sparse/dense encoder --- docs/examples/hybrid-layer.ipynb | 75 ++++++++++++++++++++++++-------- 1 file changed, 57 insertions(+), 18 deletions(-) diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 18404591..7b2ff298 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -30,11 +30,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "!pip install -qU semantic-router==0.0.11" + "#!pip install -qU semantic-router==0.0.11" ] }, { @@ -46,11 +46,20 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.schema import Route\n", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/danielgriffiths/Coding_files/Aurelio_local/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from semantic_router.route import Route\n", "\n", "politics = Route(\n", " name=\"politics\",\n", @@ -81,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -119,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -145,9 +154,17 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-03 11:00:59 INFO semantic_router.utils.logger Creating embeddings for all routes...\u001b[0m\n" + ] + } + ], "source": [ "from semantic_router.hybrid_layer import HybridRouteLayer\n", "\n", @@ -158,18 +175,40 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'politics'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "dl(\"don't you love politics?\")" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'chitchat'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "dl(\"how's the weather today?\")" ] From 112201e4e9df82c5c2698f5791b7e6a48d0bd3a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:50:51 +0000 Subject: [PATCH 11/37] Route path updated --- semantic_router/encoders/tfidf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 6fc420eb..394f32fb 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -1,7 +1,7 @@ import numpy as np from collections import Counter from semantic_router.encoders import BaseEncoder -from semantic_router.schemas.route import Route +from semantic_router.route import Route from numpy.linalg import norm import string From 85b87f3fbdca1b78f729082d72f5f9a8d6da69cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:51:07 +0000 Subject: [PATCH 12/37] seperate schema files removed --- semantic_router/schemas/encoder.py | 42 ----------------------- semantic_router/schemas/route.py | 7 ---- semantic_router/schemas/semantic_space.py | 17 --------- 3 files changed, 66 deletions(-) delete mode 100644 semantic_router/schemas/encoder.py delete mode 100644 semantic_router/schemas/route.py delete mode 100644 semantic_router/schemas/semantic_space.py diff --git a/semantic_router/schemas/encoder.py b/semantic_router/schemas/encoder.py deleted file mode 100644 index fbbfb2d7..00000000 --- a/semantic_router/schemas/encoder.py +++ /dev/null @@ -1,42 +0,0 @@ -from enum import Enum - -from pydantic.dataclasses import dataclass - -from semantic_router.encoders import ( - BaseEncoder, - CohereEncoder, - OpenAIEncoder, -) - - -class EncoderType(Enum): - HUGGINGFACE = "huggingface" - OPENAI = "openai" - COHERE = "cohere" - - -class RouteChoice(BaseModel): - name: str | None = None - function_call: dict | None = None - - -@dataclass -class Encoder: - type: EncoderType - name: str | None - model: BaseEncoder - - def __init__(self, type: str, name: str | None): - self.type = EncoderType(type) - self.name = name - if self.type == EncoderType.HUGGINGFACE: - raise NotImplementedError - elif self.type == EncoderType.OPENAI: - self.model = OpenAIEncoder(name) - elif self.type == EncoderType.COHERE: - self.model = CohereEncoder(name) - else: - raise ValueError - - def __call__(self, texts: list[str]) -> list[list[float]]: - return self.model(texts) diff --git a/semantic_router/schemas/route.py b/semantic_router/schemas/route.py deleted file mode 100644 index b70bc60f..00000000 --- a/semantic_router/schemas/route.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydantic import BaseModel - - -class Route(BaseModel): - name: str - utterances: list[str] - description: str | None = None diff --git a/semantic_router/schemas/semantic_space.py b/semantic_router/schemas/semantic_space.py deleted file mode 100644 index 92e7adaf..00000000 --- a/semantic_router/schemas/semantic_space.py +++ /dev/null @@ -1,17 +0,0 @@ -from pydantic.dataclasses import dataclass - -from semantic_router.schemas.route import Route - - -@dataclass -class SemanticSpace: - id: str - routes: list[Route] - encoder: str = "" - - def __init__(self, routes: list[Route] = []): - self.id = "" - self.routes = routes - - def add(self, route: Route): - self.routes.append(route) From 99f8c8d7489e464b75612686c93b282625e2f7f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:51:32 +0000 Subject: [PATCH 13/37] dense_encoder instead of encoder --- semantic_router/hybrid_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index d62a996d..a5e6bd4b 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -98,7 +98,7 @@ def _add_routes(self, routes: list[Route]): all_utterances = [ utterance for route in routes for utterance in route.utterances ] - dense_embeds = np.array(self.encoder(all_utterances)) + dense_embeds = np.array(self.dense_encoder(all_utterances)) sparse_embeds = np.array(self.sparse_encoder(all_utterances)) # create route array From 975560c631df636a7f6928458152a3ab8e6aa070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:51:53 +0000 Subject: [PATCH 14/37] schema file restored --- semantic_router/schema.py | 43 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 semantic_router/schema.py diff --git a/semantic_router/schema.py b/semantic_router/schema.py new file mode 100644 index 00000000..b7a3c9fa --- /dev/null +++ b/semantic_router/schema.py @@ -0,0 +1,43 @@ +from enum import Enum + +from pydantic import BaseModel +from pydantic.dataclasses import dataclass + +from semantic_router.encoders import ( + BaseEncoder, + CohereEncoder, + OpenAIEncoder, +) + + +class EncoderType(Enum): + HUGGINGFACE = "huggingface" + OPENAI = "openai" + COHERE = "cohere" + + +class RouteChoice(BaseModel): + name: str | None = None + function_call: dict | None = None + + +@dataclass +class Encoder: + type: EncoderType + name: str | None + model: BaseEncoder + + def __init__(self, type: str, name: str | None): + self.type = EncoderType(type) + self.name = name + if self.type == EncoderType.HUGGINGFACE: + raise NotImplementedError + elif self.type == EncoderType.OPENAI: + self.model = OpenAIEncoder(name) + elif self.type == EncoderType.COHERE: + self.model = CohereEncoder(name) + else: + raise ValueError + + def __call__(self, texts: list[str]) -> list[list[float]]: + return self.model(texts) From 86eb3413fc8a949c104ba699dbc1b95c49c5d633 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:52:06 +0000 Subject: [PATCH 15/37] Route path updated --- tests/unit/encoders/test_tfidf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py index 68e37d9e..9e4cae29 100644 --- a/tests/unit/encoders/test_tfidf.py +++ b/tests/unit/encoders/test_tfidf.py @@ -1,6 +1,6 @@ import pytest from semantic_router.encoders import TfidfEncoder -from semantic_router.schemas.route import Route +from semantic_router.route import Route @pytest.fixture From acf01fd0ec97de84864800686e6c8a44ba436e04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:52:51 +0000 Subject: [PATCH 16/37] base encoder mocker path added --- tests/unit/test_hybrid_layer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index b77f51ad..567c5492 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -24,8 +24,15 @@ def mock_encoder_call(utterances): @pytest.fixture -def base_encoder(): - return BaseEncoder(name="test-encoder") +def base_encoder(mocker): + mock_base_encoder = BaseEncoder(name="test-encoder") + mocker.patch.object(BaseEncoder, "__call__", return_value=[[0.1, 0.2, 0.3]]) + return mock_base_encoder + + +# @pytest.fixture +# def base_encoder(): +# return BaseEncoder(name="test-encoder") @pytest.fixture From a8c64a677db6abf605a4795f9ce9ac5795d32f19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:53:45 +0000 Subject: [PATCH 17/37] schema path updated --- tests/unit/test_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index f4f97623..97b5028e 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -1,6 +1,6 @@ import pytest -from semantic_router.schemas.encoder import ( +from semantic_router.schema import ( CohereEncoder, Encoder, EncoderType, From 73258344b861da8a0111e408327b7dfd82825797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Thu, 4 Jan 2024 10:48:08 +0000 Subject: [PATCH 18/37] removed none types for mypy --- semantic_router/encoders/tfidf.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 394f32fb..68baceaa 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -7,16 +7,16 @@ class TfidfEncoder(BaseEncoder): - idf: dict | None = None - word_index: dict | None = None + idf: np.ndarray + word_index: dict def __init__(self, name: str = "tfidf"): super().__init__(name=name) - self.word_index = None - self.idf = None + self.word_index = {} + self.idf = np.array([]) def __call__(self, docs: list[str]) -> list[list[float]]: - if self.word_index is None or self.idf is None: + if len(self.word_index) == 0 or self.idf.size == 0: raise ValueError("Vectorizer is not initialized.") if len(docs) == 0: raise ValueError("No documents to encode.") @@ -43,6 +43,8 @@ def _build_word_index(self, docs: list[str]) -> dict: return word_index def _compute_tf(self, docs: list[str]) -> np.ndarray: + if len(self.word_index) == 0: + raise ValueError("Word index is not initialized.") tf = np.zeros((len(docs), len(self.word_index))) for i, doc in enumerate(docs): word_counts = Counter(doc.split()) @@ -54,6 +56,8 @@ def _compute_tf(self, docs: list[str]) -> np.ndarray: return tf def _compute_idf(self, docs: list[str]) -> np.ndarray: + if len(self.word_index) == 0: + raise ValueError("Word index is not initialized.") idf = np.zeros(len(self.word_index)) for doc in docs: words = set(doc.split()) From 48b64dc8a5a4ba4bd132f345fbac1491c035e885 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Thu, 4 Jan 2024 10:48:26 +0000 Subject: [PATCH 19/37] added hasattr check for mypy --- semantic_router/hybrid_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index a5e6bd4b..853fb332 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -36,7 +36,7 @@ def __init__( else: self.score_threshold = 0.82 # if routes list has been passed, we initialize index now - if isinstance(sparse_encoder, TfidfEncoder): + if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(self.sparse_encoder, 'fit'): self.sparse_encoder.fit(routes) if routes: # initialize index now @@ -54,7 +54,7 @@ def __call__(self, text: str) -> str | None: return None def add(self, route: Route): - if isinstance(self.sparse_encoder, TfidfEncoder): + if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(self.sparse_encoder, 'fit'): self.sparse_encoder.fit(self.routes + [route]) self.sparse_index = None for r in self.routes: From 6952c566cd815e5485a89bb8a662bd6110d1718e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:50:38 +0000 Subject: [PATCH 20/37] fix: updated with sparse/dense encoder --- docs/examples/hybrid-layer.ipynb | 75 ++++++++++++++++++++++++-------- 1 file changed, 57 insertions(+), 18 deletions(-) diff --git a/docs/examples/hybrid-layer.ipynb b/docs/examples/hybrid-layer.ipynb index 18404591..7b2ff298 100644 --- a/docs/examples/hybrid-layer.ipynb +++ b/docs/examples/hybrid-layer.ipynb @@ -30,11 +30,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "!pip install -qU semantic-router==0.0.11" + "#!pip install -qU semantic-router==0.0.11" ] }, { @@ -46,11 +46,20 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.schema import Route\n", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/danielgriffiths/Coding_files/Aurelio_local/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from semantic_router.route import Route\n", "\n", "politics = Route(\n", " name=\"politics\",\n", @@ -81,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -119,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -145,9 +154,17 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-03 11:00:59 INFO semantic_router.utils.logger Creating embeddings for all routes...\u001b[0m\n" + ] + } + ], "source": [ "from semantic_router.hybrid_layer import HybridRouteLayer\n", "\n", @@ -158,18 +175,40 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'politics'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "dl(\"don't you love politics?\")" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'chitchat'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "dl(\"how's the weather today?\")" ] From 84d40fd8a01b6096951d30f4a0c7a1a324ac835f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:50:51 +0000 Subject: [PATCH 21/37] fix: Route path updated --- semantic_router/encoders/tfidf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 6fc420eb..394f32fb 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -1,7 +1,7 @@ import numpy as np from collections import Counter from semantic_router.encoders import BaseEncoder -from semantic_router.schemas.route import Route +from semantic_router.route import Route from numpy.linalg import norm import string From b83bfc6a8ff23c008d4cd431bb0565dc42747fc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:51:07 +0000 Subject: [PATCH 22/37] refactor: seperate schema files removed --- semantic_router/schemas/encoder.py | 42 ----------------------- semantic_router/schemas/route.py | 7 ---- semantic_router/schemas/semantic_space.py | 17 --------- 3 files changed, 66 deletions(-) delete mode 100644 semantic_router/schemas/encoder.py delete mode 100644 semantic_router/schemas/route.py delete mode 100644 semantic_router/schemas/semantic_space.py diff --git a/semantic_router/schemas/encoder.py b/semantic_router/schemas/encoder.py deleted file mode 100644 index fbbfb2d7..00000000 --- a/semantic_router/schemas/encoder.py +++ /dev/null @@ -1,42 +0,0 @@ -from enum import Enum - -from pydantic.dataclasses import dataclass - -from semantic_router.encoders import ( - BaseEncoder, - CohereEncoder, - OpenAIEncoder, -) - - -class EncoderType(Enum): - HUGGINGFACE = "huggingface" - OPENAI = "openai" - COHERE = "cohere" - - -class RouteChoice(BaseModel): - name: str | None = None - function_call: dict | None = None - - -@dataclass -class Encoder: - type: EncoderType - name: str | None - model: BaseEncoder - - def __init__(self, type: str, name: str | None): - self.type = EncoderType(type) - self.name = name - if self.type == EncoderType.HUGGINGFACE: - raise NotImplementedError - elif self.type == EncoderType.OPENAI: - self.model = OpenAIEncoder(name) - elif self.type == EncoderType.COHERE: - self.model = CohereEncoder(name) - else: - raise ValueError - - def __call__(self, texts: list[str]) -> list[list[float]]: - return self.model(texts) diff --git a/semantic_router/schemas/route.py b/semantic_router/schemas/route.py deleted file mode 100644 index b70bc60f..00000000 --- a/semantic_router/schemas/route.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydantic import BaseModel - - -class Route(BaseModel): - name: str - utterances: list[str] - description: str | None = None diff --git a/semantic_router/schemas/semantic_space.py b/semantic_router/schemas/semantic_space.py deleted file mode 100644 index 92e7adaf..00000000 --- a/semantic_router/schemas/semantic_space.py +++ /dev/null @@ -1,17 +0,0 @@ -from pydantic.dataclasses import dataclass - -from semantic_router.schemas.route import Route - - -@dataclass -class SemanticSpace: - id: str - routes: list[Route] - encoder: str = "" - - def __init__(self, routes: list[Route] = []): - self.id = "" - self.routes = routes - - def add(self, route: Route): - self.routes.append(route) From ec77a1a5333212643a989f0a17aef3e18a8a5191 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:51:32 +0000 Subject: [PATCH 23/37] fix: dense_encoder instead of encoder --- semantic_router/hybrid_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index d62a996d..a5e6bd4b 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -98,7 +98,7 @@ def _add_routes(self, routes: list[Route]): all_utterances = [ utterance for route in routes for utterance in route.utterances ] - dense_embeds = np.array(self.encoder(all_utterances)) + dense_embeds = np.array(self.dense_encoder(all_utterances)) sparse_embeds = np.array(self.sparse_encoder(all_utterances)) # create route array From b87a4912c71c077d92cfa7b69dcfd05652928ac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:51:53 +0000 Subject: [PATCH 24/37] fix: schema file restored --- semantic_router/schema.py | 43 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 semantic_router/schema.py diff --git a/semantic_router/schema.py b/semantic_router/schema.py new file mode 100644 index 00000000..b7a3c9fa --- /dev/null +++ b/semantic_router/schema.py @@ -0,0 +1,43 @@ +from enum import Enum + +from pydantic import BaseModel +from pydantic.dataclasses import dataclass + +from semantic_router.encoders import ( + BaseEncoder, + CohereEncoder, + OpenAIEncoder, +) + + +class EncoderType(Enum): + HUGGINGFACE = "huggingface" + OPENAI = "openai" + COHERE = "cohere" + + +class RouteChoice(BaseModel): + name: str | None = None + function_call: dict | None = None + + +@dataclass +class Encoder: + type: EncoderType + name: str | None + model: BaseEncoder + + def __init__(self, type: str, name: str | None): + self.type = EncoderType(type) + self.name = name + if self.type == EncoderType.HUGGINGFACE: + raise NotImplementedError + elif self.type == EncoderType.OPENAI: + self.model = OpenAIEncoder(name) + elif self.type == EncoderType.COHERE: + self.model = CohereEncoder(name) + else: + raise ValueError + + def __call__(self, texts: list[str]) -> list[list[float]]: + return self.model(texts) From 2048019ce19ec345fa5db93952b395f51248ed56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:52:06 +0000 Subject: [PATCH 25/37] fix: Route path updated --- tests/unit/encoders/test_tfidf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py index 68e37d9e..9e4cae29 100644 --- a/tests/unit/encoders/test_tfidf.py +++ b/tests/unit/encoders/test_tfidf.py @@ -1,6 +1,6 @@ import pytest from semantic_router.encoders import TfidfEncoder -from semantic_router.schemas.route import Route +from semantic_router.route import Route @pytest.fixture From bbd06180c7e8f394238faabab45b4e3a4b8712ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:52:51 +0000 Subject: [PATCH 26/37] fix: base encoder mocker path added --- tests/unit/test_hybrid_layer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index b77f51ad..567c5492 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -24,8 +24,15 @@ def mock_encoder_call(utterances): @pytest.fixture -def base_encoder(): - return BaseEncoder(name="test-encoder") +def base_encoder(mocker): + mock_base_encoder = BaseEncoder(name="test-encoder") + mocker.patch.object(BaseEncoder, "__call__", return_value=[[0.1, 0.2, 0.3]]) + return mock_base_encoder + + +# @pytest.fixture +# def base_encoder(): +# return BaseEncoder(name="test-encoder") @pytest.fixture From c340e08971b008ed0ca6d2d5990dab0c4d7a8152 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Wed, 3 Jan 2024 15:53:45 +0000 Subject: [PATCH 27/37] fix: schema path updated --- tests/unit/test_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index f4f97623..97b5028e 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -1,6 +1,6 @@ import pytest -from semantic_router.schemas.encoder import ( +from semantic_router.schema import ( CohereEncoder, Encoder, EncoderType, From f94529e0cfc4c34a00c824cdf153a9a7275964e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Thu, 4 Jan 2024 10:48:08 +0000 Subject: [PATCH 28/37] fixed: removed none types for mypy --- semantic_router/encoders/tfidf.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 394f32fb..68baceaa 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -7,16 +7,16 @@ class TfidfEncoder(BaseEncoder): - idf: dict | None = None - word_index: dict | None = None + idf: np.ndarray + word_index: dict def __init__(self, name: str = "tfidf"): super().__init__(name=name) - self.word_index = None - self.idf = None + self.word_index = {} + self.idf = np.array([]) def __call__(self, docs: list[str]) -> list[list[float]]: - if self.word_index is None or self.idf is None: + if len(self.word_index) == 0 or self.idf.size == 0: raise ValueError("Vectorizer is not initialized.") if len(docs) == 0: raise ValueError("No documents to encode.") @@ -43,6 +43,8 @@ def _build_word_index(self, docs: list[str]) -> dict: return word_index def _compute_tf(self, docs: list[str]) -> np.ndarray: + if len(self.word_index) == 0: + raise ValueError("Word index is not initialized.") tf = np.zeros((len(docs), len(self.word_index))) for i, doc in enumerate(docs): word_counts = Counter(doc.split()) @@ -54,6 +56,8 @@ def _compute_tf(self, docs: list[str]) -> np.ndarray: return tf def _compute_idf(self, docs: list[str]) -> np.ndarray: + if len(self.word_index) == 0: + raise ValueError("Word index is not initialized.") idf = np.zeros(len(self.word_index)) for doc in docs: words = set(doc.split()) From dd9f07a0451dfe585979b3737fce0d27e4b8d1f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Thu, 4 Jan 2024 10:48:26 +0000 Subject: [PATCH 29/37] fix: added hasattr check for mypy --- semantic_router/hybrid_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index a5e6bd4b..853fb332 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -36,7 +36,7 @@ def __init__( else: self.score_threshold = 0.82 # if routes list has been passed, we initialize index now - if isinstance(sparse_encoder, TfidfEncoder): + if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(self.sparse_encoder, 'fit'): self.sparse_encoder.fit(routes) if routes: # initialize index now @@ -54,7 +54,7 @@ def __call__(self, text: str) -> str | None: return None def add(self, route: Route): - if isinstance(self.sparse_encoder, TfidfEncoder): + if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(self.sparse_encoder, 'fit'): self.sparse_encoder.fit(self.routes + [route]) self.sparse_index = None for r in self.routes: From 422f5f77c7e2b6c9162eef76bbca2cda5e80d3dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Fri, 5 Jan 2024 09:52:50 +0000 Subject: [PATCH 30/37] fix: added default values to class variables --- semantic_router/encoders/tfidf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 68baceaa..2782b1ed 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -4,11 +4,13 @@ from semantic_router.route import Route from numpy.linalg import norm import string +from typing import Dict +from numpy import ndarray class TfidfEncoder(BaseEncoder): - idf: np.ndarray - word_index: dict + idf: ndarray = np.array([]) + word_index: Dict = {} def __init__(self, name: str = "tfidf"): super().__init__(name=name) From 6bd5da3b0eba9d056bf6e8f9a947629ede956549 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Fri, 5 Jan 2024 09:53:55 +0000 Subject: [PATCH 31/37] fix: black lint reformat hybrid layer --- semantic_router/hybrid_layer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 853fb332..40c95d63 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -36,7 +36,9 @@ def __init__( else: self.score_threshold = 0.82 # if routes list has been passed, we initialize index now - if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(self.sparse_encoder, 'fit'): + if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( + self.sparse_encoder, "fit" + ): self.sparse_encoder.fit(routes) if routes: # initialize index now @@ -54,7 +56,9 @@ def __call__(self, text: str) -> str | None: return None def add(self, route: Route): - if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(self.sparse_encoder, 'fit'): + if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( + self.sparse_encoder, "fit" + ): self.sparse_encoder.fit(self.routes + [route]) self.sparse_index = None for r in self.routes: From 0f1ecf70c6a6f3225b47af6e9fb39bbd5e2bc5b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Fri, 5 Jan 2024 09:54:57 +0000 Subject: [PATCH 32/37] fix: added default values to tfidf tests --- tests/unit/encoders/test_tfidf.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py index 9e4cae29..5aa8dfc8 100644 --- a/tests/unit/encoders/test_tfidf.py +++ b/tests/unit/encoders/test_tfidf.py @@ -1,6 +1,7 @@ import pytest from semantic_router.encoders import TfidfEncoder from semantic_router.route import Route +import numpy as np @pytest.fixture @@ -10,8 +11,8 @@ def tfidf_encoder(): class TestTfidfEncoder: def test_initialization(self, tfidf_encoder): - assert tfidf_encoder.word_index is None - assert tfidf_encoder.idf is None + assert tfidf_encoder.word_index == {} + assert (tfidf_encoder.idf == np.array([])).all() def test_fit(self, tfidf_encoder): routes = [ @@ -21,8 +22,8 @@ def test_fit(self, tfidf_encoder): ) ] tfidf_encoder.fit(routes) - assert tfidf_encoder.word_index is not None - assert tfidf_encoder.idf is not None + assert tfidf_encoder.word_index != {} + assert not np.array_equal(tfidf_encoder.idf, np.array([])) def test_call_method(self, tfidf_encoder): routes = [ From d96c552230ebba4faddb25a651f3b68a379048cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Mon, 8 Jan 2024 15:33:31 +0000 Subject: [PATCH 33/37] fix: created embedding helper functions --- semantic_router/hybrid_layer.py | 52 ++++++++++++++------------------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 40c95d63..d7200c26 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -56,45 +56,33 @@ def __call__(self, text: str) -> str | None: return None def add(self, route: Route): + self._add_route(route=route) + + def _add_route(self, route: Route): + self.routes += [route] + + self.update_dense_embeddings_index(route.utterances) + if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( self.sparse_encoder, "fit" ): - self.sparse_encoder.fit(self.routes + [route]) + self.sparse_encoder.fit(self.routes) + # re-build index self.sparse_index = None - for r in self.routes: - self.compute_and_store_sparse_embeddings(r) - self.routes.append(route) - self._add_route(route=route) + all_utterances = [ + utterance for route in self.routes for utterance in route.utterances + ] + self.update_sparse_embeddings_index(all_utterances) + else: + self.update_sparse_embeddings_index(route.utterances) - def _add_route(self, route: Route): - # create embeddings - dense_embeds = np.array(self.dense_encoder(route.utterances)) # * self.alpha - self.compute_and_store_sparse_embeddings(route) # create route array if self.categories is None: self.categories = np.array([route.name] * len(route.utterances)) - self.utterances = np.array(route.utterances) else: str_arr = np.array([route.name] * len(route.utterances)) self.categories = np.concatenate([self.categories, str_arr]) - self.utterances = np.concatenate( - [self.utterances, np.array(route.utterances)] - ) - # create utterance array (the dense index) - if self.index is None: - self.index = dense_embeds - else: - self.index = np.concatenate([self.index, dense_embeds]) - - def compute_and_store_sparse_embeddings(self, route: Route): - sparse_embeds = np.array( - self.sparse_encoder(route.utterances) - ) # * (1 - self.alpha) - # create sparse utterance array - if self.sparse_index is None: - self.sparse_index = sparse_embeds - else: - self.sparse_index = np.concatenate([self.sparse_index, sparse_embeds]) + self.routes.append(route) def _add_routes(self, routes: list[Route]): # create embeddings for all routes @@ -102,8 +90,8 @@ def _add_routes(self, routes: list[Route]): all_utterances = [ utterance for route in routes for utterance in route.utterances ] - dense_embeds = np.array(self.dense_encoder(all_utterances)) - sparse_embeds = np.array(self.sparse_encoder(all_utterances)) + self.update_dense_embeddings_index(all_utterances) + self.update_sparse_embeddings_index(all_utterances) # create route array route_names = [route.name for route in routes for _ in route.utterances] @@ -114,6 +102,8 @@ def _add_routes(self, routes: list[Route]): else route_array ) + def update_dense_embeddings_index(self, utterances: list): + dense_embeds = np.array(self.dense_encoder(utterances)) # create utterance array (the dense index) self.index = ( np.concatenate([self.index, dense_embeds]) @@ -121,6 +111,8 @@ def _add_routes(self, routes: list[Route]): else dense_embeds ) + def update_sparse_embeddings_index(self, utterances: list): + sparse_embeds = np.array(self.sparse_encoder(utterances)) # create sparse utterance array self.sparse_index = ( np.concatenate([self.sparse_index, sparse_embeds]) From 138d540b9adc3f141851a76cbdcda176d8567496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Mon, 8 Jan 2024 15:34:30 +0000 Subject: [PATCH 34/37] feat: added more tfidf tests --- tests/unit/encoders/test_tfidf.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py index 5aa8dfc8..21524c91 100644 --- a/tests/unit/encoders/test_tfidf.py +++ b/tests/unit/encoders/test_tfidf.py @@ -39,7 +39,7 @@ def test_call_method(self, tfidf_encoder): isinstance(sublist, list) for sublist in result ), "Each item in result should be a list" - def test_call_method_no_docs(self, tfidf_encoder): + def test_call_method_no_docs_tfidf(self, tfidf_encoder): with pytest.raises(ValueError): tfidf_encoder([]) @@ -60,3 +60,26 @@ def test_call_method_no_word(self, tfidf_encoder): def test_call_method_with_uninitialized_model(self, tfidf_encoder): with pytest.raises(ValueError): tfidf_encoder(["test"]) + + def test_call_method_no_docs(self, tfidf_encoder): + with pytest.raises(ValueError, match="No documents to encode."): + tfidf_encoder([]) + + def test_compute_tf_no_word_index(self, tfidf_encoder): + with pytest.raises(ValueError, match="Word index is not initialized."): + tfidf_encoder._compute_tf(["some docs"]) + + def test_compute_tf_with_word_in_word_index(self, tfidf_encoder): + routes = [ + Route( + name="test_route", + utterances=["some docs", "and more docs", "and even more docs"], + ) + ] + tfidf_encoder.fit(routes) + tf = tfidf_encoder._compute_tf(["some docs"]) + assert tf.shape == (1, len(tfidf_encoder.word_index)) + + def test_compute_idf_no_word_index(self, tfidf_encoder): + with pytest.raises(ValueError, match="Word index is not initialized."): + tfidf_encoder._compute_idf(["some docs"]) From 3698b07fb4b67aa0a2c8242b488f504168170792 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= Date: Mon, 8 Jan 2024 15:35:11 +0000 Subject: [PATCH 35/37] feat: added add route tfidf test --- tests/unit/test_hybrid_layer.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 567c5492..027d8750 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -30,11 +30,6 @@ def base_encoder(mocker): return mock_base_encoder -# @pytest.fixture -# def base_encoder(): -# return BaseEncoder(name="test-encoder") - - @pytest.fixture def cohere_encoder(mocker): mocker.patch.object(CohereEncoder, "__call__", side_effect=mock_encoder_call) @@ -165,5 +160,18 @@ def test_failover_score_threshold(self, base_encoder, bm25_encoder): ) assert route_layer.score_threshold == 0.82 + def test_add_route_tfidf(self, cohere_encoder, tfidf_encoder, routes): + hybrid_route_layer = HybridRouteLayer( + dense_encoder=cohere_encoder, + sparse_encoder=tfidf_encoder, + routes=routes[:-1], + ) + hybrid_route_layer.add(routes[-1]) + all_utterances = [ + utterance for route in routes for utterance in route.utterances + ] + assert hybrid_route_layer.sparse_index is not None + assert len(hybrid_route_layer.sparse_index) == len(all_utterances) + # Add more tests for edge cases and error handling as needed. From 7679e261afa0cc4038e89da5e34c5cdd22fe5559 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Fri, 12 Jan 2024 00:22:38 +0000 Subject: [PATCH 36/37] fixes --- poetry.lock | 137 +++++++------------------------- semantic_router/hybrid_layer.py | 3 +- tests/unit/test_hybrid_layer.py | 2 +- 3 files changed, 32 insertions(+), 110 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1328ceea..96a21bbc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -250,17 +250,6 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] -[[package]] -name = "cachetools" -version = "5.3.2" -description = "Extensible memoizing collections and decorators" -optional = false -python-versions = ">=3.7" -files = [ - {file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"}, - {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"}, -] - [[package]] name = "certifi" version = "2023.11.17" @@ -336,17 +325,6 @@ files = [ [package.dependencies] pycparser = "*" -[[package]] -name = "chardet" -version = "5.2.0" -description = "Universal encoding detector for Python 3" -optional = false -python-versions = ">=3.7" -files = [ - {file = "chardet-5.2.0-py3-none-any.whl", hash = "sha256:e1cf59446890a00105fe7b7912492ea04b6e6f06d4b742b2c788469e34c82970"}, - {file = "chardet-5.2.0.tar.gz", hash = "sha256:1b3b6ff479a8c414bc3fa2c0852995695c4a026dcd6d0633b2dd092ca39c1cf7"}, -] - [[package]] name = "charset-normalizer" version = "3.3.2" @@ -646,17 +624,6 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] -[[package]] -name = "distlib" -version = "0.3.8" -description = "Distribution utilities" -optional = false -python-versions = "*" -files = [ - {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, - {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, -] - [[package]] name = "distro" version = "1.9.0" @@ -779,7 +746,7 @@ tqdm = ">=4.65,<5.0" name = "filelock" version = "3.13.1" description = "A platform independent file lock." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, @@ -1177,7 +1144,7 @@ i18n = ["Babel (>=2.7)"] name = "joblib" version = "1.3.2" description = "Lightweight pipelining with Python functions" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"}, @@ -1254,6 +1221,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -1304,7 +1281,7 @@ traitlets = "*" name = "mmh3" version = "3.1.0" description = "Python wrapper for MurmurHash (MurmurHash3), a set of fast and robust hash functions." -optional = false +optional = true python-versions = "*" files = [ {file = "mmh3-3.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:16ee043b1bac040b4324b8baee39df9fdca480a560a6d74f2eef66a5009a234e"}, @@ -1535,7 +1512,7 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] name = "nltk" version = "3.8.1" description = "Natural Language Toolkit" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"}, @@ -1891,7 +1868,7 @@ ptyprocess = ">=0.5" name = "pinecone-text" version = "0.7.1" description = "Text utilities library by Pinecone.io" -optional = false +optional = true python-versions = ">=3.8,<4.0" files = [ {file = "pinecone_text-0.7.1-py3-none-any.whl", hash = "sha256:b806b5d66190d09888ed2d3bcdef49534aa9200b9da521371a062e6ccc79bb2c"}, @@ -2105,25 +2082,6 @@ files = [ plugins = ["importlib-metadata"] windows-terminal = ["colorama (>=0.4.6)"] -[[package]] -name = "pyproject-api" -version = "1.6.1" -description = "API to interact with the python pyproject.toml based projects" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pyproject_api-1.6.1-py3-none-any.whl", hash = "sha256:4c0116d60476b0786c88692cf4e325a9814965e2469c5998b830bba16b183675"}, - {file = "pyproject_api-1.6.1.tar.gz", hash = "sha256:1817dc018adc0d1ff9ca1ed8c60e1623d5aaca40814b953af14a9cf9a5cae538"}, -] - -[package.dependencies] -packaging = ">=23.1" -tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} - -[package.extras] -docs = ["furo (>=2023.8.19)", "sphinx (<7.2)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "setuptools (>=68.1.2)", "wheel (>=0.41.2)"] - [[package]] name = "pyreadline3" version = "3.4.1" @@ -2261,6 +2219,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2268,8 +2227,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2286,6 +2252,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2293,6 +2260,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -2407,7 +2375,7 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "regex" version = "2023.12.25" description = "Alternative regular expression module, to replace re." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "regex-2023.12.25-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0694219a1d54336fd0445ea382d49d36882415c0134ee1e8332afd1529f0baa5"}, @@ -2936,33 +2904,6 @@ files = [ {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"}, ] -[[package]] -name = "tox" -version = "4.11.4" -description = "tox is a generic virtualenv management and test command line tool" -optional = false -python-versions = ">=3.8" -files = [ - {file = "tox-4.11.4-py3-none-any.whl", hash = "sha256:2adb83d68f27116812b69aa36676a8d6a52249cb0d173649de0e7d0c2e3e7229"}, - {file = "tox-4.11.4.tar.gz", hash = "sha256:73a7240778fabf305aeb05ab8ea26e575e042ab5a18d71d0ed13e343a51d6ce1"}, -] - -[package.dependencies] -cachetools = ">=5.3.1" -chardet = ">=5.2" -colorama = ">=0.4.6" -filelock = ">=3.12.3" -packaging = ">=23.1" -platformdirs = ">=3.10" -pluggy = ">=1.3" -pyproject-api = ">=1.6.1" -tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} -virtualenv = ">=20.24.3" - -[package.extras] -docs = ["furo (>=2023.8.19)", "sphinx (>=7.2.4)", "sphinx-argparse-cli (>=1.11.1)", "sphinx-autodoc-typehints (>=1.24)", "sphinx-copybutton (>=0.5.2)", "sphinx-inline-tabs (>=2023.4.21)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -testing = ["build[virtualenv] (>=0.10)", "covdefaults (>=2.3)", "detect-test-pollution (>=1.1.1)", "devpi-process (>=1)", "diff-cover (>=7.7)", "distlib (>=0.3.7)", "flaky (>=3.7)", "hatch-vcs (>=0.3)", "hatchling (>=1.18)", "psutil (>=5.9.5)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-xdist (>=3.3.1)", "re-assert (>=1.1)", "time-machine (>=2.12)", "wheel (>=0.41.2)"] - [[package]] name = "tqdm" version = "4.66.1" @@ -3129,26 +3070,6 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] -[[package]] -name = "virtualenv" -version = "20.25.0" -description = "Virtual Python Environment builder" -optional = false -python-versions = ">=3.7" -files = [ - {file = "virtualenv-20.25.0-py3-none-any.whl", hash = "sha256:4238949c5ffe6876362d9c0180fc6c3a824a7b12b80604eeb8085f2ed7460de3"}, - {file = "virtualenv-20.25.0.tar.gz", hash = "sha256:bf51c0d9c7dd63ea8e44086fa1e4fb1093a31e963b86959257378aef020e1f1b"}, -] - -[package.dependencies] -distlib = ">=0.3.7,<1" -filelock = ">=3.12.2,<4" -platformdirs = ">=3.9.1,<5" - -[package.extras] -docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] - [[package]] name = "wcwidth" version = "0.2.13" @@ -3164,7 +3085,7 @@ files = [ name = "wget" version = "3.2" description = "pure python download utility" -optional = false +optional = true python-versions = "*" files = [ {file = "wget-3.2.zip", hash = "sha256:35e630eca2aa50ce998b9b1a127bb26b30dfee573702782aa982f875e3f16061"}, @@ -3296,4 +3217,4 @@ local = ["torch", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4ceb0344e006fc7657c66548321ff51d34a297ee6f9ab069fdc53e1256024a12" +content-hash = "5b459c6820bcf5c2b73daf0ecfcbbac95019311c74d88634bd7188650e48b749" diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 0979d420..b56e1cd5 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -5,6 +5,7 @@ from semantic_router.encoders import ( BaseEncoder, + BM25Encoder, TfidfEncoder, ) from semantic_router.route import Route @@ -24,7 +25,7 @@ def __init__( routes: list[Route] = [], alpha: float = 0.3, ): - self.encoder = dense_encoder + self.encoder = encoder self.score_threshold = self.encoder.score_threshold if sparse_encoder is None: diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index cf1e66c3..6a3d225a 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -164,7 +164,7 @@ def test_failover_score_threshold(self, base_encoder): def test_add_route_tfidf(self, cohere_encoder, tfidf_encoder, routes): hybrid_route_layer = HybridRouteLayer( - dense_encoder=cohere_encoder, + encoder=cohere_encoder, sparse_encoder=tfidf_encoder, routes=routes[:-1], ) From c651b4afb8ced5b12b41b4100cde594b2fb57013 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Fri, 12 Jan 2024 00:24:14 +0000 Subject: [PATCH 37/37] lint --- semantic_router/encoders/tfidf.py | 12 +++++++----- semantic_router/hybrid_layer.py | 4 ++-- tests/unit/encoders/test_tfidf.py | 3 ++- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 5e81a463..0809b5ad 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -1,11 +1,13 @@ -import numpy as np -from collections import Counter -from semantic_router.encoders import BaseEncoder -from semantic_router.route import Route -from numpy.linalg import norm import string +from collections import Counter from typing import Dict + +import numpy as np from numpy import ndarray +from numpy.linalg import norm + +from semantic_router.encoders import BaseEncoder +from semantic_router.route import Route class TfidfEncoder(BaseEncoder): diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index b56e1cd5..62c87efc 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -27,13 +27,13 @@ def __init__( ): self.encoder = encoder self.score_threshold = self.encoder.score_threshold - + if sparse_encoder is None: logger.warning("No sparse_encoder provided. Using default BM25Encoder.") self.sparse_encoder = BM25Encoder() else: self.sparse_encoder = sparse_encoder - + self.alpha = alpha self.routes = routes if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py index 6bb9fab3..7664433d 100644 --- a/tests/unit/encoders/test_tfidf.py +++ b/tests/unit/encoders/test_tfidf.py @@ -1,7 +1,8 @@ +import numpy as np import pytest + from semantic_router.encoders import TfidfEncoder from semantic_router.route import Route -import numpy as np @pytest.fixture