diff --git a/docs/encoders/fastembed.ipynb b/docs/encoders/fastembed.ipynb new file mode 100644 index 00000000..27bad545 --- /dev/null +++ b/docs/encoders/fastembed.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/encoders/fastembed.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/encoders/fastembed.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using FastEmbedEncoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "FastEmbed is a _lightweight and fast_ embedding library built for generating embeddings. It can be run locally and supports many open source encoders.\n", + "\n", + "Beyond being a local, open source library, there are two key reasons we might want to run this library over other open source alternatives:\n", + "\n", + "* **Lightweight and Fast**: The library uses an ONNX runtime so there is no heavy PyTorch dependency, supports quantized model weights (smaller memory footprint), is developed for running on CPU, and uses data-parallelism for encoding large datasets.\n", + "\n", + "* **Open-weight models**: FastEmbed supports many open source and open-weight models, included some that outperform popular encoders like OpenAI's Ada-002." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing semantic-router with the `[fastembed]` flag to include all necessary dependencies for `FastEmbedEncoder`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU \"semantic-router[fastembed]==0.0.15\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by defining a dictionary mapping routes to example phrases that should trigger those routes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\",\n", + " \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define another for good measure:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we initialize our embedding model, you can find a list of [all available embedding models here](https://qdrant.github.io/fastembed/examples/Supported_Models/):" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router.encoders import FastEmbedEncoder\n", + "\n", + "encoder = FastEmbedEncoder()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_**⚠️ If you see an ImportError, you must install the FastEmbed library. You can do so by installing Semantic Router using `pip install -qU \"semantic-router[fastembed]\"`.**_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-06 16:53:16 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + ] + } + ], + "source": [ + "from semantic_router.layer import RouteLayer\n", + "\n", + "rl = RouteLayer(encoder=encoder, routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can test it:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='politics', function_call=None)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects?" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name=None, function_call=None)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"I'm interested in learning about llama 2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we return `None` because no matches were identified." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/encoders/huggingface.ipynb b/docs/encoders/huggingface.ipynb new file mode 100644 index 00000000..4e9c28cd --- /dev/null +++ b/docs/encoders/huggingface.ipynb @@ -0,0 +1,271 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/encoders/huggingface.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/encoders/huggingface.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using HuggingFaceEncoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "HuggingFace is a huge ecosystem of open source models. It can be run locally and supports the largest library of encoders." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing semantic-router with the `[local]` flag to include all necessary dependencies for `HuggingFaceEncoder`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU \"semantic-router[local]==0.0.16\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by defining a dictionary mapping routes to example phrases that should trigger those routes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\",\n", + " \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_**⚠️ If you see an ImportError, you must install local dependencies. You can do so by installing Semantic Router using `pip install -qU \"semantic-router[local]\"`.**_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define another for good measure:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we initialize our embedding model." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jamesbriggs/opt/anaconda3/envs/decision-layer/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "tokenizer_config.json: 100%|██████████| 350/350 [00:00<00:00, 1.06MB/s]\n", + "vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 1.05MB/s]\n", + "tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 1.43MB/s]\n", + "special_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 386kB/s]\n", + "config.json: 100%|██████████| 612/612 [00:00<00:00, 2.90MB/s]\n", + "pytorch_model.bin: 100%|██████████| 90.9M/90.9M [00:01<00:00, 63.2MB/s]\n" + ] + } + ], + "source": [ + "from semantic_router.encoders import HuggingFaceEncoder\n", + "\n", + "encoder = HuggingFaceEncoder()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-09 00:22:35 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + ] + } + ], + "source": [ + "from semantic_router.layer import RouteLayer\n", + "\n", + "rl = RouteLayer(encoder=encoder, routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can test it:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='politics', function_call=None)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects?" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name=None, function_call=None)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"I'm interested in learning about llama 2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we return `None` because no matches were identified." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/poetry.lock b/poetry.lock index d0f80d03..68a8e12e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1069,13 +1069,13 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" [[package]] name = "ipython" -version = "8.19.0" +version = "8.20.0" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.10" files = [ - {file = "ipython-8.19.0-py3-none-any.whl", hash = "sha256:2f55d59370f59d0d2b2212109fe0e6035cfea436b1c0e6150ad2244746272ec5"}, - {file = "ipython-8.19.0.tar.gz", hash = "sha256:ac4da4ecf0042fb4e0ce57c60430c2db3c719fa8bdf92f8631d6bd8a5785d1f0"}, + {file = "ipython-8.20.0-py3-none-any.whl", hash = "sha256:bc9716aad6f29f36c449e30821c9dd0c1c1a7b59ddcc26931685b87b4c569619"}, + {file = "ipython-8.20.0.tar.gz", hash = "sha256:2f21bd3fc1d51550c89ee3944ae04bbc7bc79e129ea0937da6e6c68bfdbf117a"}, ] [package.dependencies] @@ -1122,6 +1122,23 @@ docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alab qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] +[[package]] +name = "jinja2" +version = "3.1.2" +description = "A very fast and expressive template engine." +optional = true +python-versions = ">=3.7" +files = [ + {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, + {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, +] + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + [[package]] name = "joblib" version = "1.3.2" @@ -1157,13 +1174,13 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt [[package]] name = "jupyter-core" -version = "5.7.0" +version = "5.7.1" description = "Jupyter core package. A base package on which Jupyter projects rely." optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_core-5.7.0-py3-none-any.whl", hash = "sha256:16eea462f7dad23ba9f86542bdf17f830804e2028eb48d609b6134d91681e983"}, - {file = "jupyter_core-5.7.0.tar.gz", hash = "sha256:cb8d3ed92144d2463a3c5664fdd686a3f0c1442ea45df8babb1c1a9e6333fe03"}, + {file = "jupyter_core-5.7.1-py3-none-any.whl", hash = "sha256:c65c82126453a723a2804aa52409930434598fd9d35091d63dfb919d2b765bb7"}, + {file = "jupyter_core-5.7.1.tar.gz", hash = "sha256:de61a9d7fc71240f688b2fb5ab659fbb56979458dc66a71decd098e03c79e218"}, ] [package.dependencies] @@ -1175,6 +1192,75 @@ traitlets = ">=5.3" docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] +[[package]] +name = "markupsafe" +version = "2.1.3" +description = "Safely add untrusted strings to HTML/XML markup." +optional = true +python-versions = ">=3.7" +files = [ + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win32.whl", hash = "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win32.whl", hash = "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl", hash = "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win32.whl", hash = "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win32.whl", hash = "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"}, + {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, +] + [[package]] name = "matplotlib-inline" version = "0.1.6" @@ -1402,6 +1488,24 @@ files = [ {file = "nest_asyncio-1.5.8.tar.gz", hash = "sha256:25aa2ca0d2a5b5531956b9e273b45cf664cae2b145101d73b86b199978d48fdb"}, ] +[[package]] +name = "networkx" +version = "3.2.1" +description = "Python package for creating and manipulating graphs and networks" +optional = true +python-versions = ">=3.9" +files = [ + {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, + {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, +] + +[package.extras] +default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] +test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] + [[package]] name = "nltk" version = "3.8.1" @@ -1461,6 +1565,147 @@ files = [ {file = "numpy-1.25.2.tar.gz", hash = "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760"}, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.1.3.1" +description = "CUBLAS native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" +description = "CUDA profiling tools runtime libs." +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" +description = "NVRTC native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" +description = "CUDA Runtime native Libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" +description = "cuDNN runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.0.2.54" +description = "CUFFT native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.2.106" +description = "CURAND native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" +description = "CUDA solver native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" +description = "CUSPARSE native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.18.1" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.3.101" +description = "Nvidia JIT LTO Library" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, +] + [[package]] name = "onnx" version = "1.15.0" @@ -2273,6 +2518,125 @@ files = [ {file = "ruff-0.1.11.tar.gz", hash = "sha256:f9d4d88cb6eeb4dfe20f9f0519bd2eaba8119bde87c3d5065c541dbae2b5a2cb"}, ] +[[package]] +name = "safetensors" +version = "0.4.1" +description = "" +optional = true +python-versions = ">=3.7" +files = [ + {file = "safetensors-0.4.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:cba01c6b76e01ec453933b3b3c0157c59b52881c83eaa0f7666244e71aa75fd1"}, + {file = "safetensors-0.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7a8f6f679d97ea0135c7935c202feefbd042c149aa70ee759855e890c01c7814"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbc2ce1f5ae5143a7fb72b71fa71db6a42b4f6cf912aa3acdc6b914084778e68"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2d87d993eaefe6611a9c241a8bd364a5f1ffed5771c74840363a6c4ed8d868f6"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:097e9af2efa8778cd2f0cba451784253e62fa7cc9fc73c0744d27212f7294e25"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d10a9f7bae608ccfdc009351f01dc3d8535ff57f9488a58a4c38e45bf954fe93"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:270b99885ec14abfd56c1d7f28ada81740a9220b4bae960c3de1c6fe84af9e4d"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:285b52a481e7ba93e29ad4ec5841ef2c4479ef0a6c633c4e2629e0508453577b"}, + {file = "safetensors-0.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c3c9f0ca510e0de95abd6424789dcbc879942a3a4e29b0dfa99d9427bf1da75c"}, + {file = "safetensors-0.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:88b4653059c903015284a9722f9a46838c654257173b279c8f6f46dbe80b612d"}, + {file = "safetensors-0.4.1-cp310-none-win32.whl", hash = "sha256:2fe6926110e3d425c4b684a4379b7796fdc26ad7d16922ea1696c8e6ea7e920f"}, + {file = "safetensors-0.4.1-cp310-none-win_amd64.whl", hash = "sha256:a79e16222106b2f5edbca1b8185661477d8971b659a3c814cc6f15181a9b34c8"}, + {file = "safetensors-0.4.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:d93321eea0dd7e81b283e47a1d20dee6069165cc158286316d0d06d340de8fe8"}, + {file = "safetensors-0.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8ff8e41c8037db17de0ea2a23bc684f43eaf623be7d34906fe1ac10985b8365e"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39d36f1d88468a87c437a1bc27c502e71b6ca44c385a9117a9f9ba03a75cc9c6"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7ef010e9afcb4057fb6be3d0a0cfa07aac04fe97ef73fe4a23138d8522ba7c17"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b287304f2b2220d51ccb51fd857761e78bcffbeabe7b0238f8dc36f2edfd9542"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e09000b2599e1836314430f81a3884c66a5cbabdff5d9f175b5d560d4de38d78"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9c80ce0001efa16066358d2dd77993adc25f5a6c61850e4ad096a2232930bce"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:413e1f6ac248f7d1b755199a06635e70c3515493d3b41ba46063dec33aa2ebb7"}, + {file = "safetensors-0.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3ac139377cfe71ba04573f1cda66e663b7c3e95be850e9e6c2dd4b5984bd513"}, + {file = "safetensors-0.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:04157d008385bea66d12fe90844a80d4a76dc25ec5230b5bd9a630496d1b7c03"}, + {file = "safetensors-0.4.1-cp311-none-win32.whl", hash = "sha256:5f25297148ec665f0deb8bd67e9564634d8d6841041ab5393ccfe203379ea88b"}, + {file = "safetensors-0.4.1-cp311-none-win_amd64.whl", hash = "sha256:b2f8877990a72ff595507b80f4b69036a9a1986a641f8681adf3425d97d3d2a5"}, + {file = "safetensors-0.4.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:eb2c1da1cc39509d1a55620a5f4d14f8911c47a89c926a96e6f4876e864375a3"}, + {file = "safetensors-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:303d2c0415cf15a28f8d7f17379ea3c34c2b466119118a34edd9965983a1a8a6"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb4cb3e37a9b961ddd68e873b29fe9ab4a081e3703412e34aedd2b7a8e9cafd9"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ae5497adc68669db2fed7cb2dad81e6a6106e79c9a132da3efdb6af1db1014fa"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b30abd0cddfe959d1daedf92edcd1b445521ebf7ddefc20860ed01486b33c90"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d784a98c492c751f228a4a894c3b8a092ff08b24e73b5568938c28b8c0e8f8df"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e57a5ab08b0ec7a7caf30d2ac79bb30c89168431aca4f8854464bb9461686925"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:edcf3121890b5f0616aa5a54683b1a5d2332037b970e507d6bb7841a3a596556"}, + {file = "safetensors-0.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fdb58dee173ef33634c3016c459d671ca12d11e6acf9db008261cbe58107e579"}, + {file = "safetensors-0.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:780dc21eb3fd32ddd0e8c904bdb0290f2454f4ac21ae71e94f9ce72db1900a5a"}, + {file = "safetensors-0.4.1-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:48901bd540f8a3c1791314bc5c8a170927bf7f6acddb75bf0a263d081a3637d4"}, + {file = "safetensors-0.4.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:3b0b7b2d5976fbed8a05e2bbdce5816a59e6902e9e7c7e07dc723637ed539787"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f69903ff49cb30b9227fb5d029bea276ea20d04b06803877a420c5b1b74c689"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0ddd050e01f3e843aa8c1c27bf68675b8a08e385d0045487af4d70418c3cb356"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a82bc2bd7a9a0e08239bdd6d7774d64121f136add93dfa344a2f1a6d7ef35fa"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6ace9e66a40f98a216ad661245782483cf79cf56eb2b112650bb904b0baa9db5"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82cbb8f4d022f2e94498cbefca900698b8ded3d4f85212f47da614001ff06652"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:791edc10a3c359a2f5f52d5cddab0df8a45107d91027d86c3d44e57162e5d934"}, + {file = "safetensors-0.4.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:83c2cfbe8c6304f0891e7bb378d56f66d2148972eeb5f747cd8a2246886f0d8c"}, + {file = "safetensors-0.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:04dd14f53f5500eb4c4149674216ba1000670efbcf4b1b5c2643eb244e7882ea"}, + {file = "safetensors-0.4.1-cp37-none-win32.whl", hash = "sha256:d5b3defa74f3723a388bfde2f5d488742bc4879682bd93267c09a3bcdf8f869b"}, + {file = "safetensors-0.4.1-cp37-none-win_amd64.whl", hash = "sha256:25a043cbb59d4f75e9dd87fdf5c009dd8830105a2c57ace49b72167dd9808111"}, + {file = "safetensors-0.4.1-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:3f6a520af7f2717c5ecba112041f2c8af1ca6480b97bf957aba81ed9642e654c"}, + {file = "safetensors-0.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c3807ac3b16288dffebb3474b555b56fe466baa677dfc16290dcd02dca1ab228"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b58ba13a9e82b4bc3fc221914f6ef237fe6c2adb13cede3ace64d1aacf49610"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dac4bb42f8679aadc59bd91a4c5a1784a758ad49d0912995945cd674089f628e"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:911b48dc09e321a194def3a7431662ff4f03646832f3a8915bbf0f449b8a5fcb"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82571d20288c975c1b30b08deb9b1c3550f36b31191e1e81fae87669a92217d0"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da52ee0dc8ba03348ffceab767bd8230842fdf78f8a996e2a16445747143a778"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2536b11ce665834201072e9397404170f93f3be10cca9995b909f023a04501ee"}, + {file = "safetensors-0.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:998fbac99ca956c3a09fe07cc0b35fac26a521fa8865a690686d889f0ff4e4a6"}, + {file = "safetensors-0.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:845be0aafabf2a60c2d482d4e93023fecffe5e5443d801d7a7741bae9de41233"}, + {file = "safetensors-0.4.1-cp38-none-win32.whl", hash = "sha256:ce7a28bc8af685a69d7e869d09d3e180a275e3281e29cf5f1c7319e231932cc7"}, + {file = "safetensors-0.4.1-cp38-none-win_amd64.whl", hash = "sha256:e056fb9e22d118cc546107f97dc28b449d88274207dd28872bd668c86216e4f6"}, + {file = "safetensors-0.4.1-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:bdc0d039e44a727824639824090bd8869535f729878fa248addd3dc01db30eae"}, + {file = "safetensors-0.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3c1b1d510c7aba71504ece87bf393ea82638df56303e371e5e2cf09d18977dd7"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bd0afd95c1e497f520e680ea01e0397c0868a3a3030e128438cf6e9e3fcd671"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f603bdd8deac6726d39f41688ed353c532dd53935234405d79e9eb53f152fbfb"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8a85e3e47e0d4eebfaf9a58b40aa94f977a56050cb5598ad5396a9ee7c087c6"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0ccb5aa0f3be2727117e5631200fbb3a5b3a2b3757545a92647d6dd8be6658f"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d784938534e255473155e4d9f276ee69eb85455b6af1292172c731409bf9adee"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a257de175c254d39ccd6a21341cd62eb7373b05c1e618a78096a56a857e0c316"}, + {file = "safetensors-0.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6fd80f7794554091836d4d613d33a7d006e2b8d6ba014d06f97cebdfda744f64"}, + {file = "safetensors-0.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:35803201d980efcf964b75a0a2aee97fe5e9ecc5f3ad676b38fafdfe98e0620d"}, + {file = "safetensors-0.4.1-cp39-none-win32.whl", hash = "sha256:7ff8a36e0396776d3ed9a106fc9a9d7c55d4439ca9a056a24bf66d343041d3e6"}, + {file = "safetensors-0.4.1-cp39-none-win_amd64.whl", hash = "sha256:bfa2e20342b81921b98edba52f8deb68843fa9c95250739a56b52ceda5ea5c61"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ae2d5a31cfb8a973a318f7c4d2cffe0bd1fe753cdf7bb41a1939d45a0a06f964"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a45dbf03e8334d3a5dc93687d98b6dc422f5d04c7d519dac09b84a3c87dd7c6"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2297b359d91126c0f9d4fd17bae3cfa2fe3a048a6971b8db07db746ad92f850c"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bda3d98e2bcece388232cfc551ebf063b55bdb98f65ab54df397da30efc7dcc5"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f8934bdfd202ebd0697040a3dff40dd77bc4c5bbf3527ede0532f5e7fb4d970f"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:42c3710cec7e5c764c7999697516370bee39067de0aa089b7e2cfb97ac8c6b20"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:53134226053e56bd56e73f7db42596e7908ed79f3c9a1016e4c1dade593ac8e5"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:257d59e40a1b367cb544122e7451243d65b33c3f34d822a347f4eea6fdf97fdf"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d54c2f1826e790d1eb2d2512bfd0ee443f0206b423d6f27095057c7f18a0687"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:645b3f1138fce6e818e79d4128afa28f0657430764cc045419c1d069ff93f732"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e9a7ffb1e551c6df51d267f5a751f042b183df22690f6feceac8d27364fd51d7"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:44e230fbbe120de564b64f63ef3a8e6ff02840fa02849d9c443d56252a1646d4"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:9d16b3b2fcc6fca012c74bd01b5619c655194d3e3c13e4d4d0e446eefa39a463"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:5d95ea4d8b32233910734a904123bdd3979c137c461b905a5ed32511defc075f"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:dab431699b5d45e0ca043bc580651ce9583dda594e62e245b7497adb32e99809"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16d8bbb7344e39cb9d4762e85c21df94ebeb03edac923dd94bb9ed8c10eac070"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1faf5111c66a6ba91f85dff2e36edaaf36e6966172703159daeef330de4ddc7b"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:660ca1d8bff6c7bc7c6b30b9b32df74ef3ab668f5df42cefd7588f0d40feadcb"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ae2f67f04ed0bb2e56fd380a8bd3eef03f609df53f88b6f5c7e89c08e52aae00"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:c8ed5d2c04cdc1afc6b3c28d59580448ac07732c50d94c15e14670f9c473a2ce"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:2b6a2814278b6660261aa9a9aae524616de9f1ec364e3716d219b6ed8f91801f"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:3cfd1ca35eacc635f0eaa894e5c5ed83ffebd0f95cac298fd430014fa7323631"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4177b456c6b0c722d82429127b5beebdaf07149d265748e97e0a34ff0b3694c8"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:313e8472197bde54e3ec54a62df184c414582979da8f3916981b6a7954910a1b"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fdb4adb76e21bad318210310590de61c9f4adcef77ee49b4a234f9dc48867869"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:1d568628e9c43ca15eb96c217da73737c9ccb07520fafd8a1eba3f2750614105"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:573b6023a55a2f28085fc0a84e196c779b6cbef4d9e73acea14c8094fee7686f"}, + {file = "safetensors-0.4.1.tar.gz", hash = "sha256:2304658e6ada81a5223225b4efe84748e760c46079bffedf7e321763cafb36c9"}, +] + +[package.extras] +all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"] +dev = ["safetensors[all]"] +jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"] +numpy = ["numpy (>=1.21.6)"] +paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"] +pinned-tf = ["safetensors[numpy]", "tensorflow (==2.11.0)"] +quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] +tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] +testing = ["h5py (>=3.7.0)", "huggingface_hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools_rust (>=1.5.2)"] +torch = ["safetensors[numpy]", "torch (>=1.10)"] + [[package]] name = "six" version = "1.16.0" @@ -2465,6 +2829,59 @@ files = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +[[package]] +name = "torch" +version = "2.1.2" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3a871edd6c02dae77ad810335c0833391c1a4ce49af21ea8cf0f6a5d2096eea8"}, + {file = "torch-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076"}, + {file = "torch-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e13034fd5fb323cbbc29e56d0637a3791e50dd589616f40c79adfa36a5a35a1"}, + {file = "torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:d9b535cad0df3d13997dbe8bd68ac33e0e3ae5377639c9881948e40794a61403"}, + {file = "torch-2.1.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:f9a55d55af02826ebfbadf4e9b682f0f27766bc33df8236b48d28d705587868f"}, + {file = "torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:a6ebbe517097ef289cc7952783588c72de071d4b15ce0f8b285093f0916b1162"}, + {file = "torch-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8f32ce591616a30304f37a7d5ea80b69ca9e1b94bba7f308184bf616fdaea155"}, + {file = "torch-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e0ee6cf90c8970e05760f898d58f9ac65821c37ffe8b04269ec787aa70962b69"}, + {file = "torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:76d37967c31c99548ad2c4d3f2cf191db48476f2e69b35a0937137116da356a1"}, + {file = "torch-2.1.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e2d83f07b4aac983453ea5bf8f9aa9dacf2278a8d31247f5d9037f37befc60e4"}, + {file = "torch-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f41fe0c7ecbf903a568c73486139a75cfab287a0f6c17ed0698fdea7a1e8641d"}, + {file = "torch-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e3225f47d50bb66f756fe9196a768055d1c26b02154eb1f770ce47a2578d3aa7"}, + {file = "torch-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33d59cd03cb60106857f6c26b36457793637512998666ee3ce17311f217afe2b"}, + {file = "torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:8e221deccd0def6c2badff6be403e0c53491805ed9915e2c029adbcdb87ab6b5"}, + {file = "torch-2.1.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05b18594f60a911a0c4f023f38a8bda77131fba5fd741bda626e97dcf5a3dd0a"}, + {file = "torch-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ca96253b761e9aaf8e06fb30a66ee301aecbf15bb5a303097de1969077620b6"}, + {file = "torch-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d93ba70f67b08c2ae5598ee711cbc546a1bc8102cef938904b8c85c2089a51a0"}, + {file = "torch-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:255b50bc0608db177e6a3cc118961d77de7e5105f07816585fa6f191f33a9ff3"}, + {file = "torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6984cd5057c0c977b3c9757254e989d3f1124f4ce9d07caa6cb637783c71d42a"}, + {file = "torch-2.1.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:bc195d7927feabc0eb7c110e457c955ed2ab616f3c7c28439dd4188cf589699f"}, +] + +[package.dependencies] +filelock = "*" +fsspec = "*" +jinja2 = "*" +networkx = "*" +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +sympy = "*" +triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +typing-extensions = "*" + +[package.extras] +dynamo = ["jinja2"] +opt-einsum = ["opt-einsum (>=3.3)"] + [[package]] name = "tornado" version = "6.4" @@ -2520,6 +2937,99 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "transformers" +version = "4.36.2" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "transformers-4.36.2-py3-none-any.whl", hash = "sha256:462066c4f74ee52516f12890dcc9ec71d1a5e97998db621668455117a54330f6"}, + {file = "transformers-4.36.2.tar.gz", hash = "sha256:d8068e897e47793281501e547d2bbdfc5b8556409c2cb6c3d9e2ca77d4c0b4ec"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.19.3,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.3.1" +tokenizers = ">=0.14,<0.19" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.21.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +codecarbon = ["codecarbon (==1.2.0)"] +deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +docs-specific = ["hf-doc-builder"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +timm = ["timm"] +tokenizers = ["tokenizers (>=0.14,<0.19)"] +torch = ["accelerate (>=0.21.0)", "torch (>=1.10,!=1.12.0)"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] +video = ["av (==9.2.0)", "decord (==0.6.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] + +[[package]] +name = "triton" +version = "2.1.0" +description = "A language and compiler for custom Deep Learning operations" +optional = true +python-versions = "*" +files = [ + {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, + {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, + {file = "triton-2.1.0-0-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ae4bb8a91de790e1866405211c4d618379781188f40d5c4c399766914e84cd94"}, + {file = "triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39f6fb6bdccb3e98f3152e3fbea724f1aeae7d749412bbb1fa9c441d474eba26"}, + {file = "triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21544e522c02005a626c8ad63d39bdff2f31d41069592919ef281e964ed26446"}, + {file = "triton-2.1.0-0-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:143582ca31dd89cd982bd3bf53666bab1c7527d41e185f9e3d8a3051ce1b663b"}, + {file = "triton-2.1.0-0-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82fc5aeeedf6e36be4e4530cbdcba81a09d65c18e02f52dc298696d45721f3bd"}, + {file = "triton-2.1.0-0-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:81a96d110a738ff63339fc892ded095b31bd0d205e3aace262af8400d40b6fa8"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.18)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + [[package]] name = "types-pyyaml" version = "6.0.12.12" @@ -2700,8 +3210,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] fastembed = ["fastembed"] hybrid = ["pinecone-text"] +local = ["torch", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "42a58d13a0f9d9a1bca34b4c29cafee6a5c884b80d47848cb7c552ad91e54743" +content-hash = "9cfa7fae3109942320a1e65fa30ee94e5b355beda5e1b285228b179bc462b4ac" diff --git a/pyproject.toml b/pyproject.toml index 5b0e1f30..02cde056 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,10 +23,13 @@ pinecone-text = {version = "^0.7.0", optional = true} colorlog = "^6.8.0" pyyaml = "^6.0.1" fastembed = {version = "^0.1.3", optional = true} +torch = {version = "^2.1.2", optional = true} +transformers = {version = "^4.36.2", optional = true} [tool.poetry.extras] hybrid = ["pinecone-text"] fastembed = ["fastembed"] +local = ["torch", "transformers"] [tool.poetry.group.dev.dependencies] ipykernel = "^6.26.0" diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 9d3a027e..e4bcaf14 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -3,6 +3,7 @@ from semantic_router.encoders.cohere import CohereEncoder from semantic_router.encoders.fastembed import FastEmbedEncoder from semantic_router.encoders.openai import OpenAIEncoder +from semantic_router.encoders.huggingface import HuggingFaceEncoder __all__ = [ "BaseEncoder", @@ -10,4 +11,5 @@ "OpenAIEncoder", "BM25Encoder", "FastEmbedEncoder", + "HuggingFaceEncoder", ] diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index 68150cb7..b1298b37 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -9,7 +9,12 @@ class BM25Encoder(BaseEncoder): idx_mapping: dict[int, int] | None = None type: str = "sparse" - def __init__(self, name: str = "bm25", score_threshold: float = 0.82): + def __init__( + self, + name: str = "bm25", + score_threshold: float = 0.82, + use_default_params: bool = True, + ): super().__init__(name=name, score_threshold=score_threshold) try: from pinecone_text.sparse import BM25Encoder as encoder @@ -18,9 +23,15 @@ def __init__(self, name: str = "bm25", score_threshold: float = 0.82): "Please install pinecone-text to use BM25Encoder. " "You can install it with: `pip install semantic-router[hybrid]`" ) - logger.info("Downloading and initializing BM25 model parameters.") - self.model = encoder.default() + self.model = encoder() + + if use_default_params: + logger.info("Downloading and initializing default sBM25 model parameters.") + self.model = encoder.default() + self._set_idx_mapping() + + def _set_idx_mapping(self): params = self.model.get_params() doc_freq = params["doc_freq"] if isinstance(doc_freq, dict): @@ -53,3 +64,4 @@ def fit(self, docs: list[str]): if self.model is None: raise ValueError("Model is not initialized.") self.model.fit(docs) + self._set_idx_mapping() diff --git a/semantic_router/encoders/huggingface.py b/semantic_router/encoders/huggingface.py new file mode 100644 index 00000000..f84b402a --- /dev/null +++ b/semantic_router/encoders/huggingface.py @@ -0,0 +1,112 @@ +from typing import Any +from pydantic import PrivateAttr +from semantic_router.encoders import BaseEncoder + + +class HuggingFaceEncoder(BaseEncoder): + name: str = "sentence-transformers/all-MiniLM-L6-v2" + type: str = "huggingface" + score_threshold: float = 0.5 + tokenizer_kwargs: dict = {} + model_kwargs: dict = {} + device: str | None = None + _tokenizer: Any = PrivateAttr() + _model: Any = PrivateAttr() + _torch: Any = PrivateAttr() + + def __init__(self, **data): + super().__init__(**data) + self._tokenizer, self._model = self._initialize_hf_model() + + def _initialize_hf_model(self): + try: + from transformers import AutoTokenizer, AutoModel + except ImportError: + raise ImportError( + "Please install transformers to use HuggingFaceEncoder. " + "You can install it with: " + "`pip install semantic-router[local]`" + ) + + try: + import torch + except ImportError: + raise ImportError( + "Please install Pytorch to use HuggingFaceEncoder. " + "You can install it with: " + "`pip install semantic-router[local]`" + ) + + self._torch = torch + + tokenizer = AutoTokenizer.from_pretrained( + self.name, + **self.tokenizer_kwargs, + ) + + model = AutoModel.from_pretrained(self.name, **self.model_kwargs) + + if self.device: + model.to(self.device) + + else: + device = "cuda" if self._torch.cuda.is_available() else "cpu" + model.to(device) + self.device = device + + return tokenizer, model + + def __call__( + self, + docs: list[str], + batch_size: int = 32, + normalize_embeddings: bool = True, + pooling_strategy: str = "mean", + ) -> list[list[float]]: + all_embeddings = [] + for i in range(0, len(docs), batch_size): + batch_docs = docs[i : i + batch_size] + + encoded_input = self._tokenizer( + batch_docs, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + + with self._torch.no_grad(): + model_output = self._model(**encoded_input) + + if pooling_strategy == "mean": + embeddings = self._mean_pooling( + model_output, encoded_input["attention_mask"] + ) + elif pooling_strategy == "max": + embeddings = self._max_pooling( + model_output, encoded_input["attention_mask"] + ) + else: + raise ValueError( + "Invalid pooling_strategy. Please use 'mean' or 'max'." + ) + + if normalize_embeddings: + embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1) + + embeddings = embeddings.tolist() + all_embeddings.extend(embeddings) + return all_embeddings + + def _mean_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return self._torch.sum( + token_embeddings * input_mask_expanded, 1 + ) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + def _max_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + token_embeddings[input_mask_expanded == 0] = -1e9 + return self._torch.max(token_embeddings, 1)[0] diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index cd9f7ccb..5273f531 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -16,11 +16,21 @@ class HybridRouteLayer: score_threshold: float def __init__( - self, encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3 + self, + encoder: BaseEncoder, + sparse_encoder: BM25Encoder | None = None, + routes: list[Route] = [], + alpha: float = 0.3, ): self.encoder = encoder self.score_threshold = self.encoder.score_threshold - self.sparse_encoder = BM25Encoder() + + if sparse_encoder is None: + logger.warning("No sparse_encoder provided. Using default BM25Encoder.") + self.sparse_encoder = BM25Encoder() + else: + self.sparse_encoder = sparse_encoder + self.alpha = alpha # if routes list has been passed, we initialize index now if routes: diff --git a/tests/unit/encoders/test_bm25.py b/tests/unit/encoders/test_bm25.py index e654d7bb..174453d2 100644 --- a/tests/unit/encoders/test_bm25.py +++ b/tests/unit/encoders/test_bm25.py @@ -5,7 +5,11 @@ @pytest.fixture def bm25_encoder(): - return BM25Encoder() + sparse_encoder = BM25Encoder(use_default_params=False) + sparse_encoder.fit( + ["The quick brown fox", "jumps over the lazy dog", "Hello, world!"] + ) + return sparse_encoder class TestBM25Encoder: diff --git a/tests/unit/encoders/test_huggingface.py b/tests/unit/encoders/test_huggingface.py new file mode 100644 index 00000000..0aa8cb79 --- /dev/null +++ b/tests/unit/encoders/test_huggingface.py @@ -0,0 +1,62 @@ +import pytest +import numpy as np +from unittest.mock import patch +from semantic_router.encoders.huggingface import HuggingFaceEncoder + + +encoder = HuggingFaceEncoder() + + +class TestHuggingFaceEncoder: + def test_huggingface_encoder_import_errors_transformers(self): + with patch.dict("sys.modules", {"transformers": None}): + with pytest.raises(ImportError) as error: + HuggingFaceEncoder() + + assert "Please install transformers to use HuggingFaceEncoder" in str( + error.value + ) + + def test_huggingface_encoder_import_errors_torch(self): + with patch.dict("sys.modules", {"torch": None}): + with pytest.raises(ImportError) as error: + HuggingFaceEncoder() + + assert "Please install Pytorch to use HuggingFaceEncoder" in str(error.value) + + def test_huggingface_encoder_mean_pooling(self): + test_docs = ["This is a test", "This is another test"] + embeddings = encoder(test_docs, pooling_strategy="mean") + assert isinstance(embeddings, list) + assert len(embeddings) == len(test_docs) + assert all(isinstance(embedding, list) for embedding in embeddings) + assert all(len(embedding) > 0 for embedding in embeddings) + + def test_huggingface_encoder_max_pooling(self): + test_docs = ["This is a test", "This is another test"] + embeddings = encoder(test_docs, pooling_strategy="max") + assert isinstance(embeddings, list) + assert len(embeddings) == len(test_docs) + assert all(isinstance(embedding, list) for embedding in embeddings) + assert all(len(embedding) > 0 for embedding in embeddings) + + def test_huggingface_encoder_normalized_embeddings(self): + docs = ["This is a test document.", "Another test document."] + unnormalized_embeddings = encoder(docs, normalize_embeddings=False) + normalized_embeddings = encoder(docs, normalize_embeddings=True) + assert len(unnormalized_embeddings) == len(normalized_embeddings) + + for unnormalized, normalized in zip( + unnormalized_embeddings, normalized_embeddings + ): + norm_unnormalized = np.linalg.norm(unnormalized, ord=2) + norm_normalized = np.linalg.norm(normalized, ord=2) + # Ensure the norm of the normalized embeddings is approximately 1 + assert np.isclose(norm_normalized, 1.0) + # Ensure the normalized embeddings are actually normalized versions of unnormalized embeddings + np.testing.assert_allclose( + normalized, + np.divide(unnormalized, norm_unnormalized), + rtol=1e-5, + atol=1e-5, # Adjust tolerance levels + ) diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 6896c4de..df530149 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -1,6 +1,11 @@ import pytest -from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder +from semantic_router.encoders import ( + BaseEncoder, + BM25Encoder, + CohereEncoder, + OpenAIEncoder, +) from semantic_router.hybrid_layer import HybridRouteLayer from semantic_router.route import Route @@ -42,9 +47,15 @@ def routes(): ] +sparse_encoder = BM25Encoder(use_default_params=False) +sparse_encoder.fit(["The quick brown fox", "jumps over the lazy dog", "Hello, world!"]) + + class TestHybridRouteLayer: def test_initialization(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes + ) assert route_layer.index is not None and route_layer.categories is not None assert openai_encoder.score_threshold == 0.82 assert route_layer.score_threshold == 0.82 @@ -52,14 +63,20 @@ def test_initialization(self, openai_encoder, routes): assert len(set(route_layer.categories)) == 2 def test_initialization_different_encoders(self, cohere_encoder, openai_encoder): - route_layer_cohere = HybridRouteLayer(encoder=cohere_encoder) + route_layer_cohere = HybridRouteLayer( + encoder=cohere_encoder, sparse_encoder=sparse_encoder + ) assert route_layer_cohere.score_threshold == 0.3 - route_layer_openai = HybridRouteLayer(encoder=openai_encoder) + route_layer_openai = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder + ) assert route_layer_openai.score_threshold == 0.82 def test_add_route(self, openai_encoder): - route_layer = HybridRouteLayer(encoder=openai_encoder) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder + ) route = Route(name="Route 3", utterances=["Yes", "No"]) route_layer._add_routes([route]) assert route_layer.index is not None and route_layer.categories is not None @@ -67,7 +84,9 @@ def test_add_route(self, openai_encoder): assert len(set(route_layer.categories)) == 1 def test_add_multiple_routes(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder + ) for route in routes: route_layer.add(route) assert route_layer.index is not None and route_layer.categories is not None @@ -75,16 +94,22 @@ def test_add_multiple_routes(self, openai_encoder, routes): assert len(set(route_layer.categories)) == 2 def test_query_and_classification(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes + ) query_result = route_layer("Hello") assert query_result in ["Route 1", "Route 2"] def test_query_with_no_index(self, openai_encoder): - route_layer = HybridRouteLayer(encoder=openai_encoder) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder + ) assert route_layer("Anything") is None def test_semantic_classify(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes + ) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -95,7 +120,9 @@ def test_semantic_classify(self, openai_encoder, routes): assert score == [0.9] def test_semantic_classify_multiple_routes(self, openai_encoder, routes): - route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes + ) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -107,12 +134,16 @@ def test_semantic_classify_multiple_routes(self, openai_encoder, routes): assert score == [0.9, 0.8] def test_pass_threshold(self, openai_encoder): - route_layer = HybridRouteLayer(encoder=openai_encoder) + route_layer = HybridRouteLayer( + encoder=openai_encoder, sparse_encoder=sparse_encoder + ) assert not route_layer._pass_threshold([], 0.5) assert route_layer._pass_threshold([0.6, 0.7], 0.5) def test_failover_score_threshold(self, base_encoder): - route_layer = HybridRouteLayer(encoder=base_encoder) + route_layer = HybridRouteLayer( + encoder=base_encoder, sparse_encoder=sparse_encoder + ) assert base_encoder.score_threshold == 0.50 assert route_layer.score_threshold == 0.50