diff --git a/LICENSES.txt b/LICENSES.txt index 71dc3837a7..52fa317330 100644 --- a/LICENSES.txt +++ b/LICENSES.txt @@ -3061,6 +3061,7 @@ MIT jnr-x86asm-1.0.2.jar jsoup-1.15.3.jar localstack-1.17.6.jar + milvus-1.19.7.jar mockito-core-3.12.4.jar mssql-jdbc-6.2.1.jre7.jar mysql-1.17.6.jar diff --git a/NOTICE.txt b/NOTICE.txt index 67f7472324..d4c048525d 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -462,6 +462,7 @@ MIT jnr-x86asm-1.0.2.jar jsoup-1.15.3.jar localstack-1.17.6.jar + milvus-1.19.7.jar mockito-core-3.12.4.jar mssql-jdbc-6.2.1.jre7.jar mysql-1.17.6.jar diff --git a/docs/asciidoc/modules/ROOT/images/pinecone-index.png b/docs/asciidoc/modules/ROOT/images/pinecone-index.png new file mode 100644 index 0000000000..5ef736a177 Binary files /dev/null and b/docs/asciidoc/modules/ROOT/images/pinecone-index.png differ diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/index.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/index.adoc index 7f818d2bff..d649284cca 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/index.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/index.adoc @@ -49,15 +49,17 @@ See the following pages for more details on specific vector db procedures - xref:./qdrant.adoc[Qdrant] - xref:./chroma.adoc[ChromaDB] - xref:./weaviate.adoc[Weaviate] +- xref:./pinecone.adoc[Pinecone] +- xref:./milvus.adoc[Milvus] -== Store Vector db info (i.e. `apoc.vectordb.configure`) +== Store Vector db info (i.e. `apoc.vectordb.configure`) We can save some info in the System Database to be reused later, that is the host, login credentials, and mapping, to be used in `*.get` and `.*query` procedures, except for the `apoc.vectordb.custom.get` one. Therefore, to store the vector info, we can execute the `CALL apoc.vectordb.configure(vectorName, keyConfig, databaseName, $configMap)`, -where `vectorName` can be "QDRANT", "CHROMA" or "WEAVIATE", +where `vectorName` can be "QDRANT", "CHROMA", "PINECONE", "MILVUS" or "WEAVIATE", that indicates info to be reused respectively by `apoc.vectordb.qdrant.*`, `apoc.vectordb.chroma.*` and `apoc.vectordb.weaviate.*`. Then `keyConfig` is the configuration name, `databaseName` is the database where the config will be set, diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc new file mode 100644 index 0000000000..4bf59e6322 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc @@ -0,0 +1,225 @@ + +== Pinecone + +Here is a list of all available Pinecone procedures: + +[opts=header, cols="1, 3"] +|=== +| name | description +| apoc.vectordb.pinecone.createCollection(hostOrKey, index, similarity, size, $config) | + Creates an index, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`. + The default endpoint is `/indexes`. +| apoc.vectordb.pinecone.deleteCollection(hostOrKey, index, $config) | + Deletes an index with the name specified in the 2nd parameter. + The default endpoint is `/indexes/`. +| apoc.vectordb.pinecone.upsert(hostOrKey, index, vectors, $config) | + Upserts, in the index with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}]. + The default endpoint is `/vectors/upsert`. +| apoc.vectordb.pinecone.delete(hostOrKey, index, ids, $config) | + Delete the vectors with the specified `ids`. + The default endpoint is `/indexes/`. +| apoc.vectordb.pinecone.get(hostOrKey, index, ids, $config) | + Get the vectors with the specified `ids`. + The default endpoint is `/vectors/fetch`. +| apoc.vectordb.pinecone.getAndUpdate(hostOrKey, index, ids, $config) | + Get the vectors with the specified `ids`, and optionally creates/updates neo4j entities. + The default endpoint is `/vectors/fetch`. +| apoc.vectordb.pinecone.query(hostOrKey, index, vector, filter, limit, $config) | + Retrieve closest vectors the the defined `vector`, `limit` of results, in the index with the name specified in the 2nd parameter. + The default endpoint is `/query`. +| apoc.vectordb.pinecone.queryAndUpdate(hostOrKey, index, vector, filter, limit, $config) | + Retrieve closest vectors the the defined `vector`, `limit` of results, in the index with the name specified in the 2nd parameter, and optionally creates/updates neo4j entities. + The default endpoint is `/query`. +|=== + +where the 1st parameter can be a key defined by the apoc config `apoc.pinecone..host=myHost`. + +[NOTE] +==== +The procedures create/drop/handle an index, instead of a collection like the other vectordb procedures, +since in Pinecone a collection is a static and non-queryable copy of an index. + +Anyway, the create / delete index procedures are named `.createCollection` and `.deleteCollection` to be consistent with the other. +==== + + +The default `hostOrKey` is `"https://api.pinecone.io"`, +therefore in general can be null with the `createCollection` and `deleteCollection` procedures, +and equal to the host name, with the other ones, that is, the one indicated in the Pinecone dashboard: + +image::pinecone-index.png[width=800] + + +=== Examples + +The following example assume we want to create and manage an index called `test-index`. + +.Create an index (it leverages https://docs.pinecone.io/reference/api/control-plane/create_index[this API]) +[source,cypher] +---- +CALL apoc.vectordb.pinecone.createCollection(null, 'test-index', 'cosine', 4, {}) +---- + + +.Delete an index (it leverages https://docs.pinecone.io/reference/api/control-plane/delete_index[this API]) +[source,cypher] +---- +CALL apoc.vectordb.pinecone.deleteCollection(null, 'test-index', {}) +---- + + +.Upsert vectors (it leverages https://docs.pinecone.io/reference/api/data-plane/upsert[this API]) +[source,cypher] +---- +CALL apoc.vectordb.pinecone.upsert('https://test-index-ilx67g5.svc.aped-4627-b74a.pinecone.io', + 'test-index', + [ + {id: '1', vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}}, + {id: '2', vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}} + ], + {}) +---- + + +.Get vectors (it leverages https://docs.pinecone.io/reference/api/data-plane/fetch[this API]) + +[source,cypher] +---- +CALL apoc.vectordb.pinecone.get($host, 'test-index', [1,2], {}) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | entity +| null | {city: "Berlin", foo: "one"} | null | null | null | null +| null | {city: "Berlin", foo: "two"} | null | null | null | null +| ... +|=== + +.Get vectors with `{allResults: true}` +[source,cypher] +---- +CALL apoc.vectordb.pinecone.get($host, 'test-index', ['1','2'], {allResults: true, }) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | entity +| null | {city: "Berlin", foo: "one"} | 1 | [...] | null | null +| null | {city: "Berlin", foo: "two"} | 2 | [...] | null | null +| ... +|=== + +.Query vectors (it leverages https://docs.pinecone.io/reference/api/data-plane/query[this API]) +[source,cypher] +---- +CALL apoc.vectordb.pinecone.query($host, + 'test-index', + [0.2, 0.1, 0.9, 0.7], + { city: { `$eq`: "London" } }, + 5, + {allResults: true, }) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | entity +| 1, | {city: "Berlin", foo: "one"} | 1 | [...] | null | null +| 0.1 | {city: "Berlin", foo: "two"} | 2 | [...] | null | null +| ... +|=== + + +We can define a mapping, to auto-create one/multiple nodes and relationships, by leveraging the vector metadata. + +For example, if we have created 2 vectors with the above upsert procedures, +we can populate some existing nodes (i.e. `(:Test {myId: 'one'})` and `(:Test {myId: 'two'})`): + + +[source,cypher] +---- +CALL apoc.vectordb.pinecone.queryAndUpdate($host, 'test-index', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingKey: "vect", + nodeLabel: "Test", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +which populates the two nodes as: `(:Test {myId: 'one', city: 'Berlin', vect: [vector1]})` and `(:Test {myId: 'two', city: 'London', vect: [vector2]})`, +which will be returned in the `entity` column result. + + +Or else, we can create a node if not exists, via `create: true`: + +[source,cypher] +---- +CALL apoc.vectordb.pinecone.queryAndUpdate($host, 'test-index', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + create: true, + embeddingKey: "vect", + nodeLabel: "Test", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +which creates and 2 new nodes as above. + +Or, we can populate an existing relationship (i.e. `(:Start)-[:TEST {myId: 'one'}]->(:End)` and `(:Start)-[:TEST {myId: 'two'}]->(:End)`): + + +[source,cypher] +---- +CALL apoc.vectordb.pinecone.queryAndUpdate($host, 'test-index', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingKey: "vect", + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +which populates the two relationships as: `()-[:TEST {myId: 'one', city: 'Berlin', vect: [vector1]}]-()` +and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, +which will be returned in the `entity` column result. + +[NOTE] +==== +We can use mapping with `apoc.vectordb.pinecone.getAndUpdate` procedure as well +==== + +[NOTE] +==== +To optimize performances, we can choose what to `YIELD` with the `apoc.vectordb.pinecone.query*` and the `apoc.vectordb.pinecone.get*` procedures. + +For example, by executing a `CALL apoc.vectordb.pinecone.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"with_payload": false, "with_vectors": false}, +so that we do not return the other values that we do not need. +==== + + + +.Delete vectors (it leverages https://docs.pinecone.io/reference/api/data-plane/delete[this API]) +[source,cypher] +---- +CALL apoc.vectordb.pinecone.delete($host, 'test-index', ['1','2'], {}) +---- diff --git a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java index 7e1154563a..8701ebd7fd 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java @@ -16,7 +16,6 @@ import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; -import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; import static apoc.vectordb.VectorDbTestUtil.ragSetup; @@ -284,8 +283,8 @@ public void queryVectorsWithCreateNode() { "myId", METADATA_KEY, "foo", - CREATE_KEY, - true)); + MODE_KEY, + MappingMode.CREATE_IF_MISSING.toString())); testResult( db, diff --git a/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java b/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java new file mode 100644 index 0000000000..893f5064d8 --- /dev/null +++ b/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java @@ -0,0 +1,409 @@ +package apoc.full.it.vectordb; + +import apoc.util.TestUtil; +import apoc.util.Util; +import apoc.vectordb.Milvus; +import apoc.vectordb.VectorDb; +import apoc.vectordb.VectorDbTestUtil; +import apoc.vectordb.VectorMappingConfig; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.neo4j.dbms.api.DatabaseManagementService; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.test.TestDatabaseManagementServiceBuilder; +import org.testcontainers.milvus.MilvusContainer; + +import java.util.List; +import java.util.Map; + +import static apoc.util.MapUtil.map; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static apoc.vectordb.VectorDbHandler.Type.MILVUS; +import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.REL; +import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; +import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; +import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; +import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; +import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; +import static apoc.vectordb.VectorMappingConfig.EMBEDDING_KEY; +import static apoc.vectordb.VectorMappingConfig.ENTITY_KEY; +import static apoc.vectordb.VectorMappingConfig.METADATA_KEY; +import static apoc.vectordb.VectorMappingConfig.MODE_KEY; +import static apoc.vectordb.VectorMappingConfig.NODE_LABEL; +import static apoc.vectordb.VectorMappingConfig.REL_TYPE; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; +import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; + + +public class MilvusTest { + private static final List FIELDS = List.of("city", "foo"); + private static final MilvusContainer MILVUS_CONTAINER = new MilvusContainer("milvusdb/milvus:v2.4.0"); + + private static String HOST; + + @ClassRule + public static TemporaryFolder storeDir = new TemporaryFolder(); + + private static GraphDatabaseService sysDb; + private static GraphDatabaseService db; + private static DatabaseManagementService databaseManagementService; + + @BeforeClass + public static void setUp() throws Exception { + databaseManagementService = new TestDatabaseManagementServiceBuilder(storeDir.getRoot().toPath()) + .build(); + db = databaseManagementService.database(DEFAULT_DATABASE_NAME); + sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME); + + MILVUS_CONTAINER.start(); + + HOST = MILVUS_CONTAINER.getEndpoint(); + TestUtil.registerProcedure(db, Milvus.class, VectorDb.class); + + testCall(db, "CALL apoc.vectordb.milvus.createCollection($host, 'test_collection', 'COSINE', 4)", + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertEquals(200L, value.get("code")); + }); + + testCall(db, "CALL apoc.vectordb.milvus.upsert($host, 'test_collection',\n" + + "[\n" + + " {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}},\n" + + " {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}}\n" + + "])", + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertEquals(200L, value.get("code")); + }); + } + + @AfterClass + public static void tearDown() throws Exception { + testCall(db, "CALL apoc.vectordb.milvus.deleteCollection($host, 'test_collection')", + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertEquals(200L, value.get("code")); + }); + + databaseManagementService.shutdown(); + MILVUS_CONTAINER.stop(); + } + + @Before + public void before() { + dropAndDeleteAll(db); + } + + @Test + public void getVectorsWithoutVectorResult() { + testResult(db, "CALL apoc.vectordb.milvus.get($host, 'test_collection', [1], $conf) ", + map("host", HOST, "conf", map(FIELDS_KEY, FIELDS)), + r -> { + Map row = r.next(); + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + }); + } + + @Test + public void deleteVector() { + testCall(db, "CALL apoc.vectordb.milvus.upsert($host, 'test_collection',\n" + + "[\n" + + " {id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}},\n" + + " {id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" + + "])", + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertEquals(200L, value.get("code")); + }); + + testCall(db, "CALL apoc.vectordb.milvus.delete($host, 'test_collection', [3, 4]) ", + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertEquals(200L, value.get("code")); + }); + + Util.sleep(2000); + } + + @Test + public void queryVectors() { + testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", HOST, "conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true)), + r -> { + Map row = r.next(); + assertBerlinResult(row, FALSE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, FALSE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void queryVectorsWithoutVectorResult() { + testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", HOST, "conf", map(FIELDS_KEY, FIELDS)), + r -> { + Map row = r.next(); + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + + row = r.next(); + assertEquals(Map.of("city", "London", "foo", "two"), row.get("metadata")); + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + }); + } + + @Test + public void queryVectorsWithYield() { + testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) YIELD metadata, id", + map("host", HOST, + "conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true) + ), + r -> { + assertBerlinResult(r.next(), FALSE); + assertLondonResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithFilter() { + testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7],\n" + + "'city == \"London\"',\n" + + "5, $conf) YIELD metadata, id", + map("host", HOST, + "conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true) + ), + r -> { + assertLondonResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithLimit() { + testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 1, $conf) YIELD metadata, id", + map("host", HOST, + "conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true) + ), + r -> { + assertBerlinResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithCreateNode() { + + Map conf = map(FIELDS_KEY, FIELDS, + ALL_RESULTS_KEY, true, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "myId", + METADATA_KEY, "foo", + MODE_KEY, VectorMappingConfig.MappingMode.CREATE_IF_MISSING.toString() + ) + ); + testResult(db, "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + + testResult(db, "MATCH (n:Test) RETURN properties(n) AS props ORDER BY n.myId", + VectorDbTestUtil::vectorEntityAssertions); + + testResult(db, "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void getVectorsWithCreateNodeUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = map(ALL_RESULTS_KEY, true, + FIELDS_KEY, FIELDS, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "myId", + METADATA_KEY, "foo")); + + testResult(db, "CALL apoc.vectordb.milvus.getAndUpdate($host, 'test_collection', [1, 2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void queryVectorsWithCreateNodeUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = map(FIELDS_KEY, FIELDS, + ALL_RESULTS_KEY, true, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "myId", + METADATA_KEY, "foo")); + + testResult(db, "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void queryVectorsWithCreateRel() { + + db.executeTransactionally("CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); + + Map conf = map(FIELDS_KEY, FIELDS, + ALL_RESULTS_KEY, true, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + REL_TYPE, "TEST", + ENTITY_KEY, "myId", + METADATA_KEY, "foo")); + testResult(db, "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, REL); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, REL); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertRelsCreated(db); + } + + @Test + public void queryReadOnlyVectorsWithMapping() { + Map conf = map(ALL_RESULTS_KEY, true, + MAPPING_KEY, map(EMBEDDING_KEY, "vect")); + + try { + testCall(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "conf", conf), + r -> fail() + ); + } catch (RuntimeException e) { + Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); + } + } + + @Test + public void queryVectorsWithSystemDbStorage() { + String keyConfig = "milvus-config-foo"; + Map mapping = map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "myId", + METADATA_KEY, "foo"); + + sysDb.executeTransactionally("CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)", + map("vectorName", MILVUS.toString(), + "keyConfig", keyConfig, + "databaseName", DEFAULT_DATABASE_NAME, + "conf", map( + "host", HOST + "/v2/vectordb", + "mapping", mapping + ) + ) + ); + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + testResult(db, "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", keyConfig, "conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true)), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + +} diff --git a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java index 21912c49ff..cc2c3443c9 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java @@ -360,7 +360,8 @@ public void queryVectorsWithCreateNode() { NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo", - CREATE_KEY, true)); + MODE_KEY, + MappingMode.CREATE_IF_MISSING.toString())); testResult( db, "CALL apoc.vectordb.qdrant.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", diff --git a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java index ec2218551f..5f7cff3858 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java @@ -2,6 +2,7 @@ import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.ml.RestAPIConfig.HEADERS_KEY; +import static apoc.util.ExtendedTestUtil.assertFails; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testCallEmpty; import static apoc.util.TestUtil.testResult; @@ -31,6 +32,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; + import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; @@ -245,7 +247,7 @@ public void queryVectors() { "host", HOST, "conf", - map(ALL_RESULTS_KEY, true, "fields", FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + map(ALL_RESULTS_KEY, true, FIELDS_KEY, FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), r -> { Map row = r.next(); assertBerlinResult(row, ID_1, FALSE); @@ -265,7 +267,7 @@ public void queryVectorsWithoutVectorResult() { db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + " YIELD score, vector, id, metadata, node RETURN * ORDER BY id", - map("host", HOST, "conf", map("fields", FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + map("host", HOST, "conf", map(FIELDS_KEY, FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), r -> { Map row = r.next(); assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); @@ -291,7 +293,7 @@ public void queryVectorsWithYield() { "host", HOST, "conf", - map(ALL_RESULTS_KEY, true, "fields", FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + map(ALL_RESULTS_KEY, true, FIELDS_KEY, FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), r -> { assertBerlinResult(r.next(), ID_1, FALSE); assertLondonResult(r.next(), ID_2, FALSE); @@ -309,7 +311,7 @@ public void queryVectorsWithFilter() { "host", HOST, "conf", - map(ALL_RESULTS_KEY, true, "fields", FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + map(ALL_RESULTS_KEY, true, FIELDS_KEY, FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), r -> { assertLondonResult(r.next(), ID_2, FALSE); }); @@ -324,7 +326,7 @@ public void queryVectorsWithLimit() { "host", HOST, "conf", - map(ALL_RESULTS_KEY, true, "fields", FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + map(ALL_RESULTS_KEY, true, FIELDS_KEY, FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), r -> { assertBerlinResult(r.next(), ID_1, FALSE); }); @@ -336,7 +338,7 @@ public void queryVectorsWithCreateNode() { Map conf = map( ALL_RESULTS_KEY, true, - "fields", + FIELDS_KEY, FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION, @@ -350,8 +352,8 @@ public void queryVectorsWithCreateNode() { "myId", METADATA_KEY, "foo", - CREATE_KEY, - true)); + MODE_KEY, + MappingMode.CREATE_IF_MISSING.toString())); testResult( db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " @@ -404,7 +406,7 @@ public void queryVectorsWithCreateNodeUsingExistingNode() { Map conf = map( ALL_RESULTS_KEY, true, - "fields", + FIELDS_KEY, FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION, @@ -484,7 +486,7 @@ public void queryVectorsWithCreateRel() { Map conf = map( ALL_RESULTS_KEY, true, - "fields", + FIELDS_KEY, FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION, @@ -532,7 +534,7 @@ public void queryVectorsWithCreateRelWithoutVectorResult() { "CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); Map conf = map( - "fields", + FIELDS_KEY, FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION, diff --git a/full/src/main/java/apoc/vectordb/ChromaDb.java b/full/src/main/java/apoc/vectordb/ChromaDb.java index e8b6947ed2..48e2766ef0 100644 --- a/full/src/main/java/apoc/vectordb/ChromaDb.java +++ b/full/src/main/java/apoc/vectordb/ChromaDb.java @@ -115,7 +115,7 @@ public Stream delete( Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); VectorEmbeddingConfig apiConfig = - DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, getStringIds(ids)); + DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, getStringIds(ids), collection); return executeRequest(apiConfig.getApiConfig()).map(v -> (List) v).map(ListResult::new); } @@ -153,7 +153,8 @@ private Stream getCommon( checkMappingConf(configuration, "apoc.vectordb.chroma.getAndUpdate"); } - VectorEmbeddingConfig apiConfig = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids); + VectorEmbeddingConfig apiConfig = + DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); return getEmbeddingResultStream(apiConfig, procedureCallContext, tx, v -> listToMap((Map) v).stream()); } diff --git a/full/src/main/java/apoc/vectordb/ChromaHandler.java b/full/src/main/java/apoc/vectordb/ChromaHandler.java index 6442f5ee99..d4159344ca 100644 --- a/full/src/main/java/apoc/vectordb/ChromaHandler.java +++ b/full/src/main/java/apoc/vectordb/ChromaHandler.java @@ -30,7 +30,7 @@ static class ChromaEmbeddingHandler implements VectorEmbeddingHandler { @Override public VectorEmbeddingConfig fromGet( - Map config, ProcedureCallContext procedureCallContext, List ids) { + Map config, ProcedureCallContext procedureCallContext, List ids, String collection) { List fields = procedureCallContext.outputFields().collect(Collectors.toList()); diff --git a/full/src/main/java/apoc/vectordb/Milvus.java b/full/src/main/java/apoc/vectordb/Milvus.java new file mode 100644 index 0000000000..1be5df7351 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/Milvus.java @@ -0,0 +1,233 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.vectordb.VectorDb.executeRequest; +import static apoc.vectordb.VectorDb.getEmbeddingResultStream; +import static apoc.vectordb.VectorDbHandler.Type.MILVUS; +import static apoc.vectordb.VectorDbUtil.*; + +import apoc.Extended; +import apoc.ml.RestAPIConfig; +import apoc.result.MapResult; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +@Extended +public class Milvus { + public static final VectorDbHandler DB_HANDLER = MILVUS.get(); + + @Context + public ProcedureCallContext procedureCallContext; + + @Context + public Transaction tx; + + @Context + public GraphDatabaseService db; + + @Procedure("apoc.vectordb.milvus.createCollection") + @Description( + "apoc.vectordb.milvus.createCollection(hostOrKey, collection, similarity, size, $configuration) - Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`") + public Stream createCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("similarity") String similarity, + @Name("size") Long size, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + String url = "%s/collections/create"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map additionalBodies = + Map.of("collectionName", collection, "dimension", size, "metricType", similarity); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.milvus.deleteCollection") + @Description( + "apoc.vectordb.milvus.deleteCollection(hostOrKey, collection, $configuration) - Deletes a collection with the name specified in the 2nd parameter") + public Stream deleteCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + String url = "%s/collections/drop"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "POST"); + Map additionalBodies = Map.of("collectionName", collection); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.milvus.upsert") + @Description( + "apoc.vectordb.milvus.upsert(hostOrKey, collection, vectors, $configuration) - Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}]") + public Stream upsert( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List> vectors, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + String url = "%s/entities/upsert"; + + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "POST"); + + List> data = vectors.stream() + .map(i -> { + Map map = new HashMap<>(i); + map.putAll((Map) map.remove("metadata")); + return map; + }) + .collect(Collectors.toList()); + Map additionalBodies = Map.of("data", data, "collectionName", collection); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.milvus.delete") + @Description( + "apoc.vectordb.milvus.delete(hostOrKey, collection, ids, $configuration) - Delete the vectors with the specified `ids`") + public Stream delete( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + String url = "%s/entities/delete"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "POST"); + + String filter = "id in " + ids; + Map additionalBodies = Map.of("collectionName", collection, "filter", filter); + RestAPIConfig apiConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(apiConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure(value = "apoc.vectordb.milvus.get", mode = Mode.WRITE) + @Description( + "apoc.vectordb.milvus.get(hostOrKey, collection, ids, $configuration) - Get the vectors with the specified `ids`") + public Stream get( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return getCommon(hostOrKey, collection, ids, configuration, true); + } + + @Procedure(value = "apoc.vectordb.milvus.getAndUpdate", mode = Mode.WRITE) + @Description( + "apoc.vectordb.milvus.getAndUpdate(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`, and optionally creates/updates neo4j entities") + public Stream getAndUpdate( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return getCommon(hostOrKey, collection, ids, configuration, false); + } + + private Stream getCommon( + String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) + throws Exception { + String url = "%s/entities/get"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + + if (readOnly) { + checkMappingConf(configuration, "apoc.vectordb.milvus.getAndUpdate"); + } + + VectorEmbeddingConfig apiConfig = + DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + return getEmbeddingResultStream(apiConfig, procedureCallContext, tx, v -> getMapStream((Map) v)); + } + + @Procedure(value = "apoc.vectordb.milvus.query") + @Description( + "apoc.vectordb.milvus.query(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieve closest vectors the the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter") + public Stream query( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "null") Object filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + } + + @Procedure(value = "apoc.vectordb.milvus.queryAndUpdate", mode = Mode.WRITE) + @Description( + "apoc.vectordb.milvus.queryAndUpdate(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieve closest vectors the the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter") + public Stream queryAndUpdate( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "null") Object filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + } + + private Stream getMapStream(Map v) { + var data = v.get("data"); + + return ((List) data).stream().map(i -> { + var metadata = new HashMap<>(i); + metadata.remove("id"); + metadata.remove("vector"); + metadata.remove("distance"); + + i.put("metadata", metadata); + + return i; + }); + } + + private Stream queryCommon( + String hostOrKey, + String collection, + List vector, + Object filter, + long limit, + Map configuration, + boolean readOnly) + throws Exception { + String url = "%s/entities/search"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + + if (readOnly) { + checkMappingConf(configuration, "apoc.vectordb.milvus.queryAndUpdate"); + } + + VectorEmbeddingConfig apiConfig = + DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + return getEmbeddingResultStream(apiConfig, procedureCallContext, tx, v -> getMapStream((Map) v)); + } + + private Map getVectorDbInfo( + String hostOrKey, String collection, Map configuration, String templateUrl) { + return getCommonVectorDbInfo(hostOrKey, collection, configuration, templateUrl, DB_HANDLER); + } +} diff --git a/full/src/main/java/apoc/vectordb/MilvusHandler.java b/full/src/main/java/apoc/vectordb/MilvusHandler.java new file mode 100644 index 0000000000..2ea22544b1 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/MilvusHandler.java @@ -0,0 +1,83 @@ +package apoc.vectordb; + +import static apoc.util.MapUtil.map; +import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.META_AS_SUBKEY_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.SCORE_KEY; + +import apoc.util.UrlResolver; +import java.util.List; +import java.util.Map; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; + +public class MilvusHandler implements VectorDbHandler { + + @Override + public String getUrl(String hostOrKey) { + String url = new UrlResolver("http", "localhost", 19530).getUrl("milvus", hostOrKey); + return url + "/v2/vectordb"; + } + + @Override + public VectorEmbeddingHandler getEmbedding() { + return new MilvusEmbeddingHandler(); + } + + @Override + public String getLabel() { + return "Milvus"; + } + + // -- embedding handler + static class MilvusEmbeddingHandler implements VectorEmbeddingHandler { + + @Override + public VectorEmbeddingConfig fromGet( + Map config, ProcedureCallContext procedureCallContext, List ids, String collection) { + List fields = procedureCallContext.outputFields().toList(); + Map additionalBodies = map("id", ids); + + return getVectorEmbeddingConfig(config, fields, collection, additionalBodies); + } + + @Override + public VectorEmbeddingConfig fromQuery( + Map config, + ProcedureCallContext procedureCallContext, + List vector, + Object filter, + long limit, + String collection) { + config.putIfAbsent(SCORE_KEY, "distance"); + + List fields = procedureCallContext.outputFields().toList(); + Map additionalBodies = map("data", List.of(vector), "limit", limit); + if (filter != null) { + additionalBodies.put("filter", filter); + } + + return getVectorEmbeddingConfig(config, fields, collection, additionalBodies); + } + + private VectorEmbeddingConfig getVectorEmbeddingConfig( + Map config, + List procFields, + String collection, + Map additionalBodies) { + config.putIfAbsent(META_AS_SUBKEY_KEY, false); + + List listFields = (List) config.get(FIELDS_KEY); + if (listFields == null) { + throw new RuntimeException("You have to define `field` list of parameter to be returned"); + } + if (procFields.contains("vector") && !listFields.contains("vector")) { + listFields.add("vector"); + } + VectorEmbeddingConfig conf = new VectorEmbeddingConfig(config); + additionalBodies.put("collectionName", collection); + additionalBodies.put("outputFields", listFields); + + return VectorEmbeddingHandler.populateApiBodyRequest(conf, additionalBodies); + } + } +} diff --git a/full/src/main/java/apoc/vectordb/Pinecone.java b/full/src/main/java/apoc/vectordb/Pinecone.java new file mode 100644 index 0000000000..eeb49fb01c --- /dev/null +++ b/full/src/main/java/apoc/vectordb/Pinecone.java @@ -0,0 +1,224 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.vectordb.VectorDb.executeRequest; +import static apoc.vectordb.VectorDb.getEmbeddingResultStream; +import static apoc.vectordb.VectorDbHandler.Type.PINECONE; +import static apoc.vectordb.VectorDbUtil.checkMappingConf; +import static apoc.vectordb.VectorDbUtil.getCommonVectorDbInfo; + +import apoc.Extended; +import apoc.ml.RestAPIConfig; +import apoc.result.MapResult; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +@Extended +public class Pinecone { + public static final VectorDbHandler DB_HANDLER = PINECONE.get(); + + @Context + public ProcedureCallContext procedureCallContext; + + @Context + public Transaction tx; + + @Context + public GraphDatabaseService db; + + @Procedure("apoc.vectordb.pinecone.createCollection") + @Description( + "apoc.vectordb.pinecone.createCollection(hostOrKey, collection, similarity, size, $configuration) - Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`") + public Stream createCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("similarity") String similarity, + @Name("size") Long size, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + String url = "%s/indexes"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map additionalBodies = Map.of( + "name", collection, + "dimension", size, + "metric", similarity); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.pinecone.deleteCollection") + @Description( + "apoc.vectordb.pinecone.deleteCollection(hostOrKey, collection, $configuration) - Deletes a collection with the name specified in the 2nd parameter") + public Stream deleteCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + String url = "%s/indexes/%s"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "DELETE"); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.pinecone.upsert") + @Description( + "apoc.vectordb.pinecone.upsert(hostOrKey, collection, vectors, $configuration) - Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}]") + public Stream upsert( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List> vectors, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + String url = "%s/vectors/upsert"; + + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "POST"); + + vectors = vectors.stream() + .map(i -> { + Map map = new HashMap<>(i); + map.putIfAbsent("values", map.remove("vector")); + return map; + }) + .collect(Collectors.toList()); + + Map additionalBodies = Map.of("vectors", vectors); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.pinecone.delete") + @Description( + "apoc.vectordb.pinecone.delete(hostOrKey, collection, ids, $configuration) - Delete the vectors with the specified `ids`") + public Stream delete( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + String url = "%s/vectors/delete"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map additionalBodies = Map.of("ids", ids); + RestAPIConfig apiConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(apiConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure(value = "apoc.vectordb.pinecone.get") + @Description( + "apoc.vectordb.pinecone.get(hostOrKey, collection, ids, $configuration) - Get the vectors with the specified `ids`") + public Stream get( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return getCommon(hostOrKey, collection, ids, configuration, true); + } + + @Procedure(value = "apoc.vectordb.pinecone.getAndUpdate", mode = Mode.WRITE) + @Description( + "apoc.vectordb.pinecone.getAndUpdate(hostOrKey, collection, ids, $configuration) - Get the vectors with the specified `ids`") + public Stream getAndUpdate( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return getCommon(hostOrKey, collection, ids, configuration, false); + } + + private Stream getCommon( + String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) + throws Exception { + String url = "%s/vectors/fetch"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + + if (readOnly) { + checkMappingConf(configuration, "apoc.vectordb.pinecone.getAndUpdate"); + } + + VectorEmbeddingConfig apiConfig = + DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + return getEmbeddingResultStream(apiConfig, procedureCallContext, tx, v -> { + Object vectors = ((Map) v).get("vectors"); + return ((Map) vectors).values().stream(); + }); + } + + @Procedure(value = "apoc.vectordb.pinecone.query") + @Description( + "apoc.vectordb.pinecone.query(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieve closest vectors the the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter") + public Stream query( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "{}") Map filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + } + + @Procedure(value = "apoc.vectordb.pinecone.queryAndUpdate", mode = Mode.WRITE) + @Description( + "apoc.vectordb.pinecone.queryAndUpdate(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieve closest vectors the the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter") + public Stream queryAndUpdate( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "{}") Map filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + } + + private Stream queryCommon( + String hostOrKey, + String collection, + List vector, + Map filter, + long limit, + Map configuration, + boolean readOnly) + throws Exception { + String url = "%s/query"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + + if (readOnly) { + checkMappingConf(configuration, "apoc.vectordb.pinecone.queryAndUpdate"); + } + + VectorEmbeddingConfig apiConfig = + DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + return getEmbeddingResultStream(apiConfig, procedureCallContext, tx, v -> { + Map map = (Map) v; + return ((List) map.get("matches")).stream(); + }); + } + + private Map getVectorDbInfo( + String hostOrKey, String collection, Map configuration, String templateUrl) { + return getCommonVectorDbInfo(hostOrKey, collection, configuration, templateUrl, DB_HANDLER); + } +} diff --git a/full/src/main/java/apoc/vectordb/PineconeHandler.java b/full/src/main/java/apoc/vectordb/PineconeHandler.java new file mode 100644 index 0000000000..d4089610cf --- /dev/null +++ b/full/src/main/java/apoc/vectordb/PineconeHandler.java @@ -0,0 +1,106 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.BODY_KEY; +import static apoc.ml.RestAPIConfig.ENDPOINT_KEY; +import static apoc.ml.RestAPIConfig.HEADERS_KEY; +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.util.MapUtil.map; +import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY; + +import apoc.ml.RestAPIConfig; +import java.net.URL; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.commons.lang3.StringUtils; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; + +public class PineconeHandler implements VectorDbHandler { + + @Override + public String getUrl(String hostOrKey) { + return StringUtils.isBlank(hostOrKey) ? "https://api.pinecone.io" : hostOrKey; + } + + @Override + public VectorEmbeddingHandler getEmbedding() { + return new PineconeEmbeddingHandler(); + } + + @Override + public String getLabel() { + return "Pinecone"; + } + + @Override + public Map getCredentials(Object credentialsObj, Map config) { + Map headers = (Map) config.getOrDefault(HEADERS_KEY, new HashMap<>()); + headers.putIfAbsent("Api-Key", credentialsObj); + config.put(HEADERS_KEY, headers); + return config; + } + + // -- embedding handler + static class PineconeEmbeddingHandler implements VectorEmbeddingHandler { + + /** + * "method" should be "GET", but is null as a workaround. + * Since with `method: POST` the {@link apoc.util.Util#openUrlConnection(URL, Map)} has a `setChunkedStreamingMode` + * that makes the request to respond 200 OK, but returns an empty result + */ + @Override + public VectorEmbeddingConfig fromGet( + Map config, ProcedureCallContext procedureCallContext, List ids, String collection) { + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); + + config.put(BODY_KEY, null); + + String endpoint = (String) config.get(ENDPOINT_KEY); + if (!endpoint.contains("ids=")) { + String idsQueryUrl = ids.stream().map(i -> "ids=" + i).collect(Collectors.joining("&")); + + if (endpoint.contains("?")) { + endpoint += "&" + idsQueryUrl; + } else { + endpoint += "?" + idsQueryUrl; + } + } + + config.put(ENDPOINT_KEY, endpoint); + return getVectorEmbeddingConfig(config, fields, map()); + } + + @Override + public VectorEmbeddingConfig fromQuery( + Map config, + ProcedureCallContext procedureCallContext, + List vector, + Object filter, + long limit, + String collection) { + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); + + Map additionalBodies = map("vector", vector, "filter", filter, "topK", limit); + + return getVectorEmbeddingConfig(config, fields, additionalBodies); + } + + private VectorEmbeddingConfig getVectorEmbeddingConfig( + Map config, List fields, Map additionalBodies) { + config.putIfAbsent(VECTOR_KEY, "values"); + + VectorEmbeddingConfig conf = new VectorEmbeddingConfig(config); + + additionalBodies.put("includeMetadata", fields.contains("metadata")); + additionalBodies.put("includeValues", fields.contains("vector") && conf.isAllResults()); + + RestAPIConfig apiConfig = conf.getApiConfig(); + Map headers = apiConfig.getHeaders(); + headers.remove(METHOD_KEY); + apiConfig.setHeaders(headers); + + return VectorEmbeddingHandler.populateApiBodyRequest(conf, additionalBodies); + } + } +} diff --git a/full/src/main/java/apoc/vectordb/Qdrant.java b/full/src/main/java/apoc/vectordb/Qdrant.java index b697e299a4..f927adfc7b 100644 --- a/full/src/main/java/apoc/vectordb/Qdrant.java +++ b/full/src/main/java/apoc/vectordb/Qdrant.java @@ -157,7 +157,8 @@ private Stream getCommon( checkMappingConf(configuration, "apoc.vectordb.qdrant.getAndUpdate"); } - VectorEmbeddingConfig apiConfig = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids); + VectorEmbeddingConfig apiConfig = + DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); return getEmbeddingResultStream(apiConfig, procedureCallContext, tx); } diff --git a/full/src/main/java/apoc/vectordb/QdrantHandler.java b/full/src/main/java/apoc/vectordb/QdrantHandler.java index 1a9f9954d1..6ba0ccc217 100644 --- a/full/src/main/java/apoc/vectordb/QdrantHandler.java +++ b/full/src/main/java/apoc/vectordb/QdrantHandler.java @@ -34,7 +34,7 @@ static class QdrantEmbeddingHandler implements VectorEmbeddingHandler { @Override public VectorEmbeddingConfig fromGet( - Map config, ProcedureCallContext procedureCallContext, List ids) { + Map config, ProcedureCallContext procedureCallContext, List ids, String collection) { List fields = procedureCallContext.outputFields().collect(Collectors.toList()); config.putIfAbsent(METHOD_KEY, "POST"); diff --git a/full/src/main/java/apoc/vectordb/VectorDb.java b/full/src/main/java/apoc/vectordb/VectorDb.java index 59c9933bd5..f9ce2f8850 100644 --- a/full/src/main/java/apoc/vectordb/VectorDb.java +++ b/full/src/main/java/apoc/vectordb/VectorDb.java @@ -166,15 +166,20 @@ private static Entity handleMappingNode( Node node; Object propValue = metaProps.get(mapping.getMetadataKey()); node = transaction.findNode(Label.label(mapping.getNodeLabel()), mapping.getEntityKey(), propValue); - if (node == null && mapping.isCreate()) { - node = transaction.createNode(Label.label(mapping.getNodeLabel())); - node.setProperty(mapping.getEntityKey(), propValue); + switch (mapping.getMode()) { + case READ_ONLY: + // do nothing, just return the entity + break; + case UPDATE_EXISTING: + setPropsIfEntityExists(mapping, metaProps, embedding, node); + break; + default: + if (node == null) { + node = transaction.createNode(Label.label(mapping.getNodeLabel())); + node.setProperty(mapping.getEntityKey(), propValue); + } + setPropsIfEntityExists(mapping, metaProps, embedding, node); } - if (node != null) { - setProperties(node, metaProps); - setVectorProp(mapping, embedding, node); - } - return node; } catch (MultipleFoundException e) { throw new RuntimeException("Multiple nodes found"); @@ -192,9 +197,12 @@ private static Entity handleMappingRel( Object propValue = metaProps.get(mapping.getMetadataKey()); rel = transaction.findRelationship( RelationshipType.withName(mapping.getRelType()), mapping.getEntityKey(), propValue); - if (rel != null) { - setProperties(rel, metaProps); - setVectorProp(mapping, embedding, rel); + switch (mapping.getMode()) { + case READ_ONLY: + // do nothing, just return the entity + break; + default: + setPropsIfEntityExists(mapping, metaProps, embedding, rel); } return rel; @@ -203,6 +211,14 @@ private static Entity handleMappingRel( } } + private static void setPropsIfEntityExists( + VectorMappingConfig mapping, Map metaProps, List embedding, Entity entity) { + if (entity != null) { + setProperties(entity, metaProps); + setVectorProp(mapping, embedding, entity); + } + } + private static void setVectorProp( VectorMappingConfig mapping, List embedding, T entity) { if (mapping.getEmbeddingKey() == null) { diff --git a/full/src/main/java/apoc/vectordb/VectorDbHandler.java b/full/src/main/java/apoc/vectordb/VectorDbHandler.java index 894b805646..aed4658ad5 100644 --- a/full/src/main/java/apoc/vectordb/VectorDbHandler.java +++ b/full/src/main/java/apoc/vectordb/VectorDbHandler.java @@ -22,6 +22,8 @@ default Map getCredentials(Object credentialsObj, Map config) { this.textKey = (String) config.getOrDefault(TEXT_KEY, DEFAULT_TEXT); this.allResults = Util.toBoolean(config.get(ALL_RESULTS_KEY)); this.mapping = new VectorMappingConfig((Map) config.getOrDefault(MAPPING_KEY, Map.of())); + this.metaAsSubKey = Util.toBoolean(config.getOrDefault(META_AS_SUBKEY_KEY, true)); this.apiConfig = new RestAPIConfig(config); } @@ -66,6 +70,10 @@ public boolean isAllResults() { return allResults; } + public boolean isMetaAsSubKey() { + return metaAsSubKey; + } + public VectorMappingConfig getMapping() { return mapping; } diff --git a/full/src/main/java/apoc/vectordb/VectorEmbeddingHandler.java b/full/src/main/java/apoc/vectordb/VectorEmbeddingHandler.java index e58eb45453..f029b10a6e 100644 --- a/full/src/main/java/apoc/vectordb/VectorEmbeddingHandler.java +++ b/full/src/main/java/apoc/vectordb/VectorEmbeddingHandler.java @@ -1,7 +1,5 @@ package apoc.vectordb; -import static apoc.vectordb.VectorEmbeddingConfig.*; - import apoc.ml.RestAPIConfig; import java.util.List; import java.util.Map; @@ -10,7 +8,7 @@ public interface VectorEmbeddingHandler { VectorEmbeddingConfig fromGet( - Map config, ProcedureCallContext procedureCallContext, List ids); + Map config, ProcedureCallContext procedureCallContext, List ids, String collection); VectorEmbeddingConfig fromQuery( Map config, diff --git a/full/src/main/java/apoc/vectordb/VectorMappingConfig.java b/full/src/main/java/apoc/vectordb/VectorMappingConfig.java index d925517a80..bfec3b94d8 100644 --- a/full/src/main/java/apoc/vectordb/VectorMappingConfig.java +++ b/full/src/main/java/apoc/vectordb/VectorMappingConfig.java @@ -1,17 +1,24 @@ package apoc.vectordb; -import apoc.util.Util; import java.util.Collections; import java.util.Map; public class VectorMappingConfig { + public enum MappingMode { + READ_ONLY, + UPDATE_EXISTING, + CREATE_IF_MISSING + } + public static final String METADATA_KEY = "metadataKey"; public static final String ENTITY_KEY = "entityKey"; public static final String NODE_LABEL = "nodeLabel"; public static final String REL_TYPE = "relType"; public static final String EMBEDDING_KEY = "embeddingKey"; public static final String SIMILARITY_KEY = "similarity"; - public static final String CREATE_KEY = "create"; + public static final String MODE_KEY = "mode"; + public static final String NO_FIELDS_ERROR_MSG = + "You need to define either the 'field' list parameter, or the 'metadataKey' string parameter within the `embeddingConfig` parameter"; private final String metadataKey; private final String entityKey; @@ -21,7 +28,7 @@ public class VectorMappingConfig { private final String embeddingKey; private final String similarity; - private final boolean create; + private MappingMode mode; public VectorMappingConfig(Map mapping) { if (mapping == null) { @@ -36,7 +43,8 @@ public VectorMappingConfig(Map mapping) { this.similarity = (String) mapping.getOrDefault(SIMILARITY_KEY, "cosine"); - this.create = Util.toBoolean(mapping.get(CREATE_KEY)); + String modeValue = (String) mapping.getOrDefault(MODE_KEY, MappingMode.UPDATE_EXISTING.toString()); + this.mode = MappingMode.valueOf(modeValue.toUpperCase()); } public String getMetadataKey() { @@ -59,11 +67,11 @@ public String getEmbeddingKey() { return embeddingKey; } - public boolean isCreate() { - return create; - } - public String getSimilarity() { return similarity; } + + public MappingMode getMode() { + return mode; + } } diff --git a/full/src/main/java/apoc/vectordb/Weaviate.java b/full/src/main/java/apoc/vectordb/Weaviate.java index 21c87db209..b47bb45d02 100644 --- a/full/src/main/java/apoc/vectordb/Weaviate.java +++ b/full/src/main/java/apoc/vectordb/Weaviate.java @@ -182,7 +182,7 @@ private Stream getCommon( config.putIfAbsent(METHOD_KEY, null); List fields = procedureCallContext.outputFields().collect(Collectors.toList()); - VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids); + VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); boolean hasEmbedding = fields.contains("vector") && conf.isAllResults(); boolean hasMetadata = fields.contains("metadata"); VectorMappingConfig mapping = conf.getMapping(); diff --git a/full/src/main/java/apoc/vectordb/WeaviateHandler.java b/full/src/main/java/apoc/vectordb/WeaviateHandler.java index b58ffd79af..0cf7d856c0 100644 --- a/full/src/main/java/apoc/vectordb/WeaviateHandler.java +++ b/full/src/main/java/apoc/vectordb/WeaviateHandler.java @@ -3,6 +3,7 @@ import static apoc.ml.RestAPIConfig.BODY_KEY; import static apoc.ml.RestAPIConfig.METHOD_KEY; import static apoc.util.MapUtil.map; +import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY; import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY; @@ -35,7 +36,7 @@ static class WeaviateEmbeddingHandler implements VectorEmbeddingHandler { @Override public VectorEmbeddingConfig fromGet( - Map config, ProcedureCallContext procedureCallContext, List ids) { + Map config, ProcedureCallContext procedureCallContext, List ids, String collection) { config.putIfAbsent(BODY_KEY, null); return VectorEmbeddingHandler.populateApiBodyRequest(getVectorEmbeddingConfig(config), Map.of()); } @@ -52,7 +53,7 @@ public VectorEmbeddingConfig fromQuery( config.putIfAbsent(METHOD_KEY, "POST"); VectorEmbeddingConfig vectorEmbeddingConfig = getVectorEmbeddingConfig(config); - List list = (List) config.get("fields"); + List list = (List) config.get(FIELDS_KEY); if (list == null) { throw new RuntimeException("You have to define `field` list of parameter to be returned"); } diff --git a/full/src/main/resources/extended.txt b/full/src/main/resources/extended.txt index 0fbd4c83d5..38e3096141 100644 --- a/full/src/main/resources/extended.txt +++ b/full/src/main/resources/extended.txt @@ -238,5 +238,21 @@ apoc.vectordb.weaviate.queryAndUpdate apoc.vectordb.weaviate.info apoc.vectordb.pinecone.info apoc.vectordb.milvus.info -apoc.vectordb.custom +apoc.vectordb.pinecone.createCollection +apoc.vectordb.pinecone.deleteCollection +apoc.vectordb.pinecone.upsert +apoc.vectordb.pinecone.delete +apoc.vectordb.pinecone.get +apoc.vectordb.pinecone.getAndUpdate +apoc.vectordb.pinecone.query +apoc.vectordb.pinecone.queryAndUpdate +apoc.vectordb.milvus.createCollection +apoc.vectordb.milvus.deleteCollection +apoc.vectordb.milvus.upsert +apoc.vectordb.milvus.delete +apoc.vectordb.milvus.get +apoc.vectordb.milvus.getAndUpdate +apoc.vectordb.milvus.query +apoc.vectordb.milvus.queryAndUpdate +apoc.vectordb.custom.get apoc.vectordb.configure \ No newline at end of file diff --git a/full/src/test/java/apoc/vectordb/PineconeCustomTest.java b/full/src/test/java/apoc/vectordb/PineconeCustomTest.java new file mode 100644 index 0000000000..a4185ce703 --- /dev/null +++ b/full/src/test/java/apoc/vectordb/PineconeCustomTest.java @@ -0,0 +1,98 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.BODY_KEY; +import static apoc.ml.RestAPIConfig.HEADERS_KEY; +import static apoc.ml.RestAPIConfig.JSON_PATH_KEY; +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static apoc.util.Util.map; +import static apoc.util.UtilsExtendedTest.checkEnvVar; +import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import apoc.util.TestUtil; +import java.net.URL; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +/** + * It leverages `apoc.vectordb.custom*` procedures + * * + * * + * Example of Pinecone RestAPI: + * PINECONE_HOST: `https://INDEX-ID.svc.gcp-starter.pinecone.io` + * PINECONE_KEY: `API Key` + * PINECONE_NAMESPACE: `the one to be specified in body: {.. "ns": NAMESPACE}` + * PINECONE_DIMENSION: vector dimension + */ +public class PineconeCustomTest { + private static String apiKey; + private static String host; + private static String size; + private static String namespace; + + @ClassRule + public static DbmsRule db = new ImpermanentDbmsRule(); + + @BeforeClass + public static void setUp() throws Exception { + apiKey = checkEnvVar("PINECONE_KEY"); + host = checkEnvVar("PINECONE_HOST"); + size = checkEnvVar("PINECONE_DIMENSION"); + namespace = checkEnvVar("PINECONE_NAMESPACE"); + + TestUtil.registerProcedure(db, VectorDb.class); + } + + @Test + public void callQueryEndpointViaCustomGetProc() { + + Map conf = getConf(); + conf.put(VECTOR_KEY, "values"); + + testResult(db, "CALL apoc.vectordb.custom.get($host, $conf)", map("host", host + "/query", "conf", conf), r -> { + r.forEachRemaining(i -> { + assertNotNull(i.get("score")); + assertNotNull(i.get("metadata")); + assertNotNull(i.get("id")); + assertNotNull(i.get("vector")); + }); + }); + } + + @Test + public void callQueryEndpointViaCustomProc() { + testCall(db, "CALL apoc.vectordb.custom($host, $conf)", map("host", host + "/query", "conf", getConf()), r -> { + List value = (List) r.get("value"); + value.forEach(i -> { + assertTrue(i.containsKey("score")); + assertTrue(i.containsKey("metadata")); + assertTrue(i.containsKey("id")); + }); + }); + } + + /** + * TODO: "method" is null as a workaround. + * Since with `method: POST` the {@link apoc.util.Util#openUrlConnection(URL, Map)} has a `setChunkedStreamingMode` + * that makes the request to respond 200 OK, but returns an empty result + */ + private static Map getConf() { + List vector = Collections.nCopies(Integer.parseInt(size), 0.1); + + Map body = map( + "namespace", namespace, "vector", vector, "topK", 3, "includeValues", true, "includeMetadata", true); + + Map header = map("Api-Key", apiKey); + + return map(BODY_KEY, body, HEADERS_KEY, header, METHOD_KEY, null, JSON_PATH_KEY, "matches"); + } +} diff --git a/full/src/test/java/apoc/vectordb/PineconeTest.java b/full/src/test/java/apoc/vectordb/PineconeTest.java index b12fbf5403..313be0c19e 100644 --- a/full/src/test/java/apoc/vectordb/PineconeTest.java +++ b/full/src/test/java/apoc/vectordb/PineconeTest.java @@ -1,98 +1,618 @@ package apoc.vectordb; -import static apoc.ml.RestAPIConfig.BODY_KEY; +import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.ml.RestAPIConfig.HEADERS_KEY; -import static apoc.ml.RestAPIConfig.JSON_PATH_KEY; -import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.util.ExtendedTestUtil.assertFails; +import static apoc.util.ExtendedTestUtil.testRetryCallEventually; +import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testCallEmpty; import static apoc.util.TestUtil.testResult; -import static apoc.util.Util.map; import static apoc.util.UtilsExtendedTest.checkEnvVar; -import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY; +import static apoc.vectordb.VectorDbHandler.Type.PINECONE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.REL; +import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; +import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; +import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; +import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; +import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; +import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; +import static apoc.vectordb.VectorMappingConfig.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertNull; +import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; +import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; +import apoc.ml.Prompt; +import apoc.util.MapUtil; import apoc.util.TestUtil; -import java.net.URL; -import java.util.Collections; +import apoc.util.Util; import java.util.List; import java.util.Map; +import java.util.UUID; +import org.junit.AfterClass; +import org.junit.Before; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; -import org.neo4j.test.rule.DbmsRule; -import org.neo4j.test.rule.ImpermanentDbmsRule; - -/** - * It leverages `apoc.vectordb.custom*` procedures - * * - * * - * Example of Pinecone RestAPI: - * PINECONE_HOST: `https://INDEX-ID.svc.gcp-starter.pinecone.io` - * PINECONE_KEY: `API Key` - * PINECONE_NAMESPACE: `the one to be specified in body: {.. "ns": NAMESPACE}` - * PINECONE_DIMENSION: vector dimension - */ +import org.junit.rules.TemporaryFolder; +import org.neo4j.dbms.api.DatabaseManagementService; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.test.TestDatabaseManagementServiceBuilder; + public class PineconeTest { - private static String apiKey; - private static String host; - private static String size; - private static String namespace; + private static String API_KEY; + private static String HOST; + + private static final String collName = UUID.randomUUID().toString(); @ClassRule - public static DbmsRule db = new ImpermanentDbmsRule(); + public static TemporaryFolder storeDir = new TemporaryFolder(); + + private static GraphDatabaseService sysDb; + private static GraphDatabaseService db; + private static DatabaseManagementService databaseManagementService; + + private static Map ADMIN_AUTHORIZATION; + private static Map ADMIN_HEADER_CONF; @BeforeClass - public static void setUp() throws Exception { - apiKey = checkEnvVar("PINECONE_KEY"); - host = checkEnvVar("PINECONE_HOST"); - size = checkEnvVar("PINECONE_DIMENSION"); - namespace = checkEnvVar("PINECONE_NAMESPACE"); + public static void setUp() { + API_KEY = checkEnvVar("PINECONE_KEY"); + HOST = checkEnvVar("PINECONE_HOST"); + + databaseManagementService = + new TestDatabaseManagementServiceBuilder(storeDir.getRoot().toPath()).build(); + db = databaseManagementService.database(DEFAULT_DATABASE_NAME); + sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME); + + TestUtil.registerProcedure(db, VectorDb.class, Pinecone.class, Prompt.class); + + ADMIN_AUTHORIZATION = map("Api-Key", API_KEY); + ADMIN_HEADER_CONF = map(HEADERS_KEY, ADMIN_AUTHORIZATION); + + testRetryCallEventually( + db, + "CALL apoc.vectordb.pinecone.createCollection($host, $coll, 'cosine', 4, $conf)", + map( + "host", + HOST, + "coll", + collName, + "conf", + map( + HEADERS_KEY, + ADMIN_AUTHORIZATION, + "body", + map("spec", map("serverless", map("cloud", "aws", "region", "us-east-1"))))), + r -> { + Map value = (Map) r.get("value"); + assertEquals(map("ready", false, "state", "Initializing"), value.get("status")); + HOST = "https://" + value.get("host"); + }, + 5L); + + // the upsert takes a while + Util.sleep(5000); + + testResult( + db, + "CALL apoc.vectordb.pinecone.upsert($host, $coll,\n" + "[\n" + + " {id: '1', vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}},\n" + + " {id: '2', vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}}\n" + + "],\n" + + "$conf)", + map("host", HOST, "coll", collName, "conf", ADMIN_HEADER_CONF), + r -> { + Map row = r.next(); + Map value = (Map) row.get("value"); + assertEquals(2L, value.get("upsertedCount")); + }); + + // the upsert takes a while + Util.sleep(20000); + } + + @AfterClass + public static void tearDown() { + if (API_KEY == null || HOST == null) { + return; + } + + Util.sleep(2000); + + testCallEmpty( + db, + "CALL apoc.vectordb.pinecone.deleteCollection($host, $coll, $conf)", + map("host", "", "coll", collName, "conf", ADMIN_HEADER_CONF)); + + databaseManagementService.shutdown(); + } + + @Before + public void before() { + dropAndDeleteAll(db); + } + + @Test + public void getInfo() { + testResult( + db, + "CALL apoc.vectordb.pinecone.info($host, $coll, $conf) ", + map( + "host", + null, + "coll", + collName, + "conf", + map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + Map row = r.next(); + Map value = (Map) row.get("value"); + assertEquals(collName, value.get("name")); + }); + } + + @Test + public void getInfoNotExistentCollection() { + String wrongCollection = "wrong_collection"; + assertFails( + db, + "CALL apoc.vectordb.pinecone.info($host, $coll, $conf)", + map( + "host", + null, + "coll", + wrongCollection, + "conf", + map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + "java.io.FileNotFoundException: https://api.pinecone.io/indexes/" + wrongCollection); + } + + @Test + public void getVectors() { + testResult( + db, + "CALL apoc.vectordb.pinecone.get($host, $coll, ['1', '2'], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + map( + "host", + HOST, + "coll", + collName, + "conf", + map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + Map row = r.next(); + assertBerlinResult(row, FALSE); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, FALSE); + assertNotNull(row.get("vector")); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getVectorsWithoutVectorResult() { + testResult( + db, + "CALL apoc.vectordb.pinecone.get($host, $coll, ['1'], $conf) ", + map("host", HOST, "coll", collName, "conf", ADMIN_HEADER_CONF), + r -> { + Map row = r.next(); + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void deleteVector() { + testCall( + db, + "CALL apoc.vectordb.pinecone.upsert($host, $coll,\n" + "[\n" + + " {id: '3', vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}},\n" + + " {id: '4', vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" + + "],\n" + + "$conf)", + map("host", HOST, "coll", collName, "conf", ADMIN_HEADER_CONF), + r -> { + Map value = (Map) r.get("value"); + assertEquals(2L, value.get("upsertedCount")); + }); + + // the upsert takes a while + Util.sleep(10000); + + testCall( + db, + "CALL apoc.vectordb.pinecone.delete($host, $coll, ['3', '4'], $conf) ", + map("host", HOST, "coll", collName, "conf", ADMIN_HEADER_CONF), + r -> { + assertEquals(Map.of(), r.get("value")); + }); + } + + @Test + public void queryVectors() { + testResult( + db, + "CALL apoc.vectordb.pinecone.query($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map( + "host", + HOST, + "coll", + collName, + "conf", + map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + Map row = r.next(); + assertBerlinResult(row, FALSE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, FALSE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void queryVectorsWithoutVectorResult() { + testResult( + db, + "CALL apoc.vectordb.pinecone.queryAndUpdate($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "coll", collName, "conf", map(HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + Map row = r.next(); + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + + row = r.next(); + assertEquals(Map.of("city", "London", "foo", "two"), row.get("metadata")); + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + }); + } + + @Test + public void queryVectorsWithYield() { + testResult( + db, + "CALL apoc.vectordb.pinecone.query($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf) YIELD metadata, id", + map( + "host", + HOST, + "coll", + collName, + "conf", + map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + assertBerlinResult(r.next(), FALSE); + assertLondonResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithFilter() { + testResult( + db, + "CALL apoc.vectordb.pinecone.query($host, $coll, [0.2, 0.1, 0.9, 0.7],\n" + + "{ city: { `$eq`: \"London\" } },\n" + + "5, $conf) YIELD metadata, id", + map( + "host", + HOST, + "coll", + collName, + "conf", + map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + assertLondonResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithLimit() { + testResult( + db, + "CALL apoc.vectordb.pinecone.query($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id", + map( + "host", + HOST, + "coll", + collName, + "conf", + map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + assertBerlinResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithCreateNode() { + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map( + EMBEDDING_KEY, + "vect", + NODE_LABEL, + "Test", + ENTITY_KEY, + "myId", + METADATA_KEY, + "foo", + MODE_KEY, + MappingMode.CREATE_IF_MISSING.toString())); + testResult( + db, + "CALL apoc.vectordb.pinecone.queryAndUpdate($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "coll", collName, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + + testResult( + db, + "MATCH (n:Test) RETURN properties(n) AS props ORDER BY n.myId", + VectorDbTestUtil::vectorEntityAssertions); + + testResult( + db, + "CALL apoc.vectordb.pinecone.queryAndUpdate($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "coll", collName, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void queryVectorsWithCreateNodeUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.pinecone.queryAndUpdate($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "coll", collName, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void getVectorsWithCreateNodeUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = MapUtil.map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + MapUtil.map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.pinecone.getAndUpdate($host, 'TestCollection', [1, 2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + Util.map("host", HOST, "coll", collName, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void getReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map( + NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo")); - TestUtil.registerProcedure(db, VectorDb.class); + testResult( + db, + "CALL apoc.vectordb.pinecone.get($host, 'TestCollection', [1, 2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + Util.map("host", HOST, "coll", collName, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node")); } @Test - public void callQueryEndpointViaCustomGetProc() { + public void queryVectorsWithCreateRel() { - Map conf = getConf(); - conf.put(VECTOR_KEY, "values"); + db.executeTransactionally( + "CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); - testResult(db, "CALL apoc.vectordb.custom.get($host, $conf)", map("host", host + "/query", "conf", conf), r -> { - r.forEachRemaining(i -> { - assertNotNull(i.get("score")); - assertNotNull(i.get("metadata")); - assertNotNull(i.get("id")); - assertNotNull(i.get("vector")); - }); - }); + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map(EMBEDDING_KEY, "vect", REL_TYPE, "TEST", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + testResult( + db, + "CALL apoc.vectordb.pinecone.queryAndUpdate($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "coll", collName, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, REL); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, REL); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertRelsCreated(db); } @Test - public void callQueryEndpointViaCustomProc() { - testCall(db, "CALL apoc.vectordb.custom($host, $conf)", map("host", host + "/query", "conf", getConf()), r -> { - List value = (List) r.get("value"); - value.forEach(i -> { - assertTrue(i.containsKey("score")); - assertTrue(i.containsKey("metadata")); - assertTrue(i.containsKey("id")); - }); - }); + public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally( + "CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.pinecone.query($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "coll", collName, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel")); } - /** - * TODO: "method" is null as a workaround. - * Since with `method: POST` the {@link apoc.util.Util#openUrlConnection(URL, Map)} has a `setChunkedStreamingMode` - * that makes the request to respond 200 OK, but returns an empty result - */ - private static Map getConf() { - List vector = Collections.nCopies(Integer.parseInt(size), 0.1); + @Test + public void queryVectorsWithSystemDbStorage() { + String keyConfig = "pinecone-config-foo"; + Map mapping = + map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo"); + + sysDb.executeTransactionally( + "CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)", + map( + "vectorName", + PINECONE.toString(), + "keyConfig", + keyConfig, + "databaseName", + DEFAULT_DATABASE_NAME, + "conf", + map( + "host", HOST, + "credentials", API_KEY, + "mapping", mapping))); - Map body = map( - "namespace", namespace, "vector", vector, "topK", 3, "includeValues", true, "includeMetadata", true); + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + testResult( + db, + "CALL apoc.vectordb.pinecone.queryAndUpdate($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", keyConfig, "coll", collName, "conf", map(ALL_RESULTS_KEY, true)), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void queryVectorsWithRag() { + String openAIKey = ragSetup(db); - Map header = map("Api-Key", apiKey); + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map(NODE_LABEL, "Rag", ENTITY_KEY, "readID", METADATA_KEY, "foo")); - return map(BODY_KEY, body, HEADERS_KEY, header, METHOD_KEY, null, JSON_PATH_KEY, "matches"); + testResult( + db, + "CALL apoc.vectordb.pinecone.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value", + map( + "host", HOST, + "conf", conf, + "collection", collName, + "confPrompt", map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo")), + VectorDbTestUtil::assertRagWithVectors); } } diff --git a/test-utils/build.gradle b/test-utils/build.gradle index dde9f628ec..b05d3c70bd 100644 --- a/test-utils/build.gradle +++ b/test-utils/build.gradle @@ -45,6 +45,7 @@ dependencies { api group: 'org.testcontainers', name: 'qdrant', version: '1.19.7' api group: 'org.testcontainers', name: 'chromadb', version: '1.19.7' api group: 'org.testcontainers', name: 'weaviate', version: '1.19.7' + api group: 'org.testcontainers', name: 'milvus', version: '1.19.7' api group: 'org.apache.arrow', name: 'arrow-vector', version: '16.1.0' api group: 'org.apache.arrow', name: 'arrow-memory-netty', version: '16.1.0' }