Skip to content

Commit

Permalink
[NOID] Fixes #4231: The apoc.vectordb.milvus.query* and apoc.vectordb…
Browse files Browse the repository at this point in the history
….weaviate.query* procedures should get the fields config from metadataKey if present (#4241)

* Fixes #4231: The apoc.vectordb.milvus.query* and apoc.vectordb.weaviate.query* procedures should get the fields config from metadataKey if present

* test fixes and changes review

* fix tests
  • Loading branch information
vga91 committed Dec 18, 2024
1 parent d394eab commit 0b223e2
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 10 deletions.
98 changes: 98 additions & 0 deletions full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package apoc.full.it.vectordb;

import apoc.ml.Prompt;
import apoc.util.ExtendedTestUtil;
import apoc.util.TestUtil;
import apoc.util.Util;
import apoc.vectordb.Milvus;
Expand All @@ -21,6 +23,7 @@
import java.util.List;
import java.util.Map;

import static apoc.ml.Prompt.API_KEY_CONF;
import static apoc.util.MapUtil.map;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
Expand All @@ -41,8 +44,11 @@
import static apoc.vectordb.VectorMappingConfig.ENTITY_KEY;
import static apoc.vectordb.VectorMappingConfig.METADATA_KEY;
import static apoc.vectordb.VectorMappingConfig.MODE_KEY;
import static apoc.vectordb.VectorMappingConfig.MappingMode;
import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG;
import static apoc.vectordb.VectorMappingConfig.NODE_LABEL;
import static apoc.vectordb.VectorMappingConfig.REL_TYPE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
Expand Down Expand Up @@ -406,4 +412,96 @@ public void queryVectorsWithSystemDbStorage() {
assertNodesCreated(db);
}

@Test
public void queryVectorsWithRag() {
String openAIKey = ragSetup(db);

Map<String, Object> conf = map(
FIELDS_KEY, FIELDS,
ALL_RESULTS_KEY, true,
HEADERS_KEY, READONLY_AUTHORIZATION,
MAPPING_KEY, map(NODE_LABEL, "Rag",
ENTITY_KEY, "readID",
METADATA_KEY, "foo")
);

testResult(db,
"CALL apoc.vectordb.milvus.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector\n" +
"WITH collect(node) as paths\n" +
"CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" +
"RETURN value"
,
map(
"host", HOST,
"conf", conf,
"confPrompt", map(API_KEY_CONF, openAIKey),
"attributes", List.of("city", "foo")
),
VectorDbTestUtil::assertRagWithVectors);
}

@Test
public void queryVectorsWithMetadataKeyNoFields() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID",
METADATA_KEY, "foo"
)
);
testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)",
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryVectorsWithNoMetadataKeyNoFields() {
Map<String, Object> params = map(
"host", HOST, "conf", Map.of(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID"
))
);
String query = "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}

@Test
public void queryAndUpdateMetadataKeyWithoutFieldsTest() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID",
METADATA_KEY, "foo"
)
);

String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";

testResult(db, query,
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryAndUpdateWithNoMetadataKeyNoFields() {
Map<String, Object> conf = map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
REL_TYPE, "TEST",
ENTITY_KEY, "readID"
)
);
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> params = Util.map("host", HOST,
"conf", conf);

String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";

ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}
}
77 changes: 77 additions & 0 deletions full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
package apoc.full.it.vectordb;

import apoc.ml.Prompt;
import apoc.util.ExtendedTestUtil;
import apoc.util.MapUtil;
import apoc.util.TestUtil;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.test.TestDatabaseManagementServiceBuilder;
import org.testcontainers.weaviate.WeaviateContainer;

import java.util.List;
import java.util.Map;

import static apoc.ml.Prompt.API_KEY_CONF;
import static apoc.ml.RestAPIConfig.HEADERS_KEY;
import static apoc.util.ExtendedTestUtil.assertFails;
Expand Down Expand Up @@ -658,4 +678,61 @@ public void queryVectorsWithRag() {
"attributes", List.of("city", "foo")),
VectorDbTestUtil::assertRagWithVectors);
}

@Test
public void queryVectorsWithMetadataKeyNoFields() {
testResult(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata RETURN * ORDER BY id",
map("host", HOST, "conf", map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo"
),
HEADERS_KEY, ADMIN_AUTHORIZATION)),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryVectorsWithNoMetadataKeyNoFields() {
Map<String, Object> params = map("host", HOST, "conf", map(
ALL_RESULTS_KEY, true,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId"
),
HEADERS_KEY, ADMIN_AUTHORIZATION));
String query = "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata RETURN * ORDER BY id";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}

@Test
public void queryAndUpdateMetadataKeyWithoutFieldsTest() {
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
HEADERS_KEY, ADMIN_AUTHORIZATION,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId",
METADATA_KEY, "foo"));
testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
" YIELD score, vector, id, metadata, node RETURN * ORDER BY id",
map("host", HOST, "conf", conf),
VectorDbTestUtil::assertMetadataFooResult);
}

@Test
public void queryAndUpdateWithCreateNodeUsingExistingNodeFailWithNoMetadataKeyAndNoFields() {
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
Map<String, Object> params = map("host", HOST,
"conf", Map.of(ALL_RESULTS_KEY, true,
HEADERS_KEY, ADMIN_AUTHORIZATION,
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Test",
ENTITY_KEY, "myId")));
String query = "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) YIELD score, vector, id, metadata, node RETURN * ORDER BY id";
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
}
}
8 changes: 3 additions & 5 deletions full/src/main/java/apoc/vectordb/MilvusHandler.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package apoc.vectordb;

import static apoc.util.MapUtil.map;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields;
import static apoc.vectordb.VectorEmbeddingConfig.META_AS_SUBKEY_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.SCORE_KEY;

Expand Down Expand Up @@ -66,10 +66,8 @@ private VectorEmbeddingConfig getVectorEmbeddingConfig(
Map<String, Object> additionalBodies) {
config.putIfAbsent(META_AS_SUBKEY_KEY, false);

List listFields = (List) config.get(FIELDS_KEY);
if (listFields == null) {
throw new RuntimeException("You have to define `field` list of parameter to be returned");
}
List listFields = addMetadataKeyToFields(config);

if (procFields.contains("vector") && !listFields.contains("vector")) {
listFields.add("vector");
}
Expand Down
8 changes: 3 additions & 5 deletions full/src/main/java/apoc/vectordb/WeaviateHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import static apoc.ml.RestAPIConfig.BODY_KEY;
import static apoc.ml.RestAPIConfig.METHOD_KEY;
import static apoc.util.MapUtil.map;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields;
import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY;

Expand Down Expand Up @@ -53,10 +53,8 @@ public VectorEmbeddingConfig fromQuery(
config.putIfAbsent(METHOD_KEY, "POST");
VectorEmbeddingConfig vectorEmbeddingConfig = getVectorEmbeddingConfig(config);

List list = (List) config.get(FIELDS_KEY);
if (list == null) {
throw new RuntimeException("You have to define `field` list of parameter to be returned");
}
List list = addMetadataKeyToFields(config);

Object fieldList = String.join("\n", list);

filter = filter == null ? "" : ", where: " + filter;
Expand Down
20 changes: 20 additions & 0 deletions full/src/test/java/apoc/vectordb/VectorDbTestUtil.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
package apoc.vectordb;

import apoc.util.MapUtil;
import org.junit.Assume;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.graphdb.Result;

import java.util.Map;

import static apoc.vectordb.VectorEmbeddingConfig.DEFAULT_METADATA;
import static apoc.util.TestUtil.testResult;
import static apoc.util.Util.map;
import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -122,4 +132,14 @@ public static String ragSetup(GraphDatabaseService db) {
db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})");
return openAIKey;
}

public static void assertMetadataFooResult(Result r) {
Map<String, Object> row = r.next();
Map<String, Object> metadata = (Map<String, Object>) row.get(DEFAULT_METADATA);
assertEquals("one", metadata.get("foo"));
row = r.next();
metadata = (Map<String, Object>) row.get(DEFAULT_METADATA);
assertEquals("two", metadata.get("foo"));
assertFalse(r.hasNext());
}
}

0 comments on commit 0b223e2

Please sign in to comment.