From 2d5699083e16dbe79933cced71d869844a1821c2 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Wed, 27 Nov 2024 18:29:10 -0800 Subject: [PATCH 1/3] When requesting metrics SQL, respect the desired engine --- .../datajunction_server/api/helpers.py | 22 ++++-- datajunction-server/tests/api/cubes_test.py | 75 +++++++++++++++++++ .../tests/construction/build_test.py | 1 - 3 files changed, 90 insertions(+), 8 deletions(-) diff --git a/datajunction-server/datajunction_server/api/helpers.py b/datajunction-server/datajunction_server/api/helpers.py index 7e9619b04..9bed2bd4c 100644 --- a/datajunction-server/datajunction_server/api/helpers.py +++ b/datajunction-server/datajunction_server/api/helpers.py @@ -646,9 +646,12 @@ async def build_sql_for_multiple_metrics( # pylint: disable=too-many-arguments, dimension_columns, materialized=True, ) + materialized_cube_catalog = None if cube: - catalog = await get_catalog_by_name(session, cube.availability.catalog) # type: ignore - available_engines = catalog.engines + available_engines + materialized_cube_catalog = await get_catalog_by_name( + session, + cube.availability.catalog, # type: ignore + ) # Check if selected engine is available engine = ( @@ -662,17 +665,22 @@ async def build_sql_for_multiple_metrics( # pylint: disable=too-many-arguments, f"Available engines include: {', '.join(engine.name for engine in available_engines)}", ) + # Do not use the materialized cube if the chosen engine is not available for + # the materialized cube's catalog + if ( + cube + and materialized_cube_catalog + and engine.name not in [eng.name for eng in materialized_cube_catalog.engines] + ): + cube = None + validate_orderby(orderby, metrics, dimensions) - if cube and cube.materializations and cube.availability and use_materialized: + if cube and cube.availability and use_materialized and materialized_cube_catalog: if access_control: # pragma: no cover access_control.add_request_by_node(cube) access_control.state = access.AccessControlState.INDIRECT access_control.raise_if_invalid_requests() - materialized_cube_catalog = await get_catalog_by_name( - session, - cube.availability.catalog, - ) query_ast = build_materialized_cube_node( # pylint: disable=E1121 metric_columns, dimension_columns, diff --git a/datajunction-server/tests/api/cubes_test.py b/datajunction-server/tests/api/cubes_test.py index 02edfa68d..3fb1f1802 100644 --- a/datajunction-server/tests/api/cubes_test.py +++ b/datajunction-server/tests/api/cubes_test.py @@ -1199,6 +1199,81 @@ async def test_druid_cube_agg_materialization( assert druid_materialization["schedule"] == "@daily" +@pytest.mark.asyncio +async def test_materialized_cube_sql( + client_with_repairs_cube: AsyncClient, +): + """ + Test generating SQL for a materialized cube with two cases: + (1) the materialized table's catalog is compatible with the desired engine. + (2) the materialized table's catalog is not compatible with the desired engine. + """ + await make_a_test_cube( + client_with_repairs_cube, + "default.mini_repairs_cube", + with_materialization=True, + ) + await client_with_repairs_cube.post( + "/data/default.mini_repairs_cube/availability/", + json={ + "catalog": "draft", + "schema_": "roads", + "table": "mini_repairs_cube", + "valid_through_ts": 1010129120, + }, + ) + # Ask for SQL with metrics, dimensions, filters, order by, and limit + response = await client_with_repairs_cube.get( + "/sql/", + params={ + "metrics": ["default.avg_repair_price", "default.num_repair_orders"], + "dimensions": ["default.hard_hat.state", "default.dispatcher.company_name"], + "limit": 100, + }, + ) + results = response.json() + assert "default_DOT_hard_hat AS (" in results["sql"] + + response = await client_with_repairs_cube.post( + "/data/default.mini_repairs_cube/availability/", + json={ + "catalog": "default", + "schema_": "roads", + "table": "mini_repairs_cube", + "valid_through_ts": 1010129120, + }, + ) + assert response.status_code == 200 + + response = await client_with_repairs_cube.get( + "/sql/", + params={ + "metrics": ["default.avg_repair_price", "default.num_repair_orders"], + "dimensions": ["default.hard_hat.state", "default.dispatcher.company_name"], + "limit": 100, + }, + ) + response = await client_with_repairs_cube.get( + "/sql/", + params={ + "metrics": ["default.avg_repair_price", "default.num_repair_orders"], + "dimensions": ["default.hard_hat.state", "default.dispatcher.company_name"], + "limit": 100, + }, + ) + results = response.json() + expected_sql = """ + SELECT + SUM(default_DOT_avg_repair_price), + SUM(default_DOT_num_repair_orders), + default_DOT_hard_hat_DOT_state, + default_DOT_dispatcher_DOT_company_name + FROM mini_repairs_cube + GROUP BY default_DOT_hard_hat_DOT_state, default_DOT_dispatcher_DOT_company_name + LIMIT 100""" + assert str(parse(results["sql"])) == str(parse(expected_sql)) + + @pytest.mark.asyncio async def test_cube_sql_generation_with_availability( client_with_repairs_cube: AsyncClient, diff --git a/datajunction-server/tests/construction/build_test.py b/datajunction-server/tests/construction/build_test.py index 499b475a9..9c4887f54 100644 --- a/datajunction-server/tests/construction/build_test.py +++ b/datajunction-server/tests/construction/build_test.py @@ -136,7 +136,6 @@ async def test_build_metric_with_required_dimensions( basic_DOT_source_DOT_comments_metrics.basic_DOT_num_comments_bnd FROM basic_DOT_source_DOT_comments_metrics """ - print("blahhhh", str(query)) assert str(parse(str(query))) == str(parse(str(expected))) From 8d1b618d5f8c832bb21cde39415169bbfcc3ab53 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Wed, 27 Nov 2024 20:27:51 -0800 Subject: [PATCH 2/3] Fix tests --- .../construction/build_v2.py | 4 +- datajunction-server/tests/api/cubes_test.py | 276 ------------------ datajunction-server/tests/api/helpers_test.py | 7 +- 3 files changed, 6 insertions(+), 281 deletions(-) diff --git a/datajunction-server/datajunction_server/construction/build_v2.py b/datajunction-server/datajunction_server/construction/build_v2.py index ac674d44b..828124831 100644 --- a/datajunction-server/datajunction_server/construction/build_v2.py +++ b/datajunction-server/datajunction_server/construction/build_v2.py @@ -877,6 +877,7 @@ async def build(self) -> ast.Query: Builds SQL for multiple metrics with the requested set of dimensions, filter expressions, order by, and limit clauses. """ + print("self.metric_nodes", self.metric_nodes) measures_queries = await self.build_measures_queries() # Join together the transforms on the shared dimensions and select all @@ -1020,6 +1021,7 @@ async def build_metric_agg(self, metric_node, parent_node): """ if self._access_control: self._access_control.add_request_by_node(metric_node) # type: ignore + print("metric_node", metric_node) metric_query_builder = await QueryBuilder.create(self.session, metric_node) if self._ignore_errors: metric_query_builder = ( # pragma: no cover @@ -1031,7 +1033,7 @@ async def build_metric_agg(self, metric_node, parent_node): .build() ) self.errors.extend(metric_query_builder.errors) - + print("metric_query", metric_query) metric_query.ctes[-1].select.projection[0].set_semantic_entity( # type: ignore f"{metric_node.name}.{amenable_name(metric_node.name)}", ) diff --git a/datajunction-server/tests/api/cubes_test.py b/datajunction-server/tests/api/cubes_test.py index 3fb1f1802..9e365f504 100644 --- a/datajunction-server/tests/api/cubes_test.py +++ b/datajunction-server/tests/api/cubes_test.py @@ -1274,282 +1274,6 @@ async def test_materialized_cube_sql( assert str(parse(results["sql"])) == str(parse(expected_sql)) -@pytest.mark.asyncio -async def test_cube_sql_generation_with_availability( - client_with_repairs_cube: AsyncClient, -): - """ - Test generating SQL for metrics + dimensions in a cube after adding a cube materialization - """ - await client_with_repairs_cube.post( - "/data/default.repairs_cube/availability/", - json={ - "catalog": "default", - "schema_": "roads", - "table": "repairs_cube", - "valid_through_ts": 1010129120, - }, - ) - - # Ask for SQL with metrics, dimensions, filters, order by, and limit - response = await client_with_repairs_cube.get( - "/sql/", - params={ - "metrics": [ - "default.discounted_orders_rate", - "default.num_repair_orders", - "default.avg_repair_price", - ], - "dimensions": [ - "default.hard_hat.country", - "default.hard_hat.postal_code", - "default.hard_hat.hire_date", - ], - "filters": ["default.hard_hat.country='NZ'"], - "orderby": ["default.hard_hat.country ASC"], - "limit": 100, - }, - ) - data = response.json() - assert data["columns"] == [ - { - "name": "default_DOT_hard_hat_DOT_country", - "type": "string", - "column": "country", - "node": "default.hard_hat", - "semantic_entity": "default.hard_hat.country", - "semantic_type": "dimension", - }, - { - "name": "default_DOT_hard_hat_DOT_postal_code", - "type": "string", - "column": "postal_code", - "node": "default.hard_hat", - "semantic_entity": "default.hard_hat.postal_code", - "semantic_type": "dimension", - }, - { - "name": "default_DOT_hard_hat_DOT_hire_date", - "type": "timestamp", - "column": "hire_date", - "node": "default.hard_hat", - "semantic_entity": "default.hard_hat.hire_date", - "semantic_type": "dimension", - }, - { - "name": "default_DOT_discounted_orders_rate", - "type": "double", - "column": "default_DOT_discounted_orders_rate", - "node": "default.discounted_orders_rate", - "semantic_entity": "default.discounted_orders_rate.default_DOT_discounted_orders_rate", - "semantic_type": "metric", - }, - { - "name": "default_DOT_num_repair_orders", - "type": "bigint", - "column": "default_DOT_num_repair_orders", - "node": "default.num_repair_orders", - "semantic_entity": "default.num_repair_orders.default_DOT_num_repair_orders", - "semantic_type": "metric", - }, - { - "name": "default_DOT_avg_repair_price", - "type": "double", - "column": "default_DOT_avg_repair_price", - "node": "default.avg_repair_price", - "semantic_entity": "default.avg_repair_price.default_DOT_avg_repair_price", - "semantic_type": "metric", - }, - ] - assert str(parse(data["sql"])) == str( - parse( - """ - WITH - default_DOT_repair_orders_fact AS ( - SELECT repair_orders.repair_order_id, - repair_orders.municipality_id, - repair_orders.hard_hat_id, - repair_orders.dispatcher_id, - repair_orders.order_date, - repair_orders.dispatched_date, - repair_orders.required_date, - repair_order_details.discount, - repair_order_details.price, - repair_order_details.quantity, - repair_order_details.repair_type_id, - repair_order_details.price * repair_order_details.quantity AS total_repair_cost, - repair_orders.dispatched_date - repair_orders.order_date AS time_to_dispatch, - repair_orders.dispatched_date - repair_orders.required_date AS dispatch_delay - FROM roads.repair_orders AS repair_orders JOIN roads.repair_order_details AS repair_order_details ON repair_orders.repair_order_id = repair_order_details.repair_order_id - ), - default_DOT_hard_hat AS ( - SELECT default_DOT_hard_hats.hard_hat_id, - default_DOT_hard_hats.last_name, - default_DOT_hard_hats.first_name, - default_DOT_hard_hats.title, - default_DOT_hard_hats.birth_date, - default_DOT_hard_hats.hire_date, - default_DOT_hard_hats.address, - default_DOT_hard_hats.city, - default_DOT_hard_hats.state, - default_DOT_hard_hats.postal_code, - default_DOT_hard_hats.country, - default_DOT_hard_hats.manager, - default_DOT_hard_hats.contractor_id - FROM roads.hard_hats AS default_DOT_hard_hats - WHERE default_DOT_hard_hats.country = 'NZ' - ), default_DOT_repair_orders_fact_metrics AS ( - SELECT - default_DOT_hard_hat.country default_DOT_hard_hat_DOT_country, - default_DOT_hard_hat.postal_code default_DOT_hard_hat_DOT_postal_code, - default_DOT_hard_hat.hire_date default_DOT_hard_hat_DOT_hire_date, - CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / count(*) AS default_DOT_discounted_orders_rate, - count(default_DOT_repair_orders_fact.repair_order_id) default_DOT_num_repair_orders, - avg(default_DOT_repair_orders_fact.price) default_DOT_avg_repair_price - FROM default_DOT_repair_orders_fact INNER JOIN default_DOT_hard_hat ON default_DOT_repair_orders_fact.hard_hat_id = default_DOT_hard_hat.hard_hat_id - GROUP BY default_DOT_hard_hat.country, default_DOT_hard_hat.postal_code, default_DOT_hard_hat.hire_date - ) - SELECT - default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_country, - default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_postal_code, - default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_hire_date, - default_DOT_repair_orders_fact_metrics.default_DOT_discounted_orders_rate, - default_DOT_repair_orders_fact_metrics.default_DOT_num_repair_orders, - default_DOT_repair_orders_fact_metrics.default_DOT_avg_repair_price - FROM default_DOT_repair_orders_fact_metrics - ORDER BY default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_country ASC - LIMIT 100 - """, - ), - ) # noqa: W191,E101 - - # Ask for SQL with only metrics and dimensions - response = await client_with_repairs_cube.get( - "/sql/", - params={ - "metrics": [ - "default.discounted_orders_rate", - "default.num_repair_orders", - "default.avg_repair_price", - ], - "dimensions": [ - "default.hard_hat.country", - "default.hard_hat.postal_code", - "default.hard_hat.hire_date", - ], - }, - ) - data = response.json() - assert data["columns"] == [ - { - "name": "default_DOT_hard_hat_DOT_country", - "type": "string", - "column": "country", - "node": "default.hard_hat", - "semantic_entity": "default.hard_hat.country", - "semantic_type": "dimension", - }, - { - "name": "default_DOT_hard_hat_DOT_postal_code", - "type": "string", - "column": "postal_code", - "node": "default.hard_hat", - "semantic_entity": "default.hard_hat.postal_code", - "semantic_type": "dimension", - }, - { - "name": "default_DOT_hard_hat_DOT_hire_date", - "type": "timestamp", - "column": "hire_date", - "node": "default.hard_hat", - "semantic_entity": "default.hard_hat.hire_date", - "semantic_type": "dimension", - }, - { - "name": "default_DOT_discounted_orders_rate", - "type": "double", - "column": "default_DOT_discounted_orders_rate", - "node": "default.discounted_orders_rate", - "semantic_entity": "default.discounted_orders_rate.default_DOT_discounted_orders_rate", - "semantic_type": "metric", - }, - { - "name": "default_DOT_num_repair_orders", - "type": "bigint", - "column": "default_DOT_num_repair_orders", - "node": "default.num_repair_orders", - "semantic_entity": "default.num_repair_orders.default_DOT_num_repair_orders", - "semantic_type": "metric", - }, - { - "name": "default_DOT_avg_repair_price", - "type": "double", - "column": "default_DOT_avg_repair_price", - "node": "default.avg_repair_price", - "semantic_entity": "default.avg_repair_price.default_DOT_avg_repair_price", - "semantic_type": "metric", - }, - ] - assert str(parse(data["sql"])) == str( - parse( - """ - WITH default_DOT_repair_orders_fact AS ( - SELECT - repair_orders.repair_order_id, - repair_orders.municipality_id, - repair_orders.hard_hat_id, - repair_orders.dispatcher_id, - repair_orders.order_date, - repair_orders.dispatched_date, - repair_orders.required_date, - repair_order_details.discount, - repair_order_details.price, - repair_order_details.quantity, - repair_order_details.repair_type_id, - repair_order_details.price * repair_order_details.quantity AS total_repair_cost, - repair_orders.dispatched_date - repair_orders.order_date AS time_to_dispatch, - repair_orders.dispatched_date - repair_orders.required_date AS dispatch_delay - FROM roads.repair_orders AS repair_orders JOIN roads.repair_order_details AS repair_order_details ON repair_orders.repair_order_id = repair_order_details.repair_order_id - ), - default_DOT_hard_hat AS ( - SELECT default_DOT_hard_hats.hard_hat_id, - default_DOT_hard_hats.last_name, - default_DOT_hard_hats.first_name, - default_DOT_hard_hats.title, - default_DOT_hard_hats.birth_date, - default_DOT_hard_hats.hire_date, - default_DOT_hard_hats.address, - default_DOT_hard_hats.city, - default_DOT_hard_hats.state, - default_DOT_hard_hats.postal_code, - default_DOT_hard_hats.country, - default_DOT_hard_hats.manager, - default_DOT_hard_hats.contractor_id - FROM roads.hard_hats AS default_DOT_hard_hats - ), - default_DOT_repair_orders_fact_metrics AS ( - SELECT default_DOT_hard_hat.country default_DOT_hard_hat_DOT_country, - default_DOT_hard_hat.postal_code default_DOT_hard_hat_DOT_postal_code, - default_DOT_hard_hat.hire_date default_DOT_hard_hat_DOT_hire_date, - CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / count(*) AS default_DOT_discounted_orders_rate, - count(default_DOT_repair_orders_fact.repair_order_id) default_DOT_num_repair_orders, - avg(default_DOT_repair_orders_fact.price) default_DOT_avg_repair_price - FROM default_DOT_repair_orders_fact INNER JOIN default_DOT_hard_hat ON default_DOT_repair_orders_fact.hard_hat_id = default_DOT_hard_hat.hard_hat_id - GROUP BY default_DOT_hard_hat.country, default_DOT_hard_hat.postal_code, default_DOT_hard_hat.hire_date - ) - SELECT - default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_country, - default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_postal_code, - default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_hire_date, - default_DOT_repair_orders_fact_metrics.default_DOT_discounted_orders_rate, - default_DOT_repair_orders_fact_metrics.default_DOT_num_repair_orders, - default_DOT_repair_orders_fact_metrics.default_DOT_avg_repair_price - FROM default_DOT_repair_orders_fact_metrics - """, - ), - ) # noqa: W191,E101 - - @pytest.mark.asyncio async def test_remove_dimension_link_invalidate_cube( client_with_repairs_cube: AsyncClient, # pylint: disable=redefined-outer-name diff --git a/datajunction-server/tests/api/helpers_test.py b/datajunction-server/tests/api/helpers_test.py index d0688c328..f556475cc 100644 --- a/datajunction-server/tests/api/helpers_test.py +++ b/datajunction-server/tests/api/helpers_test.py @@ -138,14 +138,13 @@ async def test_build_sql_for_multiple_metrics( Test building SQL for multiple metrics """ mock_build_materialized_cube_node.return_value = MagicMock() - mock_get_catalog_by_name.return_value = MagicMock( - engines=[MagicMock(), MagicMock()], - ) + mock_engines = [MagicMock(name="eng1"), MagicMock(name="eng2")] + mock_get_catalog_by_name.return_value = MagicMock(engines=mock_engines) mock_find_existing_cube.return_value = MagicMock( availability=MagicMock(catalog="cata-foo"), ) mock_get_by_name.return_value = MagicMock( - current=MagicMock(catalog=MagicMock(engines=["eng1", "eng2"])), + current=MagicMock(catalog=MagicMock(name="cata-foo", engines=mock_engines)), ) mock_metric_columns = [MagicMock(name="col1"), MagicMock(name="col2")] mock_metric_nodes = ["mnode1", "mnode2"] From 05adbe663c5362f0d5df1b758c2b4331967510b3 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Wed, 27 Nov 2024 20:31:14 -0800 Subject: [PATCH 3/3] Cleanup --- .../datajunction_server/construction/build_v2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/datajunction-server/datajunction_server/construction/build_v2.py b/datajunction-server/datajunction_server/construction/build_v2.py index 828124831..43fb6d1a7 100644 --- a/datajunction-server/datajunction_server/construction/build_v2.py +++ b/datajunction-server/datajunction_server/construction/build_v2.py @@ -877,7 +877,6 @@ async def build(self) -> ast.Query: Builds SQL for multiple metrics with the requested set of dimensions, filter expressions, order by, and limit clauses. """ - print("self.metric_nodes", self.metric_nodes) measures_queries = await self.build_measures_queries() # Join together the transforms on the shared dimensions and select all @@ -1021,7 +1020,6 @@ async def build_metric_agg(self, metric_node, parent_node): """ if self._access_control: self._access_control.add_request_by_node(metric_node) # type: ignore - print("metric_node", metric_node) metric_query_builder = await QueryBuilder.create(self.session, metric_node) if self._ignore_errors: metric_query_builder = ( # pragma: no cover @@ -1033,7 +1031,6 @@ async def build_metric_agg(self, metric_node, parent_node): .build() ) self.errors.extend(metric_query_builder.errors) - print("metric_query", metric_query) metric_query.ctes[-1].select.projection[0].set_semantic_entity( # type: ignore f"{metric_node.name}.{amenable_name(metric_node.name)}", )