From c6a25bf1f6e7cbb0319782b8ea8b900973ebb6e4 Mon Sep 17 00:00:00 2001 From: Gabriel Dahia Date: Mon, 23 Sep 2024 12:02:00 -0300 Subject: [PATCH] Allow kwargs in body_func --- .../langchain_elasticsearch/retrievers.py | 40 +++++++++++++++++-- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/libs/elasticsearch/langchain_elasticsearch/retrievers.py b/libs/elasticsearch/langchain_elasticsearch/retrievers.py index a5a12f5..7ee0900 100644 --- a/libs/elasticsearch/langchain_elasticsearch/retrievers.py +++ b/libs/elasticsearch/langchain_elasticsearch/retrievers.py @@ -1,10 +1,25 @@ import logging -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast +from typing import ( + Any, + Callable, + Coroutine, + Dict, + List, + Mapping, + Optional, + Sequence, + Union, + cast, +) from elasticsearch import Elasticsearch -from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever +from langchain_core.runnables.config import run_in_executor from langchain_elasticsearch._utilities import with_user_agent_header from langchain_elasticsearch.client import create_elasticsearch_client @@ -29,6 +44,8 @@ class ElasticsearchRetriever(BaseRetriever): document_mapper: Function to map Elasticsearch hits to LangChain Documents. """ + _expects_other_args = True + es_client: Elasticsearch index_name: Union[str, Sequence[str]] body_func: Callable[[str], Dict] @@ -94,12 +111,12 @@ def from_es_params( ) def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any ) -> List[Document]: if not self.es_client or not self.document_mapper: raise ValueError("faulty configuration") # should not happen - body = self.body_func(query) + body = self.body_func(query, **kwargs) results = self.es_client.search(index=self.index_name, body=body) return [self.document_mapper(hit) for hit in results["hits"]["hits"]] @@ -112,3 +129,18 @@ def _multi_field_mapper(self, hit: Mapping[str, Any]) -> Document: field = self.content_field[hit["_index"]] content = hit["_source"].pop(field) return Document(page_content=content, metadata=hit) + + async def _aget_relevant_documents( + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + **kwargs: Any, + ) -> Coroutine[Any, Any, List[Document]]: + return await run_in_executor( + None, + self._get_relevant_documents, + query, + **kwargs, + run_manager=run_manager.get_sync(), + )