Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow kwargs in body_func #46

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions libs/elasticsearch/langchain_elasticsearch/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
import logging
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
from typing import (
Any,
Callable,
Coroutine,
Dict,
List,
Mapping,
Optional,
Sequence,
Union,
cast,
)

from elasticsearch import Elasticsearch
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables.config import run_in_executor

from langchain_elasticsearch._utilities import with_user_agent_header
from langchain_elasticsearch.client import create_elasticsearch_client
Expand All @@ -29,6 +44,8 @@ class ElasticsearchRetriever(BaseRetriever):
document_mapper: Function to map Elasticsearch hits to LangChain Documents.
"""

_expects_other_args = True

es_client: Elasticsearch
index_name: Union[str, Sequence[str]]
body_func: Callable[[str], Dict]
Expand Down Expand Up @@ -94,12 +111,12 @@ def from_es_params(
)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
if not self.es_client or not self.document_mapper:
raise ValueError("faulty configuration") # should not happen

body = self.body_func(query)
body = self.body_func(query, **kwargs)
results = self.es_client.search(index=self.index_name, body=body)
return [self.document_mapper(hit) for hit in results["hits"]["hits"]]

Expand All @@ -112,3 +129,18 @@ def _multi_field_mapper(self, hit: Mapping[str, Any]) -> Document:
field = self.content_field[hit["_index"]]
content = hit["_source"].pop(field)
return Document(page_content=content, metadata=hit)

async def _aget_relevant_documents(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method necessary? We do not currently have support for async in any other part of this library, and we would most likely not use an executor when we add it, since the Elasticsearch client does support async natively.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realize that, sorry!

Anyway, these ainvoke calls work even while the underlying implementation being fully synchronous.

Leaving the retriever without this _aget_relevant_documents override may confuse users by allowing them to call the retriever in an async fashion (even when it is just running sychronously under the hood) and obtaining either a different result than they expect, if the body_func supports the omission of its keyword arguments, or a hard-to-debug error for the absence of said arguments.

Does that make sense to you?

Copy link
Collaborator

@miguelgrinberg miguelgrinberg Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I do realize that this works, but it is not a great solution. When we add async support to this library we are going to do it properly. Adding this hack seems out of place, considering that no other function in the library does it. Of course you can subclass our implementation and add async methods in a subclass if this soluiton works for you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense! I will just remove the async bit, then. Thanks!

self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> Coroutine[Any, Any, List[Document]]:
return await run_in_executor(
None,
self._get_relevant_documents,
query,
**kwargs,
run_manager=run_manager.get_sync(),
)
Loading