Skip to content

Commit

Permalink
Allow kwargs in body_func
Browse files Browse the repository at this point in the history
  • Loading branch information
gdahia committed Sep 23, 2024
1 parent 439efe2 commit c6a25bf
Showing 1 changed file with 36 additions and 4 deletions.
40 changes: 36 additions & 4 deletions libs/elasticsearch/langchain_elasticsearch/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
import logging
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
from typing import (
Any,
Callable,
Coroutine,
Dict,
List,
Mapping,
Optional,
Sequence,
Union,
cast,
)

from elasticsearch import Elasticsearch
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables.config import run_in_executor

from langchain_elasticsearch._utilities import with_user_agent_header
from langchain_elasticsearch.client import create_elasticsearch_client
Expand All @@ -29,6 +44,8 @@ class ElasticsearchRetriever(BaseRetriever):
document_mapper: Function to map Elasticsearch hits to LangChain Documents.
"""

_expects_other_args = True

es_client: Elasticsearch
index_name: Union[str, Sequence[str]]
body_func: Callable[[str], Dict]
Expand Down Expand Up @@ -94,12 +111,12 @@ def from_es_params(
)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
if not self.es_client or not self.document_mapper:
raise ValueError("faulty configuration") # should not happen

body = self.body_func(query)
body = self.body_func(query, **kwargs)
results = self.es_client.search(index=self.index_name, body=body)
return [self.document_mapper(hit) for hit in results["hits"]["hits"]]

Expand All @@ -112,3 +129,18 @@ def _multi_field_mapper(self, hit: Mapping[str, Any]) -> Document:
field = self.content_field[hit["_index"]]
content = hit["_source"].pop(field)
return Document(page_content=content, metadata=hit)

async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> Coroutine[Any, Any, List[Document]]:
return await run_in_executor(
None,
self._get_relevant_documents,
query,
**kwargs,
run_manager=run_manager.get_sync(),
)

0 comments on commit c6a25bf

Please sign in to comment.