Skip to content

Commit

Permalink
add support for rrf=False back
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Oct 3, 2024
1 parent 203f678 commit b9bf952
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 118 deletions.
76 changes: 41 additions & 35 deletions elasticsearch/helpers/vectorstore/_async/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,44 +285,50 @@ def _hybrid(
# RRF is used to even the score from the knn query and text query
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
rrf_options = {}
if isinstance(self.rrf, Dict):
if "rank_constant" in self.rrf:
rrf_options["rank_constant"] = self.rrf["rank_constant"]
if "window_size" in self.rrf:
# 'window_size' was renamed to 'rank_window_size', but we support
# the older name for backwards compatibility
rrf_options["rank_window_size"] = self.rrf["window_size"]
if "rank_window_size" in self.rrf:
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
query_body = {
"retriever": {
"rrf": {
"retrievers": [
standard_query = {
"query": {
"bool": {
"must": [
{
"standard": {
"query": {
"bool": {
"must": [
{
"match": {
self.text_field: {
"query": query,
}
}
}
],
"filter": filter,
}
},
},
},
{"knn": knn},
"match": {
self.text_field: {
"query": query,
}
}
}
],
**rrf_options,
},
},
"filter": filter,
}
}
}

if self.rrf is False:
query_body = {
"knn": knn,
**standard_query,
}
else:
rrf_options = {}
if isinstance(self.rrf, Dict):
if "rank_constant" in self.rrf:
rrf_options["rank_constant"] = self.rrf["rank_constant"]
if "window_size" in self.rrf:
# 'window_size' was renamed to 'rank_window_size', but we support
# the older name for backwards compatibility
rrf_options["rank_window_size"] = self.rrf["window_size"]
if "rank_window_size" in self.rrf:
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
query_body = {
"retriever": {
"rrf": {
"retrievers": [
{"standard": standard_query},
{"knn": knn},
],
**rrf_options,
},
},
}
return query_body

def needs_inference(self) -> bool:
Expand Down
76 changes: 41 additions & 35 deletions elasticsearch/helpers/vectorstore/_sync/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,44 +285,50 @@ def _hybrid(
# RRF is used to even the score from the knn query and text query
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
rrf_options = {}
if isinstance(self.rrf, Dict):
if "rank_constant" in self.rrf:
rrf_options["rank_constant"] = self.rrf["rank_constant"]
if "window_size" in self.rrf:
# 'window_size' was renamed to 'rank_window_size', but we support
# the older name for backwards compatibility
rrf_options["rank_window_size"] = self.rrf["window_size"]
if "rank_window_size" in self.rrf:
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
query_body = {
"retriever": {
"rrf": {
"retrievers": [
standard_query = {
"query": {
"bool": {
"must": [
{
"standard": {
"query": {
"bool": {
"must": [
{
"match": {
self.text_field: {
"query": query,
}
}
}
],
"filter": filter,
}
},
},
},
{"knn": knn},
"match": {
self.text_field: {
"query": query,
}
}
}
],
**rrf_options,
},
},
"filter": filter,
}
}
}

if self.rrf is False:
query_body = {
"knn": knn,
**standard_query,
}
else:
rrf_options = {}
if isinstance(self.rrf, Dict):
if "rank_constant" in self.rrf:
rrf_options["rank_constant"] = self.rrf["rank_constant"]
if "window_size" in self.rrf:
# 'window_size' was renamed to 'rank_window_size', but we support
# the older name for backwards compatibility
rrf_options["rank_window_size"] = self.rrf["window_size"]
if "rank_window_size" in self.rrf:
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
query_body = {
"retriever": {
"rrf": {
"retrievers": [
{"standard": standard_query},
{"knn": knn},
],
**rrf_options,
},
},
}
return query_body

def needs_inference(self) -> bool:
Expand Down
99 changes: 51 additions & 48 deletions test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,64 +415,67 @@ def assert_query(
query: Optional[str],
expected_rrf: Union[dict, bool],
) -> dict:
cmp_query_body = {
"retriever": {
"rrf": {
"retrievers": [
{
"standard": {
"query": {
"bool": {
"filter": [],
"must": [
{
"match": {
"text_field": {"query": "foo"}
}
}
],
}
},
},
},
{
"knn": {
"field": "vector_field",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
},
},
],
standard_query = {
"query": {
"bool": {
"filter": [],
"must": [{"match": {"text_field": {"query": "foo"}}}],
}
}
}
knn_query = {
"field": "vector_field",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
}

if isinstance(expected_rrf, dict):
cmp_query_body["retriever"]["rrf"].update(expected_rrf)
if expected_rrf is not False:
cmp_query_body = {
"retriever": {
"rrf": {
"retrievers": [
{"standard": standard_query},
{"knn": knn_query},
],
}
}
}
if isinstance(expected_rrf, dict):
cmp_query_body["retriever"]["rrf"].update(expected_rrf)
else:
cmp_query_body = {
"knn": knn_query,
**standard_query,
}

assert query_body == cmp_query_body

return query_body

# 1. check query_body is okay
rrf_test_cases: List[Union[dict, bool]] = [
True,
False,
{"rank_constant": 1, "rank_window_size": 5},
]
if es_version(sync_client) >= (8, 14):
rrf_test_cases: List[Union[dict, bool]] = [
True,
False,
{"rank_constant": 1, "rank_window_size": 5},
]
else:
# for 8.13.x and older there is no retriever query, so we can only
# run hybrid searches with rrf=False
rrf_test_cases: List[Union[dict, bool]] = [False]
for rrf_test_case in rrf_test_cases:
store = VectorStore(
index=index,
Expand Down

0 comments on commit b9bf952

Please sign in to comment.