From 838767a564228b36962bfef789c0d79f9eafb5e9 Mon Sep 17 00:00:00 2001 From: Anush Date: Sat, 16 Mar 2024 02:04:15 +0530 Subject: [PATCH] feat: QdrantIndex (#1) * feat: QdrantIndex * chore: poetry.lock --- poetry.lock | 236 +++++++++++++++++++++++++++++- pyproject.toml | 2 + semantic_router/index/__init__.py | 2 + semantic_router/index/qdrant.py | 223 ++++++++++++++++++++++++++++ semantic_router/layer.py | 4 +- tests/unit/test_layer.py | 154 ++++++++++++------- 6 files changed, 564 insertions(+), 57 deletions(-) create mode 100644 semantic_router/index/qdrant.py diff --git a/poetry.lock b/poetry.lock index 4d2fe047..db50db68 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1059,6 +1059,140 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] +[[package]] +name = "grpcio" +version = "1.62.1" +description = "HTTP/2-based RPC framework" +optional = true +python-versions = ">=3.7" +files = [ + {file = "grpcio-1.62.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:179bee6f5ed7b5f618844f760b6acf7e910988de77a4f75b95bbfaa8106f3c1e"}, + {file = "grpcio-1.62.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:48611e4fa010e823ba2de8fd3f77c1322dd60cb0d180dc6630a7e157b205f7ea"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b2a0e71b0a2158aa4bce48be9f8f9eb45cbd17c78c7443616d00abbe2a509f6d"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbe80577c7880911d3ad65e5ecc997416c98f354efeba2f8d0f9112a67ed65a5"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58f6c693d446964e3292425e1d16e21a97a48ba9172f2d0df9d7b640acb99243"}, + {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:77c339403db5a20ef4fed02e4d1a9a3d9866bf9c0afc77a42234677313ea22f3"}, + {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b5a4ea906db7dec694098435d84bf2854fe158eb3cd51e1107e571246d4d1d70"}, + {file = "grpcio-1.62.1-cp310-cp310-win32.whl", hash = "sha256:4187201a53f8561c015bc745b81a1b2d278967b8de35f3399b84b0695e281d5f"}, + {file = "grpcio-1.62.1-cp310-cp310-win_amd64.whl", hash = "sha256:844d1f3fb11bd1ed362d3fdc495d0770cfab75761836193af166fee113421d66"}, + {file = "grpcio-1.62.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:833379943d1728a005e44103f17ecd73d058d37d95783eb8f0b28ddc1f54d7b2"}, + {file = "grpcio-1.62.1-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:c7fcc6a32e7b7b58f5a7d27530669337a5d587d4066060bcb9dee7a8c833dfb7"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:fa7d28eb4d50b7cbe75bb8b45ed0da9a1dc5b219a0af59449676a29c2eed9698"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48f7135c3de2f298b833be8b4ae20cafe37091634e91f61f5a7eb3d61ec6f660"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f11fd63365ade276c9d4a7b7df5c136f9030e3457107e1791b3737a9b9ed6a"}, + {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4b49fd8fe9f9ac23b78437da94c54aa7e9996fbb220bac024a67469ce5d0825f"}, + {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:482ae2ae78679ba9ed5752099b32e5fe580443b4f798e1b71df412abf43375db"}, + {file = "grpcio-1.62.1-cp311-cp311-win32.whl", hash = "sha256:1faa02530b6c7426404372515fe5ddf66e199c2ee613f88f025c6f3bd816450c"}, + {file = "grpcio-1.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:5bd90b8c395f39bc82a5fb32a0173e220e3f401ff697840f4003e15b96d1befc"}, + {file = "grpcio-1.62.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:b134d5d71b4e0837fff574c00e49176051a1c532d26c052a1e43231f252d813b"}, + {file = "grpcio-1.62.1-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d1f6c96573dc09d50dbcbd91dbf71d5cf97640c9427c32584010fbbd4c0e0037"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:359f821d4578f80f41909b9ee9b76fb249a21035a061a327f91c953493782c31"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a485f0c2010c696be269184bdb5ae72781344cb4e60db976c59d84dd6354fac9"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b50b09b4dc01767163d67e1532f948264167cd27f49e9377e3556c3cba1268e1"}, + {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3227c667dccbe38f2c4d943238b887bac588d97c104815aecc62d2fd976e014b"}, + {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3952b581eb121324853ce2b191dae08badb75cd493cb4e0243368aa9e61cfd41"}, + {file = "grpcio-1.62.1-cp312-cp312-win32.whl", hash = "sha256:83a17b303425104d6329c10eb34bba186ffa67161e63fa6cdae7776ff76df73f"}, + {file = "grpcio-1.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:6696ffe440333a19d8d128e88d440f91fb92c75a80ce4b44d55800e656a3ef1d"}, + {file = "grpcio-1.62.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:e3393b0823f938253370ebef033c9fd23d27f3eae8eb9a8f6264900c7ea3fb5a"}, + {file = "grpcio-1.62.1-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:83e7ccb85a74beaeae2634f10eb858a0ed1a63081172649ff4261f929bacfd22"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:882020c87999d54667a284c7ddf065b359bd00251fcd70279ac486776dbf84ec"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a10383035e864f386fe096fed5c47d27a2bf7173c56a6e26cffaaa5a361addb1"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:960edebedc6b9ada1ef58e1c71156f28689978188cd8cff3b646b57288a927d9"}, + {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:23e2e04b83f347d0aadde0c9b616f4726c3d76db04b438fd3904b289a725267f"}, + {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:978121758711916d34fe57c1f75b79cdfc73952f1481bb9583399331682d36f7"}, + {file = "grpcio-1.62.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9084086190cc6d628f282e5615f987288b95457292e969b9205e45b442276407"}, + {file = "grpcio-1.62.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:22bccdd7b23c420a27fd28540fb5dcbc97dc6be105f7698cb0e7d7a420d0e362"}, + {file = "grpcio-1.62.1-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:8999bf1b57172dbc7c3e4bb3c732658e918f5c333b2942243f10d0d653953ba9"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:d9e52558b8b8c2f4ac05ac86344a7417ccdd2b460a59616de49eb6933b07a0bd"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1714e7bc935780bc3de1b3fcbc7674209adf5208ff825799d579ffd6cd0bd505"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8842ccbd8c0e253c1f189088228f9b433f7a93b7196b9e5b6f87dba393f5d5d"}, + {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1f1e7b36bdff50103af95a80923bf1853f6823dd62f2d2a2524b66ed74103e49"}, + {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba97b8e8883a8038606480d6b6772289f4c907f6ba780fa1f7b7da7dfd76f06"}, + {file = "grpcio-1.62.1-cp38-cp38-win32.whl", hash = "sha256:a7f615270fe534548112a74e790cd9d4f5509d744dd718cd442bf016626c22e4"}, + {file = "grpcio-1.62.1-cp38-cp38-win_amd64.whl", hash = "sha256:e6c8c8693df718c5ecbc7babb12c69a4e3677fd11de8886f05ab22d4e6b1c43b"}, + {file = "grpcio-1.62.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:73db2dc1b201d20ab7083e7041946910bb991e7e9761a0394bbc3c2632326483"}, + {file = "grpcio-1.62.1-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:407b26b7f7bbd4f4751dbc9767a1f0716f9fe72d3d7e96bb3ccfc4aace07c8de"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f8de7c8cef9261a2d0a62edf2ccea3d741a523c6b8a6477a340a1f2e417658de"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd5c8a1af40ec305d001c60236308a67e25419003e9bb3ebfab5695a8d0b369"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be0477cb31da67846a33b1a75c611f88bfbcd427fe17701b6317aefceee1b96f"}, + {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:60dcd824df166ba266ee0cfaf35a31406cd16ef602b49f5d4dfb21f014b0dedd"}, + {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:973c49086cabab773525f6077f95e5a993bfc03ba8fc32e32f2c279497780585"}, + {file = "grpcio-1.62.1-cp39-cp39-win32.whl", hash = "sha256:12859468e8918d3bd243d213cd6fd6ab07208195dc140763c00dfe901ce1e1b4"}, + {file = "grpcio-1.62.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7209117bbeebdfa5d898205cc55153a51285757902dd73c47de498ad4d11332"}, + {file = "grpcio-1.62.1.tar.gz", hash = "sha256:6c455e008fa86d9e9a9d85bb76da4277c0d7d9668a3bfa70dbe86e9f3c759947"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.62.1)"] + +[[package]] +name = "grpcio-tools" +version = "1.62.1" +description = "Protobuf code generator for gRPC" +optional = true +python-versions = ">=3.7" +files = [ + {file = "grpcio-tools-1.62.1.tar.gz", hash = "sha256:a4991e5ee8a97ab791296d3bf7e8700b1445635cc1828cc98df945ca1802d7f2"}, + {file = "grpcio_tools-1.62.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:f2b404bcae7e2ef9b0b9803b2a95119eb7507e6dc80ea4a64a78be052c30cebc"}, + {file = "grpcio_tools-1.62.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:fdd987a580b4474769adfd40144486f54bcc73838d5ec5d3647a17883ea78e76"}, + {file = "grpcio_tools-1.62.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:07af1a6442e2313cff22af93c2c4dd37ae32b5239b38e0d99e2cbf93de65429f"}, + {file = "grpcio_tools-1.62.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:41384c9ee18e61ef20cad2774ef71bd8854b63efce263b5177aa06fccb84df1f"}, + {file = "grpcio_tools-1.62.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c38006f7702d2ff52122e4c77a47348709374050c76216e84b30a9f06e45afa"}, + {file = "grpcio_tools-1.62.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08fecc3c5b4e6dd3278f2b9d12837e423c7dcff551ca1e587018b4a0fc5f8019"}, + {file = "grpcio_tools-1.62.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a01e8dcd0f041f6fa6d815c54a2017d032950e310c41d514a8bc041e872c4d12"}, + {file = "grpcio_tools-1.62.1-cp310-cp310-win32.whl", hash = "sha256:dd933b8e0b3c13fe3543d58f849a6a5e0d7987688cb6801834278378c724f695"}, + {file = "grpcio_tools-1.62.1-cp310-cp310-win_amd64.whl", hash = "sha256:2b04844a9382f1bde4b4174e476e654ab3976168d2469cb4b29e352f4f35a5aa"}, + {file = "grpcio_tools-1.62.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:024380536ba71a96cdf736f0954f6ad03f5da609c09edbcc2ca02fdd639e0eed"}, + {file = "grpcio_tools-1.62.1-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:21f14b99e0cd38ad56754cc0b62b2bf3cf75f9f7fc40647da54669e0da0726fe"}, + {file = "grpcio_tools-1.62.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:975ac5fb482c23f3608c16e06a43c8bab4d79c2e2564cdbc25cf753c6e998775"}, + {file = "grpcio_tools-1.62.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50739aaab0c8076ad5957204e71f2e0c9876e11fd8338f7f09de12c2d75163c5"}, + {file = "grpcio_tools-1.62.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:598c54318f0326cf5020aa43fc95a15e933aba4a71943d3bff2677d2d21ddfa1"}, + {file = "grpcio_tools-1.62.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f309bdb33a61f8e049480d41498ee2e525cfb5e959958b326abfdf552bf9b9cb"}, + {file = "grpcio_tools-1.62.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f358effd3c11d66c150e0227f983d54a5cd30e14038566dadcf25f9f6844e6e8"}, + {file = "grpcio_tools-1.62.1-cp311-cp311-win32.whl", hash = "sha256:b76aead9b73f1650a091870fe4e9ed15ac4d8ed136f962042367255199c23594"}, + {file = "grpcio_tools-1.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:d66a5d47eaa427039752fa0a83a425ff2a487b6a0ac30556fd3be2f3a27a0130"}, + {file = "grpcio_tools-1.62.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:575535d039b97d63e6a9abee626d6c7cd47bd8cb73dd00a5c84a98254a2164a4"}, + {file = "grpcio_tools-1.62.1-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:22644c90e43d1a888477899af917979e17364fdd6e9bbb92679cd6a54c4d36c3"}, + {file = "grpcio_tools-1.62.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:156d3e1b227c16e903003a56881dbe60e40f2b4bd66f0bc3b27c53e466e6384d"}, + {file = "grpcio_tools-1.62.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5ad7c5691625a85327e5b683443baf73ae790fd5afc938252041ed5cd665e377"}, + {file = "grpcio_tools-1.62.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e140bbc08eea8abf51c0274f45fb1e8350220e64758998d7f3c7f985a0b2496"}, + {file = "grpcio_tools-1.62.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7444fcab861911525470d398e5638b70d5cbea3b4674a3de92b5c58c5c515d4d"}, + {file = "grpcio_tools-1.62.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e643cd14a5d1e59865cba68a5a6f0175d987f36c5f4cb0db80dee9ed60b4c174"}, + {file = "grpcio_tools-1.62.1-cp312-cp312-win32.whl", hash = "sha256:1344a773d2caa9bb7fbea7e879b84f33740c808c34a5bd2a2768e526117a6b44"}, + {file = "grpcio_tools-1.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:2eea1db3748b2f37b4dce84d8e0c15d9bc811094807cabafe7b0ea47f424dfd5"}, + {file = "grpcio_tools-1.62.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:45d2e6cf04d27286b6f73e6e20ba3f0a1f6d8f5535e5dcb1356200419bb457f4"}, + {file = "grpcio_tools-1.62.1-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:46ae58e6926773e7315e9005f0f17aacedbc0895a8752bec087d24efa2f1fb21"}, + {file = "grpcio_tools-1.62.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:4c28086df31478023a36f45e50767872ab3aed2419afff09814cb61c88b77db4"}, + {file = "grpcio_tools-1.62.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a4fba5b339f4797548591036c9481e6895bf920fab7d3dc664d2697f8fb7c0bf"}, + {file = "grpcio_tools-1.62.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23eb3d47f78f509fcd201749b1f1e44b76f447913f7fbb3b8bae20f109086295"}, + {file = "grpcio_tools-1.62.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:fd5d47707bd6bc2b707ece765c362d2a1d2e8f6cd92b04c99fab49a929f3610c"}, + {file = "grpcio_tools-1.62.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1924a6a943df7c73b9ef0048302327c75962b567451479710da729ead241228"}, + {file = "grpcio_tools-1.62.1-cp37-cp37m-win_amd64.whl", hash = "sha256:fe71ca30aabe42591e84ecb9694c0297dc699cc20c5b24d2cb267fb0fc01f947"}, + {file = "grpcio_tools-1.62.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:1819fd055c1ae672d1d725ec75eefd1f700c18acba0ed9332202be31d69c401d"}, + {file = "grpcio_tools-1.62.1-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:5dbe1f7481dd14b6d477b4bace96d275090bc7636b9883975a08b802c94e7b78"}, + {file = "grpcio_tools-1.62.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:771c051c5ece27ad03e4f2e33624a925f0ad636c01757ab7dbb04a37964af4ba"}, + {file = "grpcio_tools-1.62.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98209c438b38b6f1276dbc27b1c04e346a75bfaafe72a25a548f2dc5ce71d226"}, + {file = "grpcio_tools-1.62.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2152308e5321cb90fb45aaa84d03d6dedb19735a8779aaf36c624f97b831842d"}, + {file = "grpcio_tools-1.62.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ed1f27dc2b2262c8b8d9036276619c1bb18791311c16ccbf1f31b660f2aad7cf"}, + {file = "grpcio_tools-1.62.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2744947b6c5e907af21133431809ccca535a037356864e32c122efed8cb9de1f"}, + {file = "grpcio_tools-1.62.1-cp38-cp38-win32.whl", hash = "sha256:13b20e269d14ad629ff9a2c9a2450f3dbb119d5948de63b27ffe624fa7aea85a"}, + {file = "grpcio_tools-1.62.1-cp38-cp38-win_amd64.whl", hash = "sha256:999823758e9eacd0095863d06cd6d388be769f80c9abb65cdb11c4f2cfce3fea"}, + {file = "grpcio_tools-1.62.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:941f8a5c31986053e75fa466bcfa743c2bf1b513b7978cf1f4ab4e96a8219d27"}, + {file = "grpcio_tools-1.62.1-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:b9c02c88c77ef6057c6cbeea8922d7c2424aabf46bfc40ddf42a32765ba91061"}, + {file = "grpcio_tools-1.62.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:6abd4eb3ccb444383a40156139acc3aaa73745d395139cb6bc8e2a3429e1e627"}, + {file = "grpcio_tools-1.62.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:449503213d142f8470b331a1c2f346f8457f16c7fe20f531bc2500e271f7c14c"}, + {file = "grpcio_tools-1.62.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a11bcf609d00cfc9baed77ab308223cabc1f0b22a05774a26dd4c94c0c80f1f"}, + {file = "grpcio_tools-1.62.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:5d7bdea33354b55acf40bb4dd3ba7324d6f1ef6b4a1a4da0807591f8c7e87b9a"}, + {file = "grpcio_tools-1.62.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d03b645852d605f43003020e78fe6d573cae6ee6b944193e36b8b317e7549a20"}, + {file = "grpcio_tools-1.62.1-cp39-cp39-win32.whl", hash = "sha256:52b185dfc3bf32e70929310367dbc66185afba60492a6a75a9b1141d407e160c"}, + {file = "grpcio_tools-1.62.1-cp39-cp39-win_amd64.whl", hash = "sha256:63a273b70896d3640b7a883eb4a080c3c263d91662d870a2e9c84b7bbd978e7b"}, +] + +[package.dependencies] +grpcio = ">=1.62.1" +protobuf = ">=4.21.6,<5.0dev" +setuptools = "*" + [[package]] name = "h11" version = "0.14.0" @@ -1070,6 +1204,32 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "h2" +version = "4.1.0" +description = "HTTP/2 State-Machine based protocol implementation" +optional = true +python-versions = ">=3.6.1" +files = [ + {file = "h2-4.1.0-py3-none-any.whl", hash = "sha256:03a46bcf682256c95b5fd9e9a99c1323584c3eec6440d379b9903d709476bc6d"}, + {file = "h2-4.1.0.tar.gz", hash = "sha256:a83aca08fbe7aacb79fec788c9c0bac936343560ed9ec18b82a13a12c28d2abb"}, +] + +[package.dependencies] +hpack = ">=4.0,<5" +hyperframe = ">=6.0,<7" + +[[package]] +name = "hpack" +version = "4.0.0" +description = "Pure-Python HPACK header compression" +optional = true +python-versions = ">=3.6.1" +files = [ + {file = "hpack-4.0.0-py3-none-any.whl", hash = "sha256:84a076fad3dc9a9f8063ccb8041ef100867b1878b25ef0ee63847a5d53818a6c"}, + {file = "hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095"}, +] + [[package]] name = "httpcore" version = "1.0.4" @@ -1105,6 +1265,7 @@ files = [ [package.dependencies] anyio = "*" certifi = "*" +h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""} httpcore = "==1.*" idna = "*" sniffio = "*" @@ -1161,6 +1322,17 @@ files = [ [package.dependencies] pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""} +[[package]] +name = "hyperframe" +version = "6.0.1" +description = "HTTP/2 framing layer for Python" +optional = true +python-versions = ">=3.6.1" +files = [ + {file = "hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15"}, + {file = "hyperframe-6.0.1.tar.gz", hash = "sha256:ae510046231dc8e9ecb1a6586f63d2347bf4c8905914aa84ba585ae85f28a914"}, +] + [[package]] name = "idna" version = "3.6" @@ -2522,6 +2694,25 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "portalocker" +version = "2.8.2" +description = "Wraps the portalocker recipe for easy usage" +optional = true +python-versions = ">=3.8" +files = [ + {file = "portalocker-2.8.2-py3-none-any.whl", hash = "sha256:cfb86acc09b9aa7c3b43594e19be1345b9d16af3feb08bf92f23d4dce513a28e"}, + {file = "portalocker-2.8.2.tar.gz", hash = "sha256:2b035aa7828e46c58e9b31390ee1f169b98e1066ab10b9a6a861fe7e25ee4f33"}, +] + +[package.dependencies] +pywin32 = {version = ">=226", markers = "platform_system == \"Windows\""} + +[package.extras] +docs = ["sphinx (>=1.7.1)"] +redis = ["redis"] +tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"] + [[package]] name = "prompt-toolkit" version = "3.0.43" @@ -3049,6 +3240,32 @@ files = [ [package.dependencies] cffi = {version = "*", markers = "implementation_name == \"pypy\""} +[[package]] +name = "qdrant-client" +version = "1.8.0" +description = "Client library for the Qdrant vector search engine" +optional = true +python-versions = ">=3.8" +files = [ + {file = "qdrant_client-1.8.0-py3-none-any.whl", hash = "sha256:fa28d3eb64c0c57ec029c7c85c71f6c72c197f92502022655741f3632c518e29"}, + {file = "qdrant_client-1.8.0.tar.gz", hash = "sha256:2a1a3f2cbacc7adba85644cf6cfdee20401cf25764b32da479c81fb63e178d15"}, +] + +[package.dependencies] +grpcio = ">=1.41.0" +grpcio-tools = ">=1.41.0" +httpx = {version = ">=0.14.0", extras = ["http2"]} +numpy = [ + {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, + {version = ">=1.26", markers = "python_version >= \"3.12\""}, +] +portalocker = ">=2.7.0,<3.0.0" +pydantic = ">=1.10.8" +urllib3 = ">=1.26.14,<3" + +[package.extras] +fastembed = ["fastembed (==0.2.2)"] + [[package]] name = "regex" version = "2023.12.25" @@ -3330,6 +3547,22 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface_hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools_rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "setuptools" +version = "69.2.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = true +python-versions = ">=3.8" +files = [ + {file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"}, + {file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] + [[package]] name = "six" version = "1.16.0" @@ -4034,9 +4267,10 @@ local = ["llama-cpp-python", "torch", "transformers"] mistralai = ["mistralai"] pinecone = ["pinecone-client"] processing = ["matplotlib"] +qdrant = ["qdrant-client"] vision = ["pillow", "torch", "torchvision", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "b02bb06cd8c09237dafe711992b1cbe9e190a63ad8510e6478b94d052a141901" +content-hash = "424cd1692d7d98c5e4be9774689edb8cf8cebc187fc35f0e80d6013e3b70a9c9" diff --git a/pyproject.toml b/pyproject.toml index bc43ed81..94db45e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ torchvision = { version = "^0.17.0", optional = true} pillow = { version= "^10.2.0", optional = true} tiktoken = "^0.6.0" matplotlib = { version="^3.8.3", optional = true} +qdrant-client = {version="^1.8.0", optional = true} [tool.poetry.extras] hybrid = ["pinecone-text"] @@ -45,6 +46,7 @@ pinecone = ["pinecone-client"] vision = ["torch", "torchvision", "transformers", "pillow"] processing = ["matplotlib"] mistralai = ["mistralai"] +qdrant = ["qdrant-client"] [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.0" diff --git a/semantic_router/index/__init__.py b/semantic_router/index/__init__.py index 1ad70df4..9a01b996 100644 --- a/semantic_router/index/__init__.py +++ b/semantic_router/index/__init__.py @@ -1,9 +1,11 @@ from semantic_router.index.base import BaseIndex from semantic_router.index.local import LocalIndex from semantic_router.index.pinecone import PineconeIndex +from semantic_router.index.qdrant import QdrantIndex __all__ = [ "BaseIndex", "LocalIndex", + "QdrantIndex", "PineconeIndex", ] diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py new file mode 100644 index 00000000..b12031f4 --- /dev/null +++ b/semantic_router/index/qdrant.py @@ -0,0 +1,223 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +from pydantic.v1 import Field + +from semantic_router.index.base import BaseIndex + +DEFAULT_COLLECTION_NAME = "semantic-router-collection" +DEFAULT_UPLOAD_BATCH_SIZE = 100 +SCROLL_SIZE = 1000 + + +class QdrantIndex(BaseIndex): + "The name of the collection to use" + + collection_name: str = Field( + default=DEFAULT_COLLECTION_NAME, + description=f"The name of the Qdrant collection to use. Defaults to '{DEFAULT_COLLECTION_NAME}'", + ) + location: Optional[str] = Field( + default=":memory:", + description="If ':memory:' - use an in-memory Qdrant instance. Used as 'url' value otherwise", + ) + url: Optional[str] = Field( + default=None, + description="Qualified URL of the Qdrant instance. Optional[scheme], host, Optional[port], Optional[prefix]", + ) + port: Optional[int] = Field( + default=6333, + description="Port of the REST API interface.", + ) + grpc_port: int = Field( + default=6334, + description="Port of the gRPC interface.", + ) + prefer_grpc: bool = Field( + default=None, + description="Whether to use gPRC interface whenever possible in methods", + ) + https: Optional[bool] = Field( + default=None, + description="Whether to use HTTPS(SSL) protocol.", + ) + api_key: Optional[str] = Field( + default=None, + description="API key for authentication in Qdrant Cloud.", + ) + prefix: Optional[str] = Field( + default=None, + description="Prefix to the REST URL path. Example: `http://localhost:6333/some/prefix/{qdrant-endpoint}`.", + ) + timeout: Optional[int] = Field( + default=None, + description="Timeout for REST and gRPC API requests.", + ) + host: Optional[str] = Field( + default=None, + description="Host name of Qdrant service. If url and host are None, set to 'localhost'.", + ) + path: Optional[str] = Field( + default=None, + description="Persistence path for Qdrant local", + ) + grpc_options: Optional[Dict[str, Any]] = Field( + default=None, + description="Options to be passed to the low-level Qdrant GRPC client, if used.", + ) + size: Union[int, None] = Field( + default=None, + description="Embedding dimensions. Defaults to the embedding length of the configured encoder.", + ) + distance: str = Field( + default="Cosine", description="Distance metric to use for similarity search." + ) + collection_options: Optional[Dict[str, Any]] = Field( + default={}, + description="Additonal options to be passed to `QdrantClient#create_collection`.", + ) + client: Any = Field(default=None, exclude=True) + + def __init__(self, **data): + super().__init__(**data) + self.type = "qdrant" + self.client = self._initialize_client() + + def _initialize_client(self): + try: + from qdrant_client import QdrantClient + + return QdrantClient( + location=self.location, + url=self.url, + port=self.port, + grpc_port=self.grpc_port, + prefer_grpc=self.prefer_grpc, + https=self.https, + api_key=self.api_key, + prefix=self.prefix, + timeout=self.timeout, + host=self.host, + path=self.path, + grpc_options=self.grpc_options, + ) + + except ImportError as e: + raise ImportError( + "Please install 'qdrant-client' to use QdrantIndex." + "You can install it with: " + "`pip install 'semantic-router[qdrant]'`" + ) from e + + def _init_collection(self) -> None: + from qdrant_client import QdrantClient, models + + self.client: QdrantClient + if not self.client.collection_exists(self.collection_name): + if not self.dimensions: + raise ValueError( + "Cannot create a collection without specifying the dimensions." + ) + + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=models.VectorParams( + size=self.dimensions, + distance=self.distance, # type: ignore + ), + **self.collection_options, + ) + + def add( + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[str], + batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, + ): + self.dimensions = self.dimensions or len(embeddings[0]) + self._init_collection() + + payloads = [ + {"sr_route": route, "sr_utterance": utterance} + for route, utterance in zip(routes, utterances) + ] + + # UUIDs are autogenerated by qdrant-client if not provided explicitly + self.client.upload_collection( + self.collection_name, + vectors=embeddings, + payload=payloads, + batch_size=batch_size, + ) + + def get_routes(self) -> List[Tuple]: + """ + Gets a list of route and utterance objects currently stored in the index. + + Returns: + List[Tuple]: A list of (route_name, utterance) objects. + """ + + import grpc + + results = [] + next_offset = None + stop_scrolling = False + while not stop_scrolling: + records, next_offset = self.client.scroll( + self.collection_name, + limit=SCROLL_SIZE, + offset=next_offset, + with_payload=True, + ) + stop_scrolling = next_offset is None or ( + isinstance(next_offset, grpc.PointId) + and next_offset.num == 0 + and next_offset.uuid == "" + ) + + results.extend(records) + + route_tuples = [ + (x.payload["sr_route"], x.payload["sr_utterance"]) for x in results + ] + return route_tuples + + def delete(self, route_name: str): + from qdrant_client import models + + self.client.delete( + self.collection_name, + points_selector=models.Filter( + must=[ + models.FieldCondition( + key="sr_route", + match=models.MatchText(text=route_name), + ) + ] + ), + ) + + def describe(self) -> dict: + collection_info = self.client.get_collection(self.collection_name) + + return { + "type": self.type, + "dimensions": collection_info.config.params.vectors.size, + "vectors": collection_info.points_count, + } + + def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]: + results = self.client.search( + self.collection_name, query_vector=vector, limit=top_k, with_payload=True + ) + scores = [result.score for result in results] + route_names = [result.payload["sr_route"] for result in results] + return np.array(scores), route_names + + def delete_index(self): + self.client.delete_collection(self.collection_name) + + def __len__(self): + return self.client.get_collection(self.collection_name).points_count diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 221de2be..851fc1b6 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -318,9 +318,9 @@ def from_yaml(cls, file_path: str): return cls(encoder=encoder, routes=config.routes) @classmethod - def from_config(cls, config: LayerConfig): + def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None): encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model - return cls(encoder=encoder, routes=config.routes) + return cls(encoder=encoder, routes=config.routes, index=index) def add(self, route: Route): logger.info(f"Adding `{route.name}` route") diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 4a55777b..f6b99bb5 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -1,3 +1,4 @@ +import importlib import os import tempfile from unittest.mock import mock_open, patch @@ -5,6 +6,8 @@ import pytest from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder +from semantic_router.index.local import LocalIndex +from semantic_router.index.qdrant import QdrantIndex from semantic_router.layer import LayerConfig, RouteLayer from semantic_router.llms.base import BaseLLM from semantic_router.route import Route @@ -19,7 +22,7 @@ def mock_encoder_call(utterances): "Bye": [1.0, 1.1, 1.2], "Au revoir": [1.3, 1.4, 1.5], } - return [mock_responses.get(u, [0, 0, 0]) for u in utterances] + return [mock_responses.get(u, [0.0, 0.0, 0.0]) for u in utterances] def layer_json(): @@ -118,9 +121,19 @@ def test_data(): ] +def get_test_indexes(): + indexes = [LocalIndex] + + if importlib.util.find_spec("qdrant_client") is not None: + indexes.append(QdrantIndex) + + return indexes + + +@pytest.mark.parametrize("index_cls", get_test_indexes()) class TestRouteLayer: - def test_initialization(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes, top_k=10) + def test_initialization(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer(encoder=openai_encoder, routes=routes, top_k=10, index=index_cls()) assert openai_encoder.score_threshold == 0.82 assert route_layer.score_threshold == 0.82 assert route_layer.top_k == 10 @@ -131,29 +144,35 @@ def test_initialization(self, openai_encoder, routes): else 0 == 2 ) - def test_initialization_different_encoders(self, cohere_encoder, openai_encoder): - route_layer_cohere = RouteLayer(encoder=cohere_encoder) + def test_initialization_different_encoders( + self, cohere_encoder, openai_encoder, index_cls + ): + route_layer_cohere = RouteLayer(encoder=cohere_encoder, index=index_cls()) assert cohere_encoder.score_threshold == 0.3 assert route_layer_cohere.score_threshold == 0.3 - route_layer_openai = RouteLayer(encoder=openai_encoder) + route_layer_openai = RouteLayer(encoder=openai_encoder, index=index_cls()) assert route_layer_openai.score_threshold == 0.82 - def test_initialization_no_encoder(self, openai_encoder): + def test_initialization_no_encoder(self, openai_encoder, index_cls): os.environ["OPENAI_API_KEY"] = "test_api_key" route_layer_none = RouteLayer(encoder=None) assert route_layer_none.score_threshold == openai_encoder.score_threshold def test_initialization_dynamic_route( - self, cohere_encoder, openai_encoder, dynamic_routes + self, cohere_encoder, openai_encoder, dynamic_routes, index_cls ): - route_layer_cohere = RouteLayer(encoder=cohere_encoder, routes=dynamic_routes) + route_layer_cohere = RouteLayer( + encoder=cohere_encoder, routes=dynamic_routes, index=index_cls() + ) assert route_layer_cohere.score_threshold == 0.3 - route_layer_openai = RouteLayer(encoder=openai_encoder, routes=dynamic_routes) + route_layer_openai = RouteLayer( + encoder=openai_encoder, routes=dynamic_routes, index=index_cls() + ) assert openai_encoder.score_threshold == 0.82 assert route_layer_openai.score_threshold == 0.82 - def test_add_route(self, openai_encoder): - route_layer = RouteLayer(encoder=openai_encoder) + def test_add_route(self, openai_encoder, index_cls): + route_layer = RouteLayer(encoder=openai_encoder, index=index_cls()) route1 = Route(name="Route 1", utterances=["Yes", "No"]) route2 = Route(name="Route 2", utterances=["Maybe", "Sure"]) @@ -172,15 +191,19 @@ def test_add_route(self, openai_encoder): assert route_layer.routes == [route1, route2] assert route_layer.index.describe()["vectors"] == 4 - def test_list_route_names(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + def test_list_route_names(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) route_names = route_layer.list_route_names() assert set(route_names) == { route.name for route in routes }, "The list of route names should match the names of the routes added." - def test_delete_route(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + def test_delete_route(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) # Delete a route by name route_to_delete = routes[0].name route_layer.delete(route_to_delete) @@ -194,8 +217,10 @@ def test_delete_route(self, openai_encoder, routes): utterance not in route_layer.index ), "The route's utterances should be deleted from the index." - def test_remove_route_not_found(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + def test_remove_route_not_found(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) # Attempt to remove a route that does not exist non_existent_route = "non-existent-route" with pytest.raises(ValueError) as excinfo: @@ -204,35 +229,43 @@ def test_remove_route_not_found(self, openai_encoder, routes): str(excinfo.value) == f"Route `{non_existent_route}` not found" ), "Attempting to remove a non-existent route should raise a ValueError." - def test_add_multiple_routes(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder) + def test_add_multiple_routes(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer(encoder=openai_encoder, index=index_cls()) route_layer._add_routes(routes=routes) assert route_layer.index is not None assert route_layer.index.describe()["vectors"] == 5 - def test_query_and_classification(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + def test_query_and_classification(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) query_result = route_layer(text="Hello").name assert query_result in ["Route 1", "Route 2"] - def test_query_with_no_index(self, openai_encoder): - route_layer = RouteLayer(encoder=openai_encoder) + def test_query_with_no_index(self, openai_encoder, index_cls): + route_layer = RouteLayer(encoder=openai_encoder, index=index_cls()) with pytest.raises(ValueError): assert route_layer(text="Anything").name is None - def test_query_with_vector(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + def test_query_with_vector(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) vector = [0.1, 0.2, 0.3] query_result = route_layer(vector=vector).name assert query_result in ["Route 1", "Route 2"] - def test_query_with_no_text_or_vector(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + def test_query_with_no_text_or_vector(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) with pytest.raises(ValueError): route_layer() - def test_semantic_classify(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + def test_semantic_classify(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -242,8 +275,10 @@ def test_semantic_classify(self, openai_encoder, routes): assert classification == "Route 1" assert score == [0.9] - def test_semantic_classify_multiple_routes(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + def test_semantic_classify_multiple_routes(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) classification, score = route_layer._semantic_classify( [ {"route": "Route 1", "score": 0.9}, @@ -254,28 +289,34 @@ def test_semantic_classify_multiple_routes(self, openai_encoder, routes): assert classification == "Route 1" assert score == [0.9, 0.8] - def test_query_no_text_dynamic_route(self, openai_encoder, dynamic_routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=dynamic_routes) + def test_query_no_text_dynamic_route( + self, openai_encoder, dynamic_routes, index_cls + ): + route_layer = RouteLayer( + encoder=openai_encoder, routes=dynamic_routes, index=index_cls() + ) vector = [0.1, 0.2, 0.3] with pytest.raises(ValueError): route_layer(vector=vector) - def test_pass_threshold(self, openai_encoder): - route_layer = RouteLayer(encoder=openai_encoder) + def test_pass_threshold(self, openai_encoder, index_cls): + route_layer = RouteLayer(encoder=openai_encoder, index=index_cls()) assert not route_layer._pass_threshold([], 0.5) assert route_layer._pass_threshold([0.6, 0.7], 0.5) - def test_failover_score_threshold(self, base_encoder): - route_layer = RouteLayer(encoder=base_encoder) + def test_failover_score_threshold(self, base_encoder, index_cls): + route_layer = RouteLayer(encoder=base_encoder, index=index_cls()) assert route_layer.score_threshold == 0.5 - def test_json(self, openai_encoder, routes): + def test_json(self, openai_encoder, routes, index_cls): temp = tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) try: temp_path = temp.name # Save the temporary file's path temp.close() # Close the file to ensure it can be opened again on Windows os.environ["OPENAI_API_KEY"] = "test_api_key" - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) route_layer.to_json(temp_path) assert os.path.exists(temp_path) route_layer_from_file = RouteLayer.from_json(temp_path) @@ -286,13 +327,15 @@ def test_json(self, openai_encoder, routes): finally: os.remove(temp_path) # Ensure the file is deleted even if the test fails - def test_yaml(self, openai_encoder, routes): + def test_yaml(self, openai_encoder, routes, index_cls): temp = tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) try: temp_path = temp.name # Save the temporary file's path temp.close() # Close the file to ensure it can be opened again on Windows os.environ["OPENAI_API_KEY"] = "test_api_key" - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) route_layer.to_yaml(temp_path) assert os.path.exists(temp_path) route_layer_from_file = RouteLayer.from_yaml(temp_path) @@ -303,7 +346,7 @@ def test_yaml(self, openai_encoder, routes): finally: os.remove(temp_path) # Ensure the file is deleted even if the test fails - def test_from_file_json(openai_encoder, tmp_path): + def test_from_file_json(openai_encoder, tmp_path, index_cls): # Create a temporary JSON file with layer configuration config_path = tmp_path / "config.json" config_path.write_text( @@ -319,7 +362,7 @@ def test_from_file_json(openai_encoder, tmp_path): assert len(layer_config.routes) == 2 assert layer_config.routes[0].name == "politics" - def test_from_file_yaml(openai_encoder, tmp_path): + def test_from_file_yaml(openai_encoder, tmp_path, index_cls): # Create a temporary YAML file with layer configuration config_path = tmp_path / "config.yaml" config_path.write_text( @@ -335,14 +378,14 @@ def test_from_file_yaml(openai_encoder, tmp_path): assert len(layer_config.routes) == 2 assert layer_config.routes[0].name == "politics" - def test_from_file_invalid_path(self): + def test_from_file_invalid_path(self, index_cls): with pytest.raises(FileNotFoundError) as excinfo: LayerConfig.from_file("nonexistent_path.json") assert "[Errno 2] No such file or directory: 'nonexistent_path.json'" in str( excinfo.value ) - def test_from_file_unsupported_type(self, tmp_path): + def test_from_file_unsupported_type(self, tmp_path, index_cls): # Create a temporary unsupported file config_path = tmp_path / "config.unsupported" config_path.write_text(layer_json()) @@ -351,7 +394,7 @@ def test_from_file_unsupported_type(self, tmp_path): LayerConfig.from_file(str(config_path)) assert "Unsupported file type" in str(excinfo.value) - def test_from_file_invalid_config(self, tmp_path): + def test_from_file_invalid_config(self, tmp_path, index_cls): # Define an invalid configuration JSON invalid_config_json = """ { @@ -375,7 +418,7 @@ def test_from_file_invalid_config(self, tmp_path): excinfo.value ), "Loading an invalid configuration should raise an exception." - def test_from_file_with_llm(self, tmp_path): + def test_from_file_with_llm(self, tmp_path, index_cls): llm_config_json = """ { "encoder_type": "cohere", @@ -409,22 +452,25 @@ def test_from_file_with_llm(self, tmp_path): layer_config.routes[0].llm.name == "fake-model-v1" ), "LLM instance should have the 'name' attribute set correctly" - def test_config(self, openai_encoder, routes): + def test_config(self, openai_encoder, routes, index_cls): os.environ["OPENAI_API_KEY"] = "test_api_key" - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) # confirm route creation functions as expected layer_config = route_layer.to_config() assert layer_config.routes == routes # now load from config and confirm it's the same - route_layer_from_config = RouteLayer.from_config(layer_config) - assert (route_layer_from_config.index.index == route_layer.index.index).all() + route_layer_from_config = RouteLayer.from_config(layer_config, index_cls()) assert ( route_layer_from_config._get_route_names() == route_layer._get_route_names() ) assert route_layer_from_config.score_threshold == route_layer.score_threshold - def test_get_thresholds(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + def test_get_thresholds(self, openai_encoder, routes, index_cls): + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) assert route_layer.get_thresholds() == {"Route 1": 0.82, "Route 2": 0.82}