Skip to content

Commit

Permalink
Merge pull request #1264 from shangyian/measures-auto-preagg
Browse files Browse the repository at this point in the history
Only generate preaggregated measures SQL where possible
  • Loading branch information
shangyian authored Jan 14, 2025
2 parents 14b427e + b7ca851 commit 51bf1d2
Showing 4 changed files with 190 additions and 3 deletions.
2 changes: 1 addition & 1 deletion datajunction-server/datajunction_server/api/sql.py
Original file line number Diff line number Diff line change
@@ -105,7 +105,7 @@ async def get_measures_sql_for_cube_v2(
include_all_columns=include_all_columns,
sql_transpilation_library=settings.sql_transpilation_library,
use_materialized=use_materialized,
preaggregate=preaggregate,
preagg_requested=preaggregate,
)
return measures_query

25 changes: 23 additions & 2 deletions datajunction-server/datajunction_server/construction/build_v2.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@
from datajunction_server.models.node_type import NodeType
from datajunction_server.models.sql import GeneratedSQL
from datajunction_server.naming import amenable_name, from_amenable_name
from datajunction_server.sql.decompose import Measure
from datajunction_server.sql.decompose import Aggregability, Measure
from datajunction_server.sql.parsing.ast import CompileContext
from datajunction_server.sql.parsing.backends.antlr4 import ast, cached_parse, parse
from datajunction_server.utils import SEPARATOR, refresh_if_needed
@@ -103,7 +103,7 @@ async def get_measures_query( # pylint: disable=too-many-locals
include_all_columns: bool = False,
sql_transpilation_library: Optional[str] = None,
use_materialized: bool = True,
preaggregate: bool = False,
preagg_requested: bool = False,
) -> List[GeneratedSQL]:
"""
Builds the measures SQL for a set of metrics with dimensions and filters.
@@ -166,6 +166,18 @@ async def get_measures_query( # pylint: disable=too-many-locals
measures_queries = []
for parent_node, children in common_parents.items(): # type: ignore
children = sorted(children, key=lambda x: metrics_sorting_order.get(x.name, 0))

# Determine whether to pre-aggregate to the requested dimensions so that subsequent
# queries are more efficient by checking the measures on the requested metrics
preaggregate = preagg_requested and all(
len(metrics2measures[metric.name][0]) > 0
and all(
measure.rule.type == Aggregability.FULL
for measure in metrics2measures[metric.name][0]
)
for metric in children
)

measure_columns, dimensional_columns = [], []
query_builder = await QueryBuilder.create(
session,
@@ -245,6 +257,15 @@ async def get_measures_query( # pylint: disable=too-many-locals
for dep in dependencies
if dep.type == NodeType.SOURCE
],
grain=(
[
col.name
for col in columns_metadata
if col.semantic_type == SemanticType.DIMENSION
]
if preaggregate
else [pk_col.name for pk_col in parent_node.current.primary_key()]
),
errors=query_builder.errors,
),
)
1 change: 1 addition & 0 deletions datajunction-server/datajunction_server/models/sql.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ class GeneratedSQL(BaseModel):
sql: str
sql_transpilation_library: Optional[str] = None
columns: Optional[List[ColumnMetadata]] = None # pragma: no-cover
grain: list[str] | None = None
dialect: Optional[Dialect] = None
upstream_tables: Optional[List[str]] = None
errors: Optional[List[DJQueryBuildError]] = None
165 changes: 165 additions & 0 deletions datajunction-server/tests/api/sql_v2_test.py
Original file line number Diff line number Diff line change
@@ -1176,6 +1176,171 @@ async def test_measures_sql_errors(
]


@pytest.mark.asyncio
async def test_measures_sql_preagg_incompatible( # pylint: disable=too-many-arguments
module__client_with_roads: AsyncClient,
duckdb_conn: duckdb.DuckDBPyConnection, # pylint: disable=c-extension-no-member
):
"""
Test ``GET /sql/measures`` with incompatible metrics vs compatible metrics.
"""
await fix_dimension_links(module__client_with_roads)
await module__client_with_roads.post(
"/nodes/metric",
json={
"description": "A preagg incompatible metric",
"query": "SELECT COUNT(DISTINCT hard_hat_id) FROM default.repair_orders_fact",
"mode": "published",
"name": "default.number_of_hard_hats",
},
)

response = await module__client_with_roads.get(
"/sql/measures/v2",
params={
"metrics": ["default.avg_repair_price", "default.number_of_hard_hats"],
"dimensions": [
"default.dispatcher.company_name",
],
"filters": [],
"preaggregate": True,
},
)
data = response.json()
translated_sql = data[0]
assert translated_sql["grain"] == []
expected_sql = """
WITH default_DOT_repair_orders_fact AS (
SELECT
repair_orders.repair_order_id,
repair_orders.municipality_id,
repair_orders.hard_hat_id,
repair_orders.dispatcher_id,
repair_orders.order_date,
repair_orders.dispatched_date,
repair_orders.required_date,
repair_order_details.discount,
repair_order_details.price,
repair_order_details.quantity,
repair_order_details.repair_type_id,
repair_order_details.price * repair_order_details.quantity AS total_repair_cost,
repair_orders.dispatched_date - repair_orders.order_date AS time_to_dispatch,
repair_orders.dispatched_date - repair_orders.required_date AS dispatch_delay
FROM roads.repair_orders AS repair_orders
JOIN roads.repair_order_details AS repair_order_details
ON repair_orders.repair_order_id = repair_order_details.repair_order_id
),
default_DOT_dispatcher AS (
SELECT
default_DOT_dispatchers.dispatcher_id,
default_DOT_dispatchers.company_name,
default_DOT_dispatchers.phone
FROM roads.dispatchers AS default_DOT_dispatchers
)
SELECT
default_DOT_repair_orders_fact.hard_hat_id default_DOT_repair_orders_fact_DOT_hard_hat_id,
default_DOT_repair_orders_fact.price default_DOT_repair_orders_fact_DOT_price,
default_DOT_dispatcher.company_name default_DOT_dispatcher_DOT_company_name
FROM default_DOT_repair_orders_fact
LEFT JOIN default_DOT_dispatcher
ON default_DOT_repair_orders_fact.dispatcher_id = default_DOT_dispatcher.dispatcher_id
"""
assert str(parse(str(expected_sql))) == str(parse(str(translated_sql["sql"])))
result = duckdb_conn.sql(translated_sql["sql"])
assert set(result.fetchall()) == {
(3, 67253.0, "Pothole Pete"),
(6, 65114.0, "Asphalts R Us"),
(3, 87858.0, "Asphalts R Us"),
(1, 92366.0, "Pothole Pete"),
(1, 63708.0, "Federal Roads Group"),
(4, 73600.0, "Pothole Pete"),
(2, 48919.0, "Federal Roads Group"),
(5, 21083.0, "Federal Roads Group"),
(5, 47857.0, "Asphalts R Us"),
(4, 63918.0, "Asphalts R Us"),
(3, 74555.0, "Asphalts R Us"),
(2, 29684.0, "Federal Roads Group"),
(4, 51594.0, "Pothole Pete"),
(5, 87289.0, "Pothole Pete"),
(7, 53374.0, "Federal Roads Group"),
(8, 76463.0, "Asphalts R Us"),
(8, 54901.0, "Federal Roads Group"),
(6, 68745.0, "Asphalts R Us"),
(5, 66808.0, "Asphalts R Us"),
(4, 27222.0, "Federal Roads Group"),
(6, 62928.0, "Pothole Pete"),
(1, 18497.0, "Pothole Pete"),
(1, 44120.0, "Pothole Pete"),
(5, 97916.0, "Federal Roads Group"),
(9, 70418.0, "Federal Roads Group"),
}

response = await module__client_with_roads.get(
"/sql/measures/v2",
params={
"metrics": ["default.avg_repair_price", "default.num_repair_orders"],
"dimensions": [
"default.dispatcher.company_name",
],
"filters": [],
"preaggregate": True,
},
)
data = response.json()
translated_sql = data[0]
assert translated_sql["grain"] == ["default_DOT_dispatcher_DOT_company_name"]
expected_sql = """
WITH default_DOT_repair_orders_fact AS (
SELECT
repair_orders.repair_order_id,
repair_orders.municipality_id,
repair_orders.hard_hat_id,
repair_orders.dispatcher_id,
repair_orders.order_date,
repair_orders.dispatched_date,
repair_orders.required_date,
repair_order_details.discount,
repair_order_details.price,
repair_order_details.quantity,
repair_order_details.repair_type_id,
repair_order_details.price * repair_order_details.quantity AS total_repair_cost,
repair_orders.dispatched_date - repair_orders.order_date AS time_to_dispatch,
repair_orders.dispatched_date - repair_orders.required_date AS dispatch_delay
FROM roads.repair_orders AS repair_orders
JOIN roads.repair_order_details AS repair_order_details
ON repair_orders.repair_order_id = repair_order_details.repair_order_id
),
default_DOT_dispatcher AS (
SELECT
default_DOT_dispatchers.dispatcher_id,
default_DOT_dispatchers.company_name,
default_DOT_dispatchers.phone
FROM roads.dispatchers AS default_DOT_dispatchers
),
default_DOT_repair_orders_fact_built AS (
SELECT
default_DOT_repair_orders_fact.repair_order_id,
default_DOT_repair_orders_fact.price,
default_DOT_dispatcher.company_name default_DOT_dispatcher_DOT_company_name
FROM default_DOT_repair_orders_fact LEFT JOIN default_DOT_dispatcher ON default_DOT_repair_orders_fact.dispatcher_id = default_DOT_dispatcher.dispatcher_id
)
SELECT
default_DOT_repair_orders_fact_built.default_DOT_dispatcher_DOT_company_name,
COUNT(1) AS count,
SUM(price) AS price_sum_78a5eb43,
COUNT(repair_order_id) AS repair_order_id_count_0b7dfba0
FROM default_DOT_repair_orders_fact_built
GROUP BY default_DOT_repair_orders_fact_built.default_DOT_dispatcher_DOT_company_name
"""
assert str(parse(str(expected_sql))) == str(parse(str(translated_sql["sql"])))
result = duckdb_conn.sql(translated_sql["sql"])
assert set(result.fetchall()) == {
("Pothole Pete", 8, 497647.0, 8),
("Asphalts R Us", 8, 551318.0, 8),
("Federal Roads Group", 9, 467225.0, 9),
}


@pytest.mark.asyncio
async def test_metrics_sql_different_parents(
module__client_with_roads: AsyncClient,

0 comments on commit 51bf1d2

Please sign in to comment.