diff --git a/querybook/config/querybook_public_config.yaml b/querybook/config/querybook_public_config.yaml index 87281cbcb..33e0d6a6b 100644 --- a/querybook/config/querybook_public_config.yaml +++ b/querybook/config/querybook_public_config.yaml @@ -15,6 +15,9 @@ ai_assistant: table_vector_search: enabled: false + query_vector_search: + enabled: false + sql_complete: enabled: false diff --git a/querybook/server/const/ai_assistant.py b/querybook/server/const/ai_assistant.py index 211548bf6..d1de851fc 100644 --- a/querybook/server/const/ai_assistant.py +++ b/querybook/server/const/ai_assistant.py @@ -27,5 +27,7 @@ class AICommandType(Enum): DEFAULT_VECTOR_STORE_FETCH_LIMIT = 30 # how many tables to return from vector table search eventually DEFAULT_TABLE_SEARCH_LIMIT = 10 +# how many tables to return from vector query search eventually +DEFAULT_QUERY_SEARCH_LIMIT = 10 # how many tables to select for text-to-sql DEFAULT_TABLE_SELECT_LIMIT = 3 diff --git a/querybook/server/datasources/search.py b/querybook/server/datasources/search.py index 08251c86d..69ebf57ab 100644 --- a/querybook/server/datasources/search.py +++ b/querybook/server/datasources/search.py @@ -82,6 +82,20 @@ def search_query( return {"count": count, "results": results} +@register("/search/queries/vector/", methods=["GET"]) +def vector_search_query( + environment_id, + keywords, + filters=[], +): + from logic import vector_store as vs_logic + + verify_environment_permission([environment_id]) + filters.append(["environment_id", environment_id]) + + return vs_logic.search_query(keywords, filters) + + @register("/search/tables/", methods=["GET"]) def search_tables( metastore_id, diff --git a/querybook/server/lib/elasticsearch/search_query.py b/querybook/server/lib/elasticsearch/search_query.py index 6ed7fce87..434702ad3 100644 --- a/querybook/server/lib/elasticsearch/search_query.py +++ b/querybook/server/lib/elasticsearch/search_query.py @@ -91,3 +91,19 @@ def construct_query_search_query( ) return query + + +def construct_query_search_by_query_cell_ids(ids, filters, limit): + if not ids: + return {"query": {"match_all": {}}, "size": 0} + + bool_query = {"must": [{"terms": {"id": ids}}]} + + if filters: + filter_query = match_filters(filters, and_filter_names=FILTERS_TO_AND) + if filter_query: + bool_query["filter"] = filter_query["filter"] + + es_query = {"query": {"bool": bool_query}, "size": limit} + + return es_query diff --git a/querybook/server/lib/vector_store/base_vector_store.py b/querybook/server/lib/vector_store/base_vector_store.py index a2f82a506..301a371ad 100644 --- a/querybook/server/lib/vector_store/base_vector_store.py +++ b/querybook/server/lib/vector_store/base_vector_store.py @@ -108,3 +108,42 @@ def search_tables( table_score_dict[table_name] = table_score_dict.get(table_name, 0) + score return sorted(table_score_dict.items(), key=lambda x: x[1], reverse=True)[:k] + + def search_query( + self, + text: str, + threshold: float = DEFAULT_SIMILARITY_SCORE_THRESHOLD, + k=DEFAULT_TABLE_SEARCH_LIMIT, + fetch_k=DEFAULT_VECTOR_STORE_FETCH_LIMIT, + ) -> list[tuple[int, float]]: + """ + Finds similar SQL queries based on the given text (NLP query). + + Args: + text: The natural language description or keywords. + threshold: Only return queries with a similarity score above this value. + k: Max number of matching queries to return. + fetch_k: Number of queries to retrieve from vector store before trimming. + + Returns: + A list of (query_cell_id, score) tuples in descending score order. + """ + must_query = [ + {"term": {"metadata.type": "query"}}, + ] + boolean_filter = {"bool": {"must": must_query}} + + docs_with_score = self.similarity_search_with_score( + text, + k=fetch_k, + boolean_filter=boolean_filter, + ) + + query_results = [] + for doc, score in docs_with_score: + if score > threshold: + query_cell_id = doc.metadata.get("query_cell_id") + query_results.append((query_cell_id, score)) + + query_results.sort(key=lambda x: x[1], reverse=True) + return query_results[:k] diff --git a/querybook/server/logic/vector_store.py b/querybook/server/logic/vector_store.py index 29669cd08..d6d52c314 100644 --- a/querybook/server/logic/vector_store.py +++ b/querybook/server/logic/vector_store.py @@ -1,5 +1,6 @@ from app.db import with_session from const.ai_assistant import ( + DEFAULT_QUERY_SEARCH_LIMIT, DEFAULT_TABLE_SEARCH_LIMIT, MAX_SAMPLE_QUERY_COUNT_FOR_TABLE_SUMMARY, ) @@ -13,6 +14,9 @@ from logic.elasticsearch import get_sample_query_cells_by_table_name from logic.metastore import get_all_table, get_table_by_name from models.metastore import DataTable +from lib.elasticsearch.search_query import ( + construct_query_search_by_query_cell_ids, +) LOG = get_logger(__file__) @@ -175,6 +179,29 @@ def search_tables( return {"count": len(sorted_docs), "results": sorted_docs} +def search_query(keywords, filters=None, limit=DEFAULT_QUERY_SEARCH_LIMIT): + """Search related SQL queries from vector store based on NLP query text.""" + queries = get_vector_store().search_query(keywords, k=limit) + query_cell_ids = [q[0] for q in queries] + + if not query_cell_ids: + return {"count": 0, "results": []} + + es_query = construct_query_search_by_query_cell_ids( + ids=query_cell_ids, filters=filters, limit=limit + ) + + index_name = ES_CONFIG["query_cells"]["index_name"] + results = get_matching_objects(es_query, index_name) + + # Reorder the Elasticsearch results based on the vector store ranking + es_results_by_id = {res["id"]: res for res in results} + sorted_docs = [ + es_results_by_id[qid] for qid in query_cell_ids if qid in es_results_by_id + ] + return {"count": len(sorted_docs), "results": sorted_docs} + + @with_session def get_table_summary_by_name( metastore_id: int, full_table_name: str, session=None diff --git a/querybook/webapp/components/Search/SearchOverview.tsx b/querybook/webapp/components/Search/SearchOverview.tsx index 6f4676d3c..478bbe26d 100644 --- a/querybook/webapp/components/Search/SearchOverview.tsx +++ b/querybook/webapp/components/Search/SearchOverview.tsx @@ -299,6 +299,12 @@ export const SearchOverview: React.FC = ({ ? 'Search data docs' : 'Search tables'; + const showVectorSearch = + (searchType === SearchType.Table && + isAIFeatureEnabled('table_vector_search')) || + (searchType === SearchType.Query && + isAIFeatureEnabled('query_vector_search')); + return (
= ({ placeholder={placeholder} autoFocus /> - {searchType === SearchType.Table && - isAIFeatureEnabled('table_vector_search') && ( -
- - Natural Language Search - - updateUseVectorSearch(val)} - /> -
- )} + {showVectorSearch && ( +
+ + Natural Language Search + + updateUseVectorSearch(val)} + /> +
+ )}
); }; diff --git a/querybook/webapp/config.d.ts b/querybook/webapp/config.d.ts index 9961d9686..c7d14e141 100644 --- a/querybook/webapp/config.d.ts +++ b/querybook/webapp/config.d.ts @@ -103,6 +103,10 @@ declare module 'config/querybook_public_config.yaml' { enabled: boolean; }; + query_vector_search: { + enabled: boolean; + }; + sql_complete: { enabled: boolean; }; diff --git a/querybook/webapp/lib/public-config.ts b/querybook/webapp/lib/public-config.ts index 59b26f61e..0c3a6803e 100644 --- a/querybook/webapp/lib/public-config.ts +++ b/querybook/webapp/lib/public-config.ts @@ -6,6 +6,7 @@ export const isAIFeatureEnabled = ( | 'query_generation' | 'query_auto_fix' | 'table_vector_search' + | 'query_vector_search' | 'sql_complete' ): boolean => { const aiAssistantConfig = PublicConfig.ai_assistant; diff --git a/querybook/webapp/redux/search/action.ts b/querybook/webapp/redux/search/action.ts index b1fa700ea..fa8a24cec 100644 --- a/querybook/webapp/redux/search/action.ts +++ b/querybook/webapp/redux/search/action.ts @@ -125,10 +125,20 @@ export function performSearch(): ThunkResult> { }>; switch (searchType) { case SearchType.Query: - searchRequest = SearchQueryResource.search({ - ...searchParams, - environment_id: state.environment.currentEnvironmentId, - }); + if (useVectorSearch) { + searchRequest = SearchQueryResource.vectorSearch({ + environment_id: + state.environment.currentEnvironmentId, + keywords: searchString, + filters: searchParams.filters, + }); + } else { + searchRequest = SearchQueryResource.search({ + ...searchParams, + environment_id: + state.environment.currentEnvironmentId, + }); + } break; case SearchType.DataDoc: searchRequest = SearchDataDocResource.search({ diff --git a/querybook/webapp/resource/search.ts b/querybook/webapp/resource/search.ts index c2e93884f..d68cd9d65 100644 --- a/querybook/webapp/resource/search.ts +++ b/querybook/webapp/resource/search.ts @@ -81,6 +81,12 @@ export const SearchQueryResource = { results: IQueryPreview[]; count: number; }>('/search/queries/', params as unknown as Record), + + vectorSearch: (params: ISearchQueryParams) => + ds.fetch<{ + results: IQueryPreview[]; + count: number; + }>('/search/queries/vector/', { ...params }), }; export const SearchDataDocResource = {