diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b0c8eb0..e42fc08d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ - Improved developer experience by copying the docstring from the `Retriever.get_search_results` method to the `Retriever.search` method - Support for specifying database names in index handling methods and retrievers. - User Guide in documentation. +- Introduced result_formatter argument to all retrievers, allowing custom formatting of retriever results. ### Changed - Refactored import paths for retrievers to neo4j_genai.retrievers. diff --git a/docs/source/user_guide.rst b/docs/source/user_guide.rst index 66264053..c1d85bb5 100644 --- a/docs/source/user_guide.rst +++ b/docs/source/user_guide.rst @@ -430,7 +430,7 @@ Format the Results .. warning:: - This API is in beta mode and will be subject to change is the future. + This API is in beta mode and will be subject to change in the future. For improved readability and ease in prompt-engineering, formatting the result to suit specific needs involves providing a `record_formatter` function to the Cypher retrievers. diff --git a/src/neo4j_genai/retrievers/external/pinecone/pinecone.py b/src/neo4j_genai/retrievers/external/pinecone/pinecone.py index 3aafb73d..90c91cc8 100644 --- a/src/neo4j_genai/retrievers/external/pinecone/pinecone.py +++ b/src/neo4j_genai/retrievers/external/pinecone/pinecone.py @@ -38,6 +38,7 @@ EmbedderModel, Neo4jDriverModel, RawSearchResult, + RetrieverResultItem, ) logger = logging.getLogger(__name__) @@ -78,7 +79,7 @@ class PineconeNeo4jRetriever(ExternalRetriever): id_property_neo4j (str): The name of the Neo4j node property that's used as the identifier for relating matches from Weaviate to Neo4j nodes. embedder (Optional[Embedder]): Embedder object to embed query text. return_properties (Optional[list[str]]): List of node properties to return. - result_formatter (Optional[Callable[[Any], Any]]): Function to transform a neo4j.Record to a RetrieverResultItem. + result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem. neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). Raises: @@ -94,7 +95,9 @@ def __init__( embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None, - result_formatter: Optional[Callable[[Any], Any]] = None, + result_formatter: Optional[ + Callable[[neo4j.Record], RetrieverResultItem] + ] = None, neo4j_database: Optional[str] = None, ): try: diff --git a/src/neo4j_genai/retrievers/external/pinecone/types.py b/src/neo4j_genai/retrievers/external/pinecone/types.py index b96148ac..7ee8930b 100644 --- a/src/neo4j_genai/retrievers/external/pinecone/types.py +++ b/src/neo4j_genai/retrievers/external/pinecone/types.py @@ -24,7 +24,12 @@ field_validator, ) -from neo4j_genai.types import EmbedderModel, Neo4jDriverModel, VectorSearchModel +from neo4j_genai.types import ( + EmbedderModel, + Neo4jDriverModel, + RetrieverResultItem, + VectorSearchModel, +) class PineconeSearchModel(VectorSearchModel): @@ -52,5 +57,5 @@ class PineconeNeo4jRetrieverModel(BaseModel): embedder_model: Optional[EmbedderModel] = None return_properties: Optional[list[str]] = None retrieval_query: Optional[str] = None - result_formatter: Optional[Callable[[neo4j.Record], str]] = None + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None diff --git a/src/neo4j_genai/retrievers/external/weaviate/types.py b/src/neo4j_genai/retrievers/external/weaviate/types.py index 21674f4c..39cc4571 100644 --- a/src/neo4j_genai/retrievers/external/weaviate/types.py +++ b/src/neo4j_genai/retrievers/external/weaviate/types.py @@ -25,7 +25,12 @@ from weaviate.client import WeaviateClient from weaviate.collections.classes.filters import _Filters -from neo4j_genai.types import EmbedderModel, Neo4jDriverModel, VectorSearchModel +from neo4j_genai.types import ( + EmbedderModel, + Neo4jDriverModel, + RetrieverResultItem, + VectorSearchModel, +) class WeaviateModel(BaseModel): @@ -50,7 +55,7 @@ class WeaviateNeo4jRetrieverModel(BaseModel): embedder_model: Optional[EmbedderModel] return_properties: Optional[list[str]] = None retrieval_query: Optional[str] = None - result_formatter: Optional[Callable[[neo4j.Record], str]] = None + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None diff --git a/src/neo4j_genai/retrievers/external/weaviate/weaviate.py b/src/neo4j_genai/retrievers/external/weaviate/weaviate.py index 45ea56f3..0eaa226a 100644 --- a/src/neo4j_genai/retrievers/external/weaviate/weaviate.py +++ b/src/neo4j_genai/retrievers/external/weaviate/weaviate.py @@ -31,7 +31,12 @@ WeaviateNeo4jRetrieverModel, WeaviateNeo4jSearchModel, ) -from neo4j_genai.types import EmbedderModel, Neo4jDriverModel, RawSearchResult +from neo4j_genai.types import ( + EmbedderModel, + Neo4jDriverModel, + RawSearchResult, + RetrieverResultItem, +) logger = logging.getLogger(__name__) @@ -69,7 +74,7 @@ class WeaviateNeo4jRetriever(ExternalRetriever): id_property_neo4j (str): The name of the Neo4j node property that's used as the identifier for relating matches from Weaviate to Neo4j nodes. embedder (Optional[Embedder]): Embedder object to embed query text. return_properties (Optional[list[str]]): List of node properties to return. - result_formatter (Optional[Callable[[Any], Any]]): Function to transform a neo4j.Record to a RetrieverResultItem. + result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem. neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). Raises: @@ -86,7 +91,9 @@ def __init__( embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None, - result_formatter: Optional[Callable[[Any], Any]] = None, + result_formatter: Optional[ + Callable[[neo4j.Record], RetrieverResultItem] + ] = None, neo4j_database: Optional[str] = None, ): try: diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index 17f37c76..59c26df6 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -70,7 +70,12 @@ class HybridRetriever(Retriever): embedder (Optional[Embedder]): Embedder object to embed query text. return_properties (Optional[list[str]]): List of node properties to return. neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem. + Two variables are provided in the neo4j.Record: + + - node: Represents the node retrieved from the vector index search. + - score: Denotes the similarity score. """ def __init__( @@ -80,6 +85,9 @@ def __init__( fulltext_index_name: str, embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, + result_formatter: Optional[ + Callable[[neo4j.Record], RetrieverResultItem] + ] = None, neo4j_database: Optional[str] = None, ) -> None: try: @@ -91,6 +99,7 @@ def __init__( fulltext_index_name=fulltext_index_name, embedder_model=embedder_model, return_properties=return_properties, + result_formatter=result_formatter, neo4j_database=neo4j_database, ) except ValidationError as e: @@ -107,6 +116,7 @@ def __init__( if validated_data.embedder_model else None ) + self.result_formatter = validated_data.result_formatter def default_record_formatter(self, record: neo4j.Record) -> RetrieverResultItem: """ @@ -219,7 +229,7 @@ class HybridCypherRetriever(Retriever): fulltext_index_name (str): Fulltext index name. retrieval_query (str): Cypher query that gets appended. embedder (Optional[Embedder]): Embedder object to embed query text. - result_formatter (Optional[Callable[[Any], Any]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem. + result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem. neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). Raises: @@ -233,7 +243,9 @@ def __init__( fulltext_index_name: str, retrieval_query: str, embedder: Optional[Embedder] = None, - result_formatter: Optional[Callable[[Any], Any]] = None, + result_formatter: Optional[ + Callable[[neo4j.Record], RetrieverResultItem] + ] = None, neo4j_database: Optional[str] = None, ) -> None: try: @@ -245,6 +257,7 @@ def __init__( fulltext_index_name=fulltext_index_name, retrieval_query=retrieval_query, embedder_model=embedder_model, + result_formatter=result_formatter, neo4j_database=neo4j_database, ) except ValidationError as e: @@ -261,7 +274,7 @@ def __init__( if validated_data.embedder_model else None ) - self.result_formatter = result_formatter + self.result_formatter = validated_data.result_formatter def get_search_results( self, diff --git a/src/neo4j_genai/retrievers/text2cypher.py b/src/neo4j_genai/retrievers/text2cypher.py index abef14c5..3384c597 100644 --- a/src/neo4j_genai/retrievers/text2cypher.py +++ b/src/neo4j_genai/retrievers/text2cypher.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import Optional +from typing import Callable, Optional import neo4j from neo4j.exceptions import CypherSyntaxError, DriverError, Neo4jError @@ -36,6 +36,7 @@ Neo4jDriverModel, Neo4jSchemaModel, RawSearchResult, + RetrieverResultItem, Text2CypherRetrieverModel, Text2CypherSearchModel, ) @@ -65,6 +66,9 @@ def __init__( llm: LLMInterface, neo4j_schema: Optional[str] = None, examples: Optional[list[str]] = None, + result_formatter: Optional[ + Callable[[neo4j.Record], RetrieverResultItem] + ] = None, ) -> None: try: driver_model = Neo4jDriverModel(driver=driver) @@ -77,6 +81,7 @@ def __init__( llm_model=llm_model, neo4j_schema_model=neo4j_schema_model, examples=examples, + result_formatter=result_formatter, ) except ValidationError as e: raise RetrieverInitializationError(e.errors()) from e @@ -84,6 +89,7 @@ def __init__( super().__init__(validated_data.driver_model.driver) self.llm = validated_data.llm_model.llm self.examples = validated_data.examples + self.result_formatter = validated_data.result_formatter try: self.neo4j_schema = ( validated_data.neo4j_schema_model.neo4j_schema diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index becf3c40..17bc3088 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -65,6 +65,13 @@ class VectorRetriever(Retriever): index_name (str): Vector index name. embedder (Optional[Embedder]): Embedder object to embed query text. return_properties (Optional[list[str]]): List of node properties to return. + result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem. + + Two variables are provided in the neo4j.Record: + + - node: Represents the node retrieved from the vector index search. + - score: Denotes the similarity score. + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). Raises: @@ -77,6 +84,9 @@ def __init__( index_name: str, embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, + result_formatter: Optional[ + Callable[[neo4j.Record], RetrieverResultItem] + ] = None, neo4j_database: Optional[str] = None, ) -> None: try: @@ -87,6 +97,7 @@ def __init__( index_name=index_name, embedder_model=embedder_model, return_properties=return_properties, + result_formatter=result_formatter, neo4j_database=neo4j_database, ) except ValidationError as e: @@ -102,6 +113,7 @@ def __init__( if validated_data.embedder_model else None ) + self.result_formatter = validated_data.result_formatter self._node_label = None self._embedding_node_property = None self._embedding_dimension = None @@ -222,7 +234,7 @@ class VectorCypherRetriever(Retriever): index_name (str): Vector index name. retrieval_query (str): Cypher query that gets appended. embedder (Optional[Embedder]): Embedder object to embed query text. - result_formatter (Optional[Callable[[Any], Any]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem. + result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem. neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). """ @@ -233,7 +245,9 @@ def __init__( index_name: str, retrieval_query: str, embedder: Optional[Embedder] = None, - result_formatter: Optional[Callable[[Any], Any]] = None, + result_formatter: Optional[ + Callable[[neo4j.Record], RetrieverResultItem] + ] = None, neo4j_database: Optional[str] = None, ) -> None: try: @@ -244,6 +258,7 @@ def __init__( index_name=index_name, retrieval_query=retrieval_query, embedder_model=embedder_model, + result_formatter=result_formatter, neo4j_database=neo4j_database, ) except ValidationError as e: @@ -259,7 +274,7 @@ def __init__( if validated_data.embedder_model else None ) - self.result_formatter = result_formatter + self.result_formatter = validated_data.result_formatter self._node_label = None self._node_embedding_property = None self._embedding_dimension = None diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 42dc482d..45826dd5 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -15,7 +15,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Literal, Optional +from typing import Any, Callable, Literal, Optional import neo4j from pydantic import ( @@ -201,6 +201,7 @@ class VectorRetrieverModel(BaseModel): index_name: str embedder_model: Optional[EmbedderModel] = None return_properties: Optional[list[str]] = None + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None @@ -209,6 +210,7 @@ class VectorCypherRetrieverModel(BaseModel): index_name: str retrieval_query: str embedder_model: Optional[EmbedderModel] = None + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None @@ -218,6 +220,7 @@ class HybridRetrieverModel(BaseModel): fulltext_index_name: str embedder_model: Optional[EmbedderModel] = None return_properties: Optional[list[str]] = None + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None @@ -227,6 +230,7 @@ class HybridCypherRetrieverModel(BaseModel): fulltext_index_name: str retrieval_query: str embedder_model: Optional[EmbedderModel] = None + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None @@ -235,3 +239,4 @@ class Text2CypherRetrieverModel(BaseModel): llm_model: LLMModel neo4j_schema_model: Optional[Neo4jSchemaModel] = None examples: Optional[list[str]] = None + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4721ace0..3d792e1e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable from unittest.mock import MagicMock, patch import neo4j @@ -25,6 +26,7 @@ VectorCypherRetriever, VectorRetriever, ) +from neo4j_genai.types import RetrieverResultItem @pytest.fixture(scope="function") @@ -84,4 +86,15 @@ def t2c_retriever( @pytest.fixture(scope="function") def neo4j_record() -> neo4j.Record: - return neo4j.Record({"node": "dummy-node", "score": 1.0}) + return neo4j.Record({"node": "dummy-node", "score": 1.0, "node_id": 123}) + + +@pytest.fixture(scope="function") +def result_formatter() -> Callable[[neo4j.Record], RetrieverResultItem]: + def format_function(record: neo4j.Record) -> RetrieverResultItem: + return RetrieverResultItem( + content=record.get("node"), + metadata={"score": record.get("score"), "node_id": record.get("node_id")}, + ) + + return format_function diff --git a/tests/unit/retrievers/external/test_pinecone.py b/tests/unit/retrievers/external/test_pinecone.py index 1b37eea7..5e1b65e9 100644 --- a/tests/unit/retrievers/external/test_pinecone.py +++ b/tests/unit/retrievers/external/test_pinecone.py @@ -245,3 +245,35 @@ def test_pinecone_retriever_search_retrieval_query( ], metadata={"__retriever": "PineconeNeo4jRetriever"}, ) + + +def test_pinecone_retriever_with_result_format_function( + driver: MagicMock, + client: MagicMock, + neo4j_record: MagicMock, + result_formatter: MagicMock, +) -> None: + retriever = PineconeNeo4jRetriever( + driver=driver, + client=client, + index_name="dummy-text", + id_property_neo4j="sync_id", + result_formatter=result_formatter, + ) + with mock.patch.object(retriever, "index"): + driver.execute_query.return_value = ( + [neo4j_record], + None, + None, + ) + query_vector = [1.0 for _ in range(1536)] + records = retriever.search(query_vector=query_vector) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="dummy-node", metadata={"score": 1.0, "node_id": 123} + ), + ], + metadata={"__retriever": "PineconeNeo4jRetriever"}, + ) diff --git a/tests/unit/retrievers/external/test_weaviate.py b/tests/unit/retrievers/external/test_weaviate.py index 9ce2f8ef..b6bd3815 100644 --- a/tests/unit/retrievers/external/test_weaviate.py +++ b/tests/unit/retrievers/external/test_weaviate.py @@ -255,3 +255,39 @@ def test_match_query_with_both_return_properties_and_retrieval_query() -> None: "WHERE node[$id_property] = match_id_value " + retrieval_query ) assert match_query.strip() == expected.strip() + + +def test_weaviate_retriever_with_result_format_function( + driver: MagicMock, neo4j_record: MagicMock, result_formatter: MagicMock +) -> None: + query_text = "may thy knife chip and shatter" + top_k = 5 + node_id_value = "node-test-id" + node_match_score = 0.9 + + wc = WClient(node_id_value=node_id_value, node_match_score=node_match_score) + + retriever = WeaviateNeo4jRetriever( + driver=driver, + client=wc, + collection="dummy-collection", + id_property_neo4j="sync_id", + id_property_external="neo4j_id", + result_formatter=result_formatter, + ) + driver.execute_query.return_value = [ + [neo4j_record], + None, + None, + ] + + records = retriever.search(query_text=query_text, top_k=top_k) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="dummy-node", metadata={"score": 1.0, "node_id": 123} + ), + ], + metadata={"__retriever": "WeaviateNeo4jRetriever"}, + ) diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index da1d3e74..9737ffa1 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -58,6 +58,46 @@ def test_hybrid_retriever_invalid_fulltext_index_name( assert "Input should be a valid string" in str(exc_info.value) +@patch("neo4j_genai.retrievers.HybridRetriever._verify_version") +def test_hybrid_retriever_with_result_format_function( + _verify_version_mock: MagicMock, + driver: MagicMock, + embedder: MagicMock, + neo4j_record: MagicMock, + result_formatter: MagicMock, +) -> None: + embed_query_vector = [1.0 for _ in range(1536)] + embedder.embed_query.return_value = embed_query_vector + vector_index_name = "vector-index" + fulltext_index_name = "fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + + retriever = HybridRetriever( + driver, + vector_index_name, + fulltext_index_name, + embedder, + result_formatter=result_formatter, + ) + retriever.driver.execute_query.return_value = [ # type: ignore + [neo4j_record], + None, + None, + ] + + records = retriever.search(query_text=query_text, top_k=top_k) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="dummy-node", metadata={"score": 1.0, "node_id": 123} + ), + ], + metadata={"__retriever": "HybridRetriever"}, + ) + + @patch("neo4j_genai.retrievers.HybridRetriever._verify_version") def test_hybrid_retriever_invalid_database_name( _verify_version_mock: MagicMock, driver: MagicMock @@ -341,7 +381,49 @@ def test_hybrid_cypher_retrieval_query_with_params( assert records == RetrieverResult( items=[ RetrieverResultItem( - content="", metadata=None + content="", + metadata=None, + ), + ], + metadata={"__retriever": "HybridCypherRetriever"}, + ) + + +@patch("neo4j_genai.retrievers.HybridCypherRetriever._verify_version") +def test_hybrid_cypher_retriever_with_result_format_function( + _verify_version_mock: MagicMock, + driver: MagicMock, + embedder: MagicMock, + neo4j_record: MagicMock, + result_formatter: MagicMock, +) -> None: + embed_query_vector = [1.0 for _ in range(1536)] + embedder.embed_query.return_value = embed_query_vector + vector_index_name = "vector-index" + fulltext_index_name = "fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + + retriever = HybridCypherRetriever( + driver, + vector_index_name, + fulltext_index_name, + "", + embedder, + result_formatter=result_formatter, + ) + retriever.driver.execute_query.return_value = [ # type: ignore + [neo4j_record], + None, + None, + ] + + records = retriever.search(query_text=query_text, top_k=top_k) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="dummy-node", metadata={"score": 1.0, "node_id": 123} ), ], metadata={"__retriever": "HybridCypherRetriever"}, diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 5ee17e07..17a217e1 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -24,6 +24,7 @@ from neo4j_genai.generation.prompts import Text2CypherTemplate from neo4j_genai.llm import LLMResponse from neo4j_genai.retrievers import Text2CypherRetriever +from neo4j_genai.types import RetrieverResult, RetrieverResultItem def test_t2c_retriever_initialization(driver: MagicMock, llm: MagicMock) -> None: @@ -145,3 +146,35 @@ def test_t2c_retriever_cypher_error( with pytest.raises(Text2CypherRetrievalError) as e: retriever.search(query_text=query_text) assert "Failed to get search result" in str(e) + + +@patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version") +def test_t2c_retriever_with_result_format_function( + _verify_version_mock: MagicMock, + driver: MagicMock, + llm: MagicMock, + neo4j_record: MagicMock, + result_formatter: MagicMock, +) -> None: + retriever = Text2CypherRetriever( + driver=driver, llm=llm, result_formatter=result_formatter + ) + t2c_query = "MATCH (n) RETURN n;" + retriever.llm.invoke.return_value = LLMResponse(content=t2c_query) + query_text = "may thy knife chip and shatter" + driver.execute_query.return_value = [ + [neo4j_record], + None, + None, + ] + + records = retriever.search(query_text=query_text) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="dummy-node", metadata={"score": 1.0, "node_id": 123} + ), + ], + metadata={"cypher": t2c_query, "__retriever": "Text2CypherRetriever"}, + ) diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 26a95949..60bdc4a3 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -14,7 +14,6 @@ # limitations under the License. from __future__ import annotations -from typing import Any from unittest.mock import MagicMock, patch import neo4j @@ -273,6 +272,49 @@ def test_vector_retriever_search_both_text_and_vector( ) +@patch("neo4j_genai.retrievers.VectorRetriever._fetch_index_infos") +@patch("neo4j_genai.retrievers.VectorRetriever._verify_version") +def test_vector_retriever_with_result_format_function( + _verify_version_mock: MagicMock, + _fetch_index_infos: MagicMock, + driver: MagicMock, + embedder: MagicMock, + neo4j_record: MagicMock, + result_formatter: MagicMock, +) -> None: + embed_query_vector = [1.0 for _ in range(1536)] + embedder.embed_query.return_value = embed_query_vector + index_name = "my-index" + + retriever = VectorRetriever( + driver, + index_name, + embedder=embedder, + result_formatter=result_formatter, + ) + query_text = "may thy knife chip and shatter" + top_k = 5 + driver.execute_query.return_value = [ + [neo4j_record], + None, + None, + ] + + records = retriever.search( + query_text=query_text, + top_k=top_k, + ) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="dummy-node", metadata={"score": 1.0, "node_id": 123} + ), + ], + metadata={"__retriever": "VectorRetriever"}, + ) + + def test_vector_cypher_retriever_search_missing_embedder_for_text( vector_cypher_retriever: VectorCypherRetriever, ) -> None: @@ -370,6 +412,8 @@ def test_retrieval_query_with_result_format_function( _fetch_index_infos: MagicMock, driver: MagicMock, embedder: MagicMock, + neo4j_record: MagicMock, + result_formatter: MagicMock, ) -> None: embed_query_vector = [1.0 for _ in range(1536)] embedder.embed_query.return_value = embed_query_vector @@ -378,24 +422,17 @@ def test_retrieval_query_with_result_format_function( RETURN node.id AS node_id, node.text AS text, score """ - def format_function(record: dict[str, Any]) -> RetrieverResultItem: - return RetrieverResultItem( - content=record.get("text"), - metadata={"score": record.get("score"), "node_id": record.get("node_id")}, - ) - retriever = VectorCypherRetriever( driver, index_name, retrieval_query, embedder=embedder, - result_formatter=format_function, + result_formatter=result_formatter, ) query_text = "may thy knife chip and shatter" top_k = 5 - record = neo4j.Record({"node_id": 123, "text": "dummy-text", "score": 1.0}) driver.execute_query.return_value = [ - [record], + [neo4j_record], None, None, ] @@ -421,7 +458,7 @@ def format_function(record: dict[str, Any]) -> RetrieverResultItem: assert records == RetrieverResult( items=[ RetrieverResultItem( - content="dummy-text", metadata={"score": 1.0, "node_id": 123} + content="dummy-node", metadata={"score": 1.0, "node_id": 123} ), ], metadata={"__retriever": "VectorCypherRetriever"},