Skip to content

Commit

Permalink
Structured retriever results (#73)
Browse files Browse the repository at this point in the history
* Added result_formatter to all vector, hybrid, and t2c retrievers

* Added unit tests to test retrievers work with a format function

* Ruff formatting and fixed weaviate e2e tests

* Fixed Weaviate tests

* Typo in docs

* Added neo4j.Record variables in result_formatter in docstring

* Update CHANGELOG

---------

Co-authored-by: Will Tai <wtaisen@gmail.com>
  • Loading branch information
alexthomas93 and willtai authored Jul 2, 2024
1 parent 23fd585 commit 81b0209
Show file tree
Hide file tree
Showing 16 changed files with 324 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- Improved developer experience by copying the docstring from the `Retriever.get_search_results` method to the `Retriever.search` method
- Support for specifying database names in index handling methods and retrievers.
- User Guide in documentation.
- Introduced result_formatter argument to all retrievers, allowing custom formatting of retriever results.

### Changed
- Refactored import paths for retrievers to neo4j_genai.retrievers.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ Format the Results

.. warning::

This API is in beta mode and will be subject to change is the future.
This API is in beta mode and will be subject to change in the future.

For improved readability and ease in prompt-engineering, formatting the result to suit
specific needs involves providing a `record_formatter` function to the Cypher retrievers.
Expand Down
7 changes: 5 additions & 2 deletions src/neo4j_genai/retrievers/external/pinecone/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
EmbedderModel,
Neo4jDriverModel,
RawSearchResult,
RetrieverResultItem,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -78,7 +79,7 @@ class PineconeNeo4jRetriever(ExternalRetriever):
id_property_neo4j (str): The name of the Neo4j node property that's used as the identifier for relating matches from Weaviate to Neo4j nodes.
embedder (Optional[Embedder]): Embedder object to embed query text.
return_properties (Optional[list[str]]): List of node properties to return.
result_formatter (Optional[Callable[[Any], Any]]): Function to transform a neo4j.Record to a RetrieverResultItem.
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Raises:
Expand All @@ -94,7 +95,9 @@ def __init__(
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
retrieval_query: Optional[str] = None,
result_formatter: Optional[Callable[[Any], Any]] = None,
result_formatter: Optional[
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
):
try:
Expand Down
9 changes: 7 additions & 2 deletions src/neo4j_genai/retrievers/external/pinecone/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
field_validator,
)

from neo4j_genai.types import EmbedderModel, Neo4jDriverModel, VectorSearchModel
from neo4j_genai.types import (
EmbedderModel,
Neo4jDriverModel,
RetrieverResultItem,
VectorSearchModel,
)


class PineconeSearchModel(VectorSearchModel):
Expand Down Expand Up @@ -52,5 +57,5 @@ class PineconeNeo4jRetrieverModel(BaseModel):
embedder_model: Optional[EmbedderModel] = None
return_properties: Optional[list[str]] = None
retrieval_query: Optional[str] = None
result_formatter: Optional[Callable[[neo4j.Record], str]] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None
9 changes: 7 additions & 2 deletions src/neo4j_genai/retrievers/external/weaviate/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from weaviate.client import WeaviateClient
from weaviate.collections.classes.filters import _Filters

from neo4j_genai.types import EmbedderModel, Neo4jDriverModel, VectorSearchModel
from neo4j_genai.types import (
EmbedderModel,
Neo4jDriverModel,
RetrieverResultItem,
VectorSearchModel,
)


class WeaviateModel(BaseModel):
Expand All @@ -50,7 +55,7 @@ class WeaviateNeo4jRetrieverModel(BaseModel):
embedder_model: Optional[EmbedderModel]
return_properties: Optional[list[str]] = None
retrieval_query: Optional[str] = None
result_formatter: Optional[Callable[[neo4j.Record], str]] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None


Expand Down
13 changes: 10 additions & 3 deletions src/neo4j_genai/retrievers/external/weaviate/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
WeaviateNeo4jRetrieverModel,
WeaviateNeo4jSearchModel,
)
from neo4j_genai.types import EmbedderModel, Neo4jDriverModel, RawSearchResult
from neo4j_genai.types import (
EmbedderModel,
Neo4jDriverModel,
RawSearchResult,
RetrieverResultItem,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,7 +74,7 @@ class WeaviateNeo4jRetriever(ExternalRetriever):
id_property_neo4j (str): The name of the Neo4j node property that's used as the identifier for relating matches from Weaviate to Neo4j nodes.
embedder (Optional[Embedder]): Embedder object to embed query text.
return_properties (Optional[list[str]]): List of node properties to return.
result_formatter (Optional[Callable[[Any], Any]]): Function to transform a neo4j.Record to a RetrieverResultItem.
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Raises:
Expand All @@ -86,7 +91,9 @@ def __init__(
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
retrieval_query: Optional[str] = None,
result_formatter: Optional[Callable[[Any], Any]] = None,
result_formatter: Optional[
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
):
try:
Expand Down
19 changes: 16 additions & 3 deletions src/neo4j_genai/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ class HybridRetriever(Retriever):
embedder (Optional[Embedder]): Embedder object to embed query text.
return_properties (Optional[list[str]]): List of node properties to return.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
Two variables are provided in the neo4j.Record:
- node: Represents the node retrieved from the vector index search.
- score: Denotes the similarity score.
"""

def __init__(
Expand All @@ -80,6 +85,9 @@ def __init__(
fulltext_index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
result_formatter: Optional[
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
) -> None:
try:
Expand All @@ -91,6 +99,7 @@ def __init__(
fulltext_index_name=fulltext_index_name,
embedder_model=embedder_model,
return_properties=return_properties,
result_formatter=result_formatter,
neo4j_database=neo4j_database,
)
except ValidationError as e:
Expand All @@ -107,6 +116,7 @@ def __init__(
if validated_data.embedder_model
else None
)
self.result_formatter = validated_data.result_formatter

def default_record_formatter(self, record: neo4j.Record) -> RetrieverResultItem:
"""
Expand Down Expand Up @@ -219,7 +229,7 @@ class HybridCypherRetriever(Retriever):
fulltext_index_name (str): Fulltext index name.
retrieval_query (str): Cypher query that gets appended.
embedder (Optional[Embedder]): Embedder object to embed query text.
result_formatter (Optional[Callable[[Any], Any]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Raises:
Expand All @@ -233,7 +243,9 @@ def __init__(
fulltext_index_name: str,
retrieval_query: str,
embedder: Optional[Embedder] = None,
result_formatter: Optional[Callable[[Any], Any]] = None,
result_formatter: Optional[
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
) -> None:
try:
Expand All @@ -245,6 +257,7 @@ def __init__(
fulltext_index_name=fulltext_index_name,
retrieval_query=retrieval_query,
embedder_model=embedder_model,
result_formatter=result_formatter,
neo4j_database=neo4j_database,
)
except ValidationError as e:
Expand All @@ -261,7 +274,7 @@ def __init__(
if validated_data.embedder_model
else None
)
self.result_formatter = result_formatter
self.result_formatter = validated_data.result_formatter

def get_search_results(
self,
Expand Down
8 changes: 7 additions & 1 deletion src/neo4j_genai/retrievers/text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import logging
from typing import Optional
from typing import Callable, Optional

import neo4j
from neo4j.exceptions import CypherSyntaxError, DriverError, Neo4jError
Expand All @@ -36,6 +36,7 @@
Neo4jDriverModel,
Neo4jSchemaModel,
RawSearchResult,
RetrieverResultItem,
Text2CypherRetrieverModel,
Text2CypherSearchModel,
)
Expand Down Expand Up @@ -65,6 +66,9 @@ def __init__(
llm: LLMInterface,
neo4j_schema: Optional[str] = None,
examples: Optional[list[str]] = None,
result_formatter: Optional[
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
) -> None:
try:
driver_model = Neo4jDriverModel(driver=driver)
Expand All @@ -77,13 +81,15 @@ def __init__(
llm_model=llm_model,
neo4j_schema_model=neo4j_schema_model,
examples=examples,
result_formatter=result_formatter,
)
except ValidationError as e:
raise RetrieverInitializationError(e.errors()) from e

super().__init__(validated_data.driver_model.driver)
self.llm = validated_data.llm_model.llm
self.examples = validated_data.examples
self.result_formatter = validated_data.result_formatter
try:
self.neo4j_schema = (
validated_data.neo4j_schema_model.neo4j_schema
Expand Down
21 changes: 18 additions & 3 deletions src/neo4j_genai/retrievers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ class VectorRetriever(Retriever):
index_name (str): Vector index name.
embedder (Optional[Embedder]): Embedder object to embed query text.
return_properties (Optional[list[str]]): List of node properties to return.
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
Two variables are provided in the neo4j.Record:
- node: Represents the node retrieved from the vector index search.
- score: Denotes the similarity score.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Raises:
Expand All @@ -77,6 +84,9 @@ def __init__(
index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
result_formatter: Optional[
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
) -> None:
try:
Expand All @@ -87,6 +97,7 @@ def __init__(
index_name=index_name,
embedder_model=embedder_model,
return_properties=return_properties,
result_formatter=result_formatter,
neo4j_database=neo4j_database,
)
except ValidationError as e:
Expand All @@ -102,6 +113,7 @@ def __init__(
if validated_data.embedder_model
else None
)
self.result_formatter = validated_data.result_formatter
self._node_label = None
self._embedding_node_property = None
self._embedding_dimension = None
Expand Down Expand Up @@ -222,7 +234,7 @@ class VectorCypherRetriever(Retriever):
index_name (str): Vector index name.
retrieval_query (str): Cypher query that gets appended.
embedder (Optional[Embedder]): Embedder object to embed query text.
result_formatter (Optional[Callable[[Any], Any]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
"""
Expand All @@ -233,7 +245,9 @@ def __init__(
index_name: str,
retrieval_query: str,
embedder: Optional[Embedder] = None,
result_formatter: Optional[Callable[[Any], Any]] = None,
result_formatter: Optional[
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
) -> None:
try:
Expand All @@ -244,6 +258,7 @@ def __init__(
index_name=index_name,
retrieval_query=retrieval_query,
embedder_model=embedder_model,
result_formatter=result_formatter,
neo4j_database=neo4j_database,
)
except ValidationError as e:
Expand All @@ -259,7 +274,7 @@ def __init__(
if validated_data.embedder_model
else None
)
self.result_formatter = result_formatter
self.result_formatter = validated_data.result_formatter
self._node_label = None
self._node_embedding_property = None
self._embedding_dimension = None
Expand Down
7 changes: 6 additions & 1 deletion src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Literal, Optional
from typing import Any, Callable, Literal, Optional

import neo4j
from pydantic import (
Expand Down Expand Up @@ -201,6 +201,7 @@ class VectorRetrieverModel(BaseModel):
index_name: str
embedder_model: Optional[EmbedderModel] = None
return_properties: Optional[list[str]] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None


Expand All @@ -209,6 +210,7 @@ class VectorCypherRetrieverModel(BaseModel):
index_name: str
retrieval_query: str
embedder_model: Optional[EmbedderModel] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None


Expand All @@ -218,6 +220,7 @@ class HybridRetrieverModel(BaseModel):
fulltext_index_name: str
embedder_model: Optional[EmbedderModel] = None
return_properties: Optional[list[str]] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None


Expand All @@ -227,6 +230,7 @@ class HybridCypherRetrieverModel(BaseModel):
fulltext_index_name: str
retrieval_query: str
embedder_model: Optional[EmbedderModel] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None


Expand All @@ -235,3 +239,4 @@ class Text2CypherRetrieverModel(BaseModel):
llm_model: LLMModel
neo4j_schema_model: Optional[Neo4jSchemaModel] = None
examples: Optional[list[str]] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
15 changes: 14 additions & 1 deletion tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable
from unittest.mock import MagicMock, patch

import neo4j
Expand All @@ -25,6 +26,7 @@
VectorCypherRetriever,
VectorRetriever,
)
from neo4j_genai.types import RetrieverResultItem


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -84,4 +86,15 @@ def t2c_retriever(

@pytest.fixture(scope="function")
def neo4j_record() -> neo4j.Record:
return neo4j.Record({"node": "dummy-node", "score": 1.0})
return neo4j.Record({"node": "dummy-node", "score": 1.0, "node_id": 123})


@pytest.fixture(scope="function")
def result_formatter() -> Callable[[neo4j.Record], RetrieverResultItem]:
def format_function(record: neo4j.Record) -> RetrieverResultItem:
return RetrieverResultItem(
content=record.get("node"),
metadata={"score": record.get("score"), "node_id": record.get("node_id")},
)

return format_function
Loading

0 comments on commit 81b0209

Please sign in to comment.