Skip to content

Commit

Permalink
test ElasticsearchStore with async parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
giacbrd committed Oct 7, 2024
1 parent 79b5441 commit be88280
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
6 changes: 4 additions & 2 deletions libs/elasticsearch/langchain_elasticsearch/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def __init__(
elif isinstance(strategy, RetrievalStrategy) and es_use_async:
try:
async_strategy = _sync_to_async_strategy_map[type(strategy)](
**vars(strategy)
**{k: v for k, v in vars(strategy).items() if not k.startswith("_")}
)
except KeyError:
raise TypeError(
Expand All @@ -860,7 +860,9 @@ def __init__(
)
elif isinstance(strategy, AsyncRetrievalStrategy):
try:
strategy = _async_to_sync_strategy_map[type(strategy)](**vars(strategy))
strategy = _async_to_sync_strategy_map[type(strategy)](
**{k: v for k, v in vars(strategy).items() if not k.startswith("_")}
)
except KeyError:
raise TypeError(
f"Cannot find a proper sync counterpart "
Expand Down
39 changes: 38 additions & 1 deletion libs/elasticsearch/tests/unit_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
AsyncDenseVectorScriptScoreStrategy,
AsyncDenseVectorStrategy,
AsyncSparseVectorStrategy,
AsyncVectorStore,
)
from langchain_core.documents import Document

from langchain_elasticsearch.embeddings import Embeddings, EmbeddingServiceAdapter, AsyncEmbeddingServiceAdapter
from langchain_elasticsearch.embeddings import (
AsyncEmbeddingServiceAdapter,
Embeddings,
EmbeddingServiceAdapter,
)
from langchain_elasticsearch.vectorstores import (
ApproxRetrievalStrategy,
BM25RetrievalStrategy,
Expand Down Expand Up @@ -265,6 +270,38 @@ def test_agent_header(self, store: ElasticsearchStore) -> None:
is not None
), f"The string '{agent}' does not match the expected pattern."

def test_initialization(
self, hybrid_store: ElasticsearchStore, embeddings: Embeddings
) -> None:
assert isinstance(
hybrid_store._async_embedding_service, AsyncEmbeddingServiceAdapter
)
client = Elasticsearch(hosts=["http://dummy:9200"]) # never connected to
async_client = AsyncElasticsearch(
hosts=["http://dummy:9200"]
) # never connected to
store = ElasticsearchStore(
index_name="test_index",
es_connection=client,
es_async_connection=async_client,
strategy=SparseVectorStrategy(model_id="model_1"),
)
assert isinstance(store._async_store, AsyncVectorStore)
assert isinstance(
store._async_store.retrieval_strategy, AsyncSparseVectorStrategy
) # type: ignore
assert store._async_store.retrieval_strategy.model_id == "model_1" # type: ignore
store = ElasticsearchStore(
index_name="test_index",
es_connection=client,
es_use_async=True,
strategy=AsyncBM25Strategy(k1=20),
)
assert store._async_store is None
assert store._async_embedding_service is None
assert isinstance(store._store.retrieval_strategy, BM25Strategy)
assert store._store.retrieval_strategy.k1 == 20

def test_similarity_search(
self, store: ElasticsearchStore, static_hits: List[Dict]
) -> None:
Expand Down

0 comments on commit be88280

Please sign in to comment.