Skip to content

Commit

Permalink
Merge pull request #1241 from shangyian/gql-measures
Browse files Browse the repository at this point in the history
  • Loading branch information
shangyian authored Dec 7, 2024
2 parents 6a81fc7 + b29b502 commit d7ce96f
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import strawberry

from datajunction_server.models.node import MetricDirection as MetricDirection_
from datajunction_server.sql.decompose import Aggregability as Aggregability_
from datajunction_server.sql.decompose import AggregationRule as AggregationRule_
from datajunction_server.sql.decompose import Measure as Measure_

MetricDirection = strawberry.enum(MetricDirection_)
Aggregability = strawberry.enum(Aggregability_)


@strawberry.type
Expand All @@ -21,6 +25,27 @@ class Unit: # pylint: disable=too-few-public-methods
abbreviation: Optional[str]


@strawberry.experimental.pydantic.type(model=AggregationRule_, all_fields=True)
class AggregationRule: # pylint: disable=missing-class-docstring,too-few-public-methods
...


@strawberry.experimental.pydantic.type(model=Measure_, all_fields=True)
class Measure: # pylint: disable=missing-class-docstring,too-few-public-methods
...


@strawberry.type
class ExtractedMeasures: # pylint: disable=too-few-public-methods
"""
extracted measures from metric
"""

measures: list[Measure]
derived_query: str
derived_expression: str


@strawberry.type
class MetricMetadata: # pylint: disable=too-few-public-methods
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from datajunction_server.api.graphql.scalars.materialization import (
MaterializationConfig,
)
from datajunction_server.api.graphql.scalars.metricmetadata import MetricMetadata
from datajunction_server.api.graphql.scalars.metricmetadata import (
ExtractedMeasures,
MetricMetadata,
)
from datajunction_server.api.graphql.scalars.user import User
from datajunction_server.database.dimensionlink import (
JoinCardinality as JoinCardinality_,
Expand All @@ -23,6 +26,7 @@
from datajunction_server.models.node import NodeMode as NodeMode_
from datajunction_server.models.node import NodeStatus as NodeStatus_
from datajunction_server.models.node import NodeType as NodeType_
from datajunction_server.sql.decompose import extractor

NodeType = strawberry.enum(NodeType_)
NodeStatus = strawberry.enum(NodeStatus_)
Expand Down Expand Up @@ -109,14 +113,28 @@ class NodeRevision:
availability: Optional[AvailabilityState] = None
materializations: Optional[List[MaterializationConfig]] = None

# Only source nodes will have this
# Only source nodes will have these fields
schema_: Optional[str]
table: Optional[str]

# Only metrics will have this field
# Only metrics will have these fields
metric_metadata: Optional[MetricMetadata] = None
required_dimensions: Optional[List[Column]] = None

@strawberry.field
def extracted_measures(self, root: "DBNodeRevision") -> ExtractedMeasures | None:
"""
A list of measures for a metric node
"""
if root.type != NodeType.METRIC:
return None
measures, derived_ast = extractor.extract_measures(root.query)
return ExtractedMeasures( # type: ignore
measures=measures,
derived_query=str(derived_ast),
derived_expression=str(derived_ast.select.projection[0]),
)

# Only cubes will have these fields
@strawberry.field
def cube_metrics(self, root: "DBNodeRevision") -> List["NodeRevision"]:
Expand Down
6 changes: 4 additions & 2 deletions datajunction-server/datajunction_server/models/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class Metric(BaseModel):
incompatible_druid_functions: List[str]

measures: List[Measure]
derived_sql: str
derived_query: str
derived_expression: str

@classmethod
def parse_node(cls, node: Node, dims: List[DimensionAttributeOutput]) -> "Metric":
Expand Down Expand Up @@ -76,7 +77,8 @@ def parse_node(cls, node: Node, dims: List[DimensionAttributeOutput]) -> "Metric
required_dimensions=[dim.name for dim in node.current.required_dimensions],
incompatible_druid_functions=incompatible_druid_functions,
measures=measures,
derived_sql=str(derived_sql).strip(),
derived_query=str(derived_sql).strip(),
derived_expression=str(derived_sql.select.projection[0]).strip(),
)


Expand Down
65 changes: 65 additions & 0 deletions datajunction-server/tests/api/graphql/find_nodes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,11 @@ async def test_find_transform(
cubeDimensions {
name
}
extractedMeasures {
measures {
name
}
}
}
}
}
Expand All @@ -560,6 +565,7 @@ async def test_find_transform(
"name": "default.repair_order_details",
},
],
"extractedMeasures": None,
},
"name": "default.repair_orders_fact",
"type": "TRANSFORM",
Expand Down Expand Up @@ -593,6 +599,18 @@ async def test_find_metric(
requiredDimensions {
name
}
extractedMeasures {
measures {
name
expression
aggregation
rule {
type
}
}
derivedQuery
derivedExpression
}
}
}
}
Expand All @@ -614,6 +632,53 @@ async def test_find_metric(
},
],
"requiredDimensions": [],
"extractedMeasures": {
"measures": [
{
"aggregation": "SUM",
"expression": "rm.completed_repairs",
"name": "rm.completed_repairs_sum_0",
"rule": {
"type": "FULL",
},
},
{
"aggregation": "SUM",
"expression": "rm.total_repairs_dispatched",
"name": "rm.total_repairs_dispatched_sum_1",
"rule": {
"type": "FULL",
},
},
{
"aggregation": "SUM",
"expression": "rm.total_amount_in_region",
"name": "rm.total_amount_in_region_sum_2",
"rule": {
"type": "FULL",
},
},
{
"aggregation": "SUM",
"expression": "na.total_amount_nationwide",
"name": "na.total_amount_nationwide_sum_3",
"rule": {
"type": "FULL",
},
},
],
"derivedQuery": "SELECT (SUM(rm.completed_repairs_sum_0) * 1.0 / "
"SUM(rm.total_repairs_dispatched_sum_1)) * "
"(SUM(rm.total_amount_in_region_sum_2) * 1.0 / "
"SUM(na.total_amount_nationwide_sum_3)) * 100 \n"
" FROM default.regional_level_agg rm CROSS JOIN "
"default.national_level_agg na\n"
"\n",
"derivedExpression": "(SUM(rm.completed_repairs_sum_0) * 1.0 / "
"SUM(rm.total_repairs_dispatched_sum_1)) * "
"(SUM(rm.total_amount_in_region_sum_2) * 1.0 / "
"SUM(na.total_amount_nationwide_sum_3)) * 100",
},
},
"name": "default.regional_repair_efficiency",
"type": "METRIC",
Expand Down
27 changes: 27 additions & 0 deletions datajunction-server/tests/api/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,33 @@ async def test_read_metrics(module__client_with_roads: AsyncClient) -> None:
)
data = response.json()
assert data["incompatible_druid_functions"] == ["IF"]
assert data["measures"] == [
{
"aggregation": "SUM",
"expression": "if(discount > 0.0, 1, 0)",
"name": "discount_sum_0",
"rule": {
"level": None,
"type": "full",
},
},
{
"aggregation": "COUNT",
"expression": "*",
"name": "count_1",
"rule": {
"level": None,
"type": "full",
},
},
]
assert data["derived_query"] == (
"SELECT CAST(sum(discount_sum_0) AS DOUBLE) / SUM(count_1) AS "
"default_DOT_discounted_orders_rate \n FROM default.repair_orders_fact"
)
assert data["derived_expression"] == (
"CAST(sum(discount_sum_0) AS DOUBLE) / SUM(count_1) AS default_DOT_discounted_orders_rate"
)


@pytest.mark.asyncio
Expand Down

0 comments on commit d7ce96f

Please sign in to comment.