Skip to content

Commit

Permalink
Merge pull request #1231 from shangyian/respect-engine
Browse files Browse the repository at this point in the history
When requesting metrics SQL, respect the desired engine
  • Loading branch information
shangyian authored Dec 2, 2024
2 parents 690a289 + 05adbe6 commit 65bb2fe
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 261 deletions.
22 changes: 15 additions & 7 deletions datajunction-server/datajunction_server/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,12 @@ async def build_sql_for_multiple_metrics( # pylint: disable=too-many-arguments,
dimension_columns,
materialized=True,
)
materialized_cube_catalog = None
if cube:
catalog = await get_catalog_by_name(session, cube.availability.catalog) # type: ignore
available_engines = catalog.engines + available_engines
materialized_cube_catalog = await get_catalog_by_name(
session,
cube.availability.catalog, # type: ignore
)

# Check if selected engine is available
engine = (
Expand All @@ -662,17 +665,22 @@ async def build_sql_for_multiple_metrics( # pylint: disable=too-many-arguments,
f"Available engines include: {', '.join(engine.name for engine in available_engines)}",
)

# Do not use the materialized cube if the chosen engine is not available for
# the materialized cube's catalog
if (
cube
and materialized_cube_catalog
and engine.name not in [eng.name for eng in materialized_cube_catalog.engines]
):
cube = None

validate_orderby(orderby, metrics, dimensions)

if cube and cube.materializations and cube.availability and use_materialized:
if cube and cube.availability and use_materialized and materialized_cube_catalog:
if access_control: # pragma: no cover
access_control.add_request_by_node(cube)
access_control.state = access.AccessControlState.INDIRECT
access_control.raise_if_invalid_requests()
materialized_cube_catalog = await get_catalog_by_name(
session,
cube.availability.catalog,
)
query_ast = build_materialized_cube_node( # pylint: disable=E1121
metric_columns,
dimension_columns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,6 @@ async def build_metric_agg(self, metric_node, parent_node):
.build()
)
self.errors.extend(metric_query_builder.errors)

metric_query.ctes[-1].select.projection[0].set_semantic_entity( # type: ignore
f"{metric_node.name}.{amenable_name(metric_node.name)}",
)
Expand Down
295 changes: 47 additions & 248 deletions datajunction-server/tests/api/cubes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,279 +1200,78 @@ async def test_druid_cube_agg_materialization(


@pytest.mark.asyncio
async def test_cube_sql_generation_with_availability(
async def test_materialized_cube_sql(
client_with_repairs_cube: AsyncClient,
):
"""
Test generating SQL for metrics + dimensions in a cube after adding a cube materialization
Test generating SQL for a materialized cube with two cases:
(1) the materialized table's catalog is compatible with the desired engine.
(2) the materialized table's catalog is not compatible with the desired engine.
"""
await make_a_test_cube(
client_with_repairs_cube,
"default.mini_repairs_cube",
with_materialization=True,
)
await client_with_repairs_cube.post(
"/data/default.repairs_cube/availability/",
"/data/default.mini_repairs_cube/availability/",
json={
"catalog": "default",
"catalog": "draft",
"schema_": "roads",
"table": "repairs_cube",
"table": "mini_repairs_cube",
"valid_through_ts": 1010129120,
},
)

# Ask for SQL with metrics, dimensions, filters, order by, and limit
response = await client_with_repairs_cube.get(
"/sql/",
params={
"metrics": [
"default.discounted_orders_rate",
"default.num_repair_orders",
"default.avg_repair_price",
],
"dimensions": [
"default.hard_hat.country",
"default.hard_hat.postal_code",
"default.hard_hat.hire_date",
],
"filters": ["default.hard_hat.country='NZ'"],
"orderby": ["default.hard_hat.country ASC"],
"metrics": ["default.avg_repair_price", "default.num_repair_orders"],
"dimensions": ["default.hard_hat.state", "default.dispatcher.company_name"],
"limit": 100,
},
)
data = response.json()
assert data["columns"] == [
{
"name": "default_DOT_hard_hat_DOT_country",
"type": "string",
"column": "country",
"node": "default.hard_hat",
"semantic_entity": "default.hard_hat.country",
"semantic_type": "dimension",
},
{
"name": "default_DOT_hard_hat_DOT_postal_code",
"type": "string",
"column": "postal_code",
"node": "default.hard_hat",
"semantic_entity": "default.hard_hat.postal_code",
"semantic_type": "dimension",
},
{
"name": "default_DOT_hard_hat_DOT_hire_date",
"type": "timestamp",
"column": "hire_date",
"node": "default.hard_hat",
"semantic_entity": "default.hard_hat.hire_date",
"semantic_type": "dimension",
},
{
"name": "default_DOT_discounted_orders_rate",
"type": "double",
"column": "default_DOT_discounted_orders_rate",
"node": "default.discounted_orders_rate",
"semantic_entity": "default.discounted_orders_rate.default_DOT_discounted_orders_rate",
"semantic_type": "metric",
},
{
"name": "default_DOT_num_repair_orders",
"type": "bigint",
"column": "default_DOT_num_repair_orders",
"node": "default.num_repair_orders",
"semantic_entity": "default.num_repair_orders.default_DOT_num_repair_orders",
"semantic_type": "metric",
},
{
"name": "default_DOT_avg_repair_price",
"type": "double",
"column": "default_DOT_avg_repair_price",
"node": "default.avg_repair_price",
"semantic_entity": "default.avg_repair_price.default_DOT_avg_repair_price",
"semantic_type": "metric",
results = response.json()
assert "default_DOT_hard_hat AS (" in results["sql"]

response = await client_with_repairs_cube.post(
"/data/default.mini_repairs_cube/availability/",
json={
"catalog": "default",
"schema_": "roads",
"table": "mini_repairs_cube",
"valid_through_ts": 1010129120,
},
]
assert str(parse(data["sql"])) == str(
parse(
"""
WITH
default_DOT_repair_orders_fact AS (
SELECT repair_orders.repair_order_id,
repair_orders.municipality_id,
repair_orders.hard_hat_id,
repair_orders.dispatcher_id,
repair_orders.order_date,
repair_orders.dispatched_date,
repair_orders.required_date,
repair_order_details.discount,
repair_order_details.price,
repair_order_details.quantity,
repair_order_details.repair_type_id,
repair_order_details.price * repair_order_details.quantity AS total_repair_cost,
repair_orders.dispatched_date - repair_orders.order_date AS time_to_dispatch,
repair_orders.dispatched_date - repair_orders.required_date AS dispatch_delay
FROM roads.repair_orders AS repair_orders JOIN roads.repair_order_details AS repair_order_details ON repair_orders.repair_order_id = repair_order_details.repair_order_id
),
default_DOT_hard_hat AS (
SELECT default_DOT_hard_hats.hard_hat_id,
default_DOT_hard_hats.last_name,
default_DOT_hard_hats.first_name,
default_DOT_hard_hats.title,
default_DOT_hard_hats.birth_date,
default_DOT_hard_hats.hire_date,
default_DOT_hard_hats.address,
default_DOT_hard_hats.city,
default_DOT_hard_hats.state,
default_DOT_hard_hats.postal_code,
default_DOT_hard_hats.country,
default_DOT_hard_hats.manager,
default_DOT_hard_hats.contractor_id
FROM roads.hard_hats AS default_DOT_hard_hats
WHERE default_DOT_hard_hats.country = 'NZ'
), default_DOT_repair_orders_fact_metrics AS (
SELECT
default_DOT_hard_hat.country default_DOT_hard_hat_DOT_country,
default_DOT_hard_hat.postal_code default_DOT_hard_hat_DOT_postal_code,
default_DOT_hard_hat.hire_date default_DOT_hard_hat_DOT_hire_date,
CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / count(*) AS default_DOT_discounted_orders_rate,
count(default_DOT_repair_orders_fact.repair_order_id) default_DOT_num_repair_orders,
avg(default_DOT_repair_orders_fact.price) default_DOT_avg_repair_price
FROM default_DOT_repair_orders_fact INNER JOIN default_DOT_hard_hat ON default_DOT_repair_orders_fact.hard_hat_id = default_DOT_hard_hat.hard_hat_id
GROUP BY default_DOT_hard_hat.country, default_DOT_hard_hat.postal_code, default_DOT_hard_hat.hire_date
)
SELECT
default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_country,
default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_postal_code,
default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_hire_date,
default_DOT_repair_orders_fact_metrics.default_DOT_discounted_orders_rate,
default_DOT_repair_orders_fact_metrics.default_DOT_num_repair_orders,
default_DOT_repair_orders_fact_metrics.default_DOT_avg_repair_price
FROM default_DOT_repair_orders_fact_metrics
ORDER BY default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_country ASC
LIMIT 100
""",
),
) # noqa: W191,E101
)
assert response.status_code == 200

# Ask for SQL with only metrics and dimensions
response = await client_with_repairs_cube.get(
"/sql/",
params={
"metrics": [
"default.discounted_orders_rate",
"default.num_repair_orders",
"default.avg_repair_price",
],
"dimensions": [
"default.hard_hat.country",
"default.hard_hat.postal_code",
"default.hard_hat.hire_date",
],
"metrics": ["default.avg_repair_price", "default.num_repair_orders"],
"dimensions": ["default.hard_hat.state", "default.dispatcher.company_name"],
"limit": 100,
},
)
data = response.json()
assert data["columns"] == [
{
"name": "default_DOT_hard_hat_DOT_country",
"type": "string",
"column": "country",
"node": "default.hard_hat",
"semantic_entity": "default.hard_hat.country",
"semantic_type": "dimension",
},
{
"name": "default_DOT_hard_hat_DOT_postal_code",
"type": "string",
"column": "postal_code",
"node": "default.hard_hat",
"semantic_entity": "default.hard_hat.postal_code",
"semantic_type": "dimension",
},
{
"name": "default_DOT_hard_hat_DOT_hire_date",
"type": "timestamp",
"column": "hire_date",
"node": "default.hard_hat",
"semantic_entity": "default.hard_hat.hire_date",
"semantic_type": "dimension",
},
{
"name": "default_DOT_discounted_orders_rate",
"type": "double",
"column": "default_DOT_discounted_orders_rate",
"node": "default.discounted_orders_rate",
"semantic_entity": "default.discounted_orders_rate.default_DOT_discounted_orders_rate",
"semantic_type": "metric",
},
{
"name": "default_DOT_num_repair_orders",
"type": "bigint",
"column": "default_DOT_num_repair_orders",
"node": "default.num_repair_orders",
"semantic_entity": "default.num_repair_orders.default_DOT_num_repair_orders",
"semantic_type": "metric",
},
{
"name": "default_DOT_avg_repair_price",
"type": "double",
"column": "default_DOT_avg_repair_price",
"node": "default.avg_repair_price",
"semantic_entity": "default.avg_repair_price.default_DOT_avg_repair_price",
"semantic_type": "metric",
response = await client_with_repairs_cube.get(
"/sql/",
params={
"metrics": ["default.avg_repair_price", "default.num_repair_orders"],
"dimensions": ["default.hard_hat.state", "default.dispatcher.company_name"],
"limit": 100,
},
]
assert str(parse(data["sql"])) == str(
parse(
"""
WITH default_DOT_repair_orders_fact AS (
SELECT
repair_orders.repair_order_id,
repair_orders.municipality_id,
repair_orders.hard_hat_id,
repair_orders.dispatcher_id,
repair_orders.order_date,
repair_orders.dispatched_date,
repair_orders.required_date,
repair_order_details.discount,
repair_order_details.price,
repair_order_details.quantity,
repair_order_details.repair_type_id,
repair_order_details.price * repair_order_details.quantity AS total_repair_cost,
repair_orders.dispatched_date - repair_orders.order_date AS time_to_dispatch,
repair_orders.dispatched_date - repair_orders.required_date AS dispatch_delay
FROM roads.repair_orders AS repair_orders JOIN roads.repair_order_details AS repair_order_details ON repair_orders.repair_order_id = repair_order_details.repair_order_id
),
default_DOT_hard_hat AS (
SELECT default_DOT_hard_hats.hard_hat_id,
default_DOT_hard_hats.last_name,
default_DOT_hard_hats.first_name,
default_DOT_hard_hats.title,
default_DOT_hard_hats.birth_date,
default_DOT_hard_hats.hire_date,
default_DOT_hard_hats.address,
default_DOT_hard_hats.city,
default_DOT_hard_hats.state,
default_DOT_hard_hats.postal_code,
default_DOT_hard_hats.country,
default_DOT_hard_hats.manager,
default_DOT_hard_hats.contractor_id
FROM roads.hard_hats AS default_DOT_hard_hats
),
default_DOT_repair_orders_fact_metrics AS (
SELECT default_DOT_hard_hat.country default_DOT_hard_hat_DOT_country,
default_DOT_hard_hat.postal_code default_DOT_hard_hat_DOT_postal_code,
default_DOT_hard_hat.hire_date default_DOT_hard_hat_DOT_hire_date,
CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / count(*) AS default_DOT_discounted_orders_rate,
count(default_DOT_repair_orders_fact.repair_order_id) default_DOT_num_repair_orders,
avg(default_DOT_repair_orders_fact.price) default_DOT_avg_repair_price
FROM default_DOT_repair_orders_fact INNER JOIN default_DOT_hard_hat ON default_DOT_repair_orders_fact.hard_hat_id = default_DOT_hard_hat.hard_hat_id
GROUP BY default_DOT_hard_hat.country, default_DOT_hard_hat.postal_code, default_DOT_hard_hat.hire_date
)
SELECT
default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_country,
default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_postal_code,
default_DOT_repair_orders_fact_metrics.default_DOT_hard_hat_DOT_hire_date,
default_DOT_repair_orders_fact_metrics.default_DOT_discounted_orders_rate,
default_DOT_repair_orders_fact_metrics.default_DOT_num_repair_orders,
default_DOT_repair_orders_fact_metrics.default_DOT_avg_repair_price
FROM default_DOT_repair_orders_fact_metrics
""",
),
) # noqa: W191,E101
)
results = response.json()
expected_sql = """
SELECT
SUM(default_DOT_avg_repair_price),
SUM(default_DOT_num_repair_orders),
default_DOT_hard_hat_DOT_state,
default_DOT_dispatcher_DOT_company_name
FROM mini_repairs_cube
GROUP BY default_DOT_hard_hat_DOT_state, default_DOT_dispatcher_DOT_company_name
LIMIT 100"""
assert str(parse(results["sql"])) == str(parse(expected_sql))


@pytest.mark.asyncio
Expand Down
7 changes: 3 additions & 4 deletions datajunction-server/tests/api/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,13 @@ async def test_build_sql_for_multiple_metrics(
Test building SQL for multiple metrics
"""
mock_build_materialized_cube_node.return_value = MagicMock()
mock_get_catalog_by_name.return_value = MagicMock(
engines=[MagicMock(), MagicMock()],
)
mock_engines = [MagicMock(name="eng1"), MagicMock(name="eng2")]
mock_get_catalog_by_name.return_value = MagicMock(engines=mock_engines)
mock_find_existing_cube.return_value = MagicMock(
availability=MagicMock(catalog="cata-foo"),
)
mock_get_by_name.return_value = MagicMock(
current=MagicMock(catalog=MagicMock(engines=["eng1", "eng2"])),
current=MagicMock(catalog=MagicMock(name="cata-foo", engines=mock_engines)),
)
mock_metric_columns = [MagicMock(name="col1"), MagicMock(name="col2")]
mock_metric_nodes = ["mnode1", "mnode2"]
Expand Down
Loading

0 comments on commit 65bb2fe

Please sign in to comment.