Skip to content

Commit

Permalink
integration tests and more init checks
Browse files Browse the repository at this point in the history
  • Loading branch information
giacbrd committed Oct 8, 2024
1 parent 1fa0969 commit a76f74d
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 13 deletions.
27 changes: 17 additions & 10 deletions libs/elasticsearch/langchain_elasticsearch/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ class ElasticsearchStore(VectorStore):
Pre-existing Elasticsearch connection.
es_async_connection: Optional[AsyncElasticsearch]
Pre-existing Elasticsearch async connection.
Must be set together with an es_connection or url / cloud ID.
es_url: Optional[str]
URL of the Elasticsearch instance to connect to.
es_cloud_id: Optional[str]
Expand All @@ -623,8 +624,9 @@ class ElasticsearchStore(VectorStore):
Password to use when connecting to Elasticsearch.
es_api_key: Optional[str]
API key to use when connecting to Elasticsearch.
es_use_async: bool
True if async IO methods will be called in an event loop.
es_use_async_client: bool
True for calling Async IO methods in an event loop with an async
Elasticsearch client.
Default to false, implicitly true when an es_async_connection is set.
Instantiate:
Expand Down Expand Up @@ -819,7 +821,7 @@ def __init__(
es_user: Optional[str] = None,
es_api_key: Optional[str] = None,
es_password: Optional[str] = None,
es_use_async: bool = False,
es_use_async_client: bool = False,
vector_query_field: str = "vector",
query_field: str = "text",
distance_strategy: Optional[
Expand All @@ -835,20 +837,25 @@ def __init__(
] = ApproxRetrievalStrategy(),
es_params: Optional[Dict[str, Any]] = None,
):
if es_connection and es_use_async:
es_use_async = False
if es_connection and es_use_async_client:
es_use_async_client = False
logger.warning(
"It is not possible to use Async IO if only an Elasticsearch"
" sync client is set, and not its async equivalent."
)
if not es_connection and not (es_url or es_cloud_id) and es_async_connection:
raise ValueError(
"It is not possible to provide only an Async IO client"
"for Elasticsearch"
)
if es_async_connection is not None:
es_use_async = True
es_use_async_client = True
async_strategy = None
if isinstance(strategy, BaseRetrievalStrategy):
strategy, async_strategy = _convert_retrieval_strategy(
strategy, distance=distance_strategy or DistanceStrategy.COSINE
)
elif isinstance(strategy, RetrievalStrategy) and es_use_async:
elif isinstance(strategy, RetrievalStrategy) and es_use_async_client:
try:
async_strategy = _sync_to_async_strategy_map[type(strategy)](
**{k: v for k, v in vars(strategy).items() if not k.startswith("_")}
Expand Down Expand Up @@ -882,7 +889,7 @@ def __init__(
password=es_password,
params=es_params,
)
if not es_async_connection and es_use_async:
if not es_async_connection and es_use_async_client:
es_async_connection = create_elasticsearch_async_client(
url=es_url,
cloud_id=es_cloud_id,
Expand All @@ -904,9 +911,9 @@ def __init__(

self._async_store = None
self._async_embedding_service = None
# es_async_connection and es_use_async should
# es_async_connection and es_use_async_client should
# always have the same truth value at this point
if es_async_connection is not None and es_use_async:
if es_async_connection is not None and es_use_async_client:
async_embedding_service = None
if embedding:
async_embedding_service = AsyncEmbeddingServiceAdapter(embedding)
Expand Down
Loading

0 comments on commit a76f74d

Please sign in to comment.