Skip to content

Commit

Permalink
Merge pull request #1263 from shangyian/add-common-dims-gql
Browse files Browse the repository at this point in the history
GQL common dimensions query + dim attr output shape
  • Loading branch information
shangyian authored Jan 10, 2025
2 parents 3e6c91f + 23bdaf0 commit 14b427e
Show file tree
Hide file tree
Showing 19 changed files with 501 additions and 332 deletions.
3 changes: 1 addition & 2 deletions datajunction-clients/python/tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,13 +883,12 @@ def test_get_dimensions(self, client):
"type": "string",
"node_name": "foo.bar.dispatcher",
"node_display_name": "Dispatcher",
"is_primary_key": False,
"properties": [],
"path": [
"foo.bar.repair_order_details",
"foo.bar.repair_order",
],
"filter_only": False,
"is_hidden": False,
} in result

def test_create_namespace(self, client):
Expand Down
34 changes: 32 additions & 2 deletions datajunction-server/datajunction_server/api/graphql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from datajunction_server.api.graphql.engines import EngineInfo, list_engines
from datajunction_server.api.graphql.resolvers.nodes import find_nodes_by
from datajunction_server.api.graphql.scalars import Connection
from datajunction_server.api.graphql.scalars.node import Node
from datajunction_server.api.graphql.scalars.node import DimensionAttribute, Node
from datajunction_server.models.node import NodeCursor, NodeType
from datajunction_server.utils import get_session, get_settings
from datajunction_server.sql.dag import get_common_dimensions
from datajunction_server.utils import SEPARATOR, get_session, get_settings


async def get_context(
Expand Down Expand Up @@ -151,6 +152,35 @@ async def find_nodes_paginated(
),
)

@strawberry.field(
description="Get common dimensions for one or more nodes",
)
async def common_dimensions(
self,
nodes: Annotated[
Optional[List[str]],
strawberry.argument(
description="A list of nodes to find common dimensions for",
),
] = None,
*,
info: Info,
) -> list[DimensionAttribute]:
"""
Return a list of common dimensions for a set of nodes.
"""
nodes = await find_nodes_by(info, nodes) # type: ignore
dimensions = await get_common_dimensions(info.context["session"], nodes) # type: ignore
return [
DimensionAttribute( # type: ignore
name=dim.name,
attribute=dim.name.split(SEPARATOR)[-1],
properties=dim.properties,
type=dim.type,
)
for dim in dimensions
]


schema = strawberry.Schema(query=Query)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlalchemy.orm import joinedload, selectinload
from strawberry.types import Info

from datajunction_server.api.graphql.scalars.node import NodeName
from datajunction_server.api.graphql.utils import extract_fields
from datajunction_server.database.dimensionlink import DimensionLink
from datajunction_server.database.node import Column, ColumnAttribute
Expand Down Expand Up @@ -54,6 +55,33 @@ async def find_nodes_by(
)


async def get_node_by_name(
info: Info,
name: str,
) -> DBNode | NodeName | None:
"""
Retrieves a node by name. This function also tries to optimize the database
query by only retrieving joined-in fields if they were requested.
"""
session = info.context["session"] # type: ignore
fields = extract_fields(info)
if "name" in fields and len(fields) == 1:
return NodeName(name=name) # type: ignore

options = load_node_options(
fields["nodes"]
if "nodes" in fields
else fields["edges"]["node"]
if "edges" in fields
else fields,
)
return await DBNode.get_by_name(
session,
name=name,
options=options,
)


def load_node_options(fields):
"""
Based on the GraphQL query input fields, builds a list of node load options.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import strawberry
from strawberry.scalars import JSON
from strawberry.types import Info

from datajunction_server.api.graphql.scalars import BigInt
from datajunction_server.api.graphql.scalars.availabilitystate import AvailabilityState
Expand Down Expand Up @@ -68,9 +69,26 @@ class DimensionAttribute: # pylint: disable=too-few-public-methods
"""

name: str
attribute: str
role: str
dimension_node: "NodeRevision"
attribute: str | None
role: str | None = None
properties: list[str]
type: str

_dimension_node: Optional["Node"] = None

@strawberry.field(description="The dimension node this attribute belongs to")
async def dimension_node(self, info: Info) -> "Node":
"""
Lazy load the dimension node when queried.
"""
if self._dimension_node:
return self._dimension_node

# pylint: disable=import-outside-toplevel
from datajunction_server.api.graphql.resolvers.nodes import get_node_by_name

dimension_node_name = self.name.rsplit(".", 1)[0]
return await get_node_by_name(info=info, name=dimension_node_name) # type: ignore


@strawberry.type
Expand Down Expand Up @@ -174,7 +192,9 @@ def cube_dimensions(self, root: "DBNodeRevision") -> List[DimensionAttribute]:
),
attribute=element.name,
role=dimension_to_roles.get(element.name, ""),
dimension_node=node_revision,
_dimension_node=node_revision,
type=element.type,
properties=element.attribute_names(),
)
for element, node_revision in root.cube_elements_with_nodes()
if node_revision and node_revision.type != NodeType.METRIC
Expand Down
6 changes: 6 additions & 0 deletions datajunction-server/datajunction_server/database/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def has_attribute(self, attribute_name: str) -> bool:
for attr in self.attributes # pylint: disable=not-an-iterable
)

def attribute_names(self) -> list[str]:
"""
A list of column attribute names
"""
return [attr.attribute_type.name for attr in self.attributes]

def has_attributes_besides(self, attribute_name: str) -> bool:
"""
Whether the column has any attribute besides the one specified.
Expand Down
3 changes: 1 addition & 2 deletions datajunction-server/datajunction_server/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,7 @@ class DimensionAttributeOutput(BaseModel):
name: str
node_name: Optional[str]
node_display_name: Optional[str]
is_primary_key: bool
is_hidden: bool
properties: list[str] | None
type: str
path: List[str]
filter_only: bool = False
Expand Down
53 changes: 30 additions & 23 deletions datajunction-server/datajunction_server/sql/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,7 @@ def _extract_roles_from_path(join_path) -> str:
name=f"{node_name}.{column_name}{_extract_roles_from_path(join_path)}",
node_name=node_name,
node_display_name=node_display_name,
is_primary_key=(
attribute_types is not None
and ColumnAttributes.PRIMARY_KEY.value in attribute_types
),
is_hidden=(
attribute_types is not None
and ColumnAttributes.HIDDEN.value in attribute_types
),
properties=attribute_types.split(",") if attribute_types else [],
type=str(column_type),
path=[
(path.replace("[", "").replace("]", "")[:-1])
Expand Down Expand Up @@ -580,20 +573,12 @@ async def get_filter_only_dimensions(
name=dim,
node_name=link.dimension.name,
node_display_name=link.dimension.current.display_name,
is_primary_key=(
dim.split(SEPARATOR)[-1]
in {
col.name for col in link.dimension.current.primary_key()
}
),
type=str(column_mapping[dim.split(SEPARATOR)[-1]].type),
path=[upstream.name],
filter_only=True,
is_hidden=(
column_mapping[dim.split(SEPARATOR)[-1]].has_attribute(
ColumnAttributes.HIDDEN.value,
)
),
properties=column_mapping[
dim.split(SEPARATOR)[-1]
].attribute_names(),
)
for dim in link.foreign_keys.values()
],
Expand Down Expand Up @@ -622,7 +607,18 @@ async def get_shared_dimensions(
metric_nodes: List[Node],
) -> List[DimensionAttributeOutput]:
"""
Return a list of dimensions that are common between the nodes.
Return a list of dimensions that are common between the metric nodes.
"""
parents = await get_metric_parents(session, metric_nodes)
return await get_common_dimensions(session, parents)


async def get_metric_parents(
session: AsyncSession,
metric_nodes: list[Node],
) -> list[Node]:
"""
Return a list of parent nodes of the metrics
"""
find_latest_node_revisions = [
and_(
Expand All @@ -645,9 +641,20 @@ async def get_shared_dimensions(
),
)
)
parents = list(set((await session.execute(statement)).scalars().all()))
common = await group_dimensions_by_name(session, parents[0])
for node in parents[1:]:
return list(set((await session.execute(statement)).scalars().all()))


async def get_common_dimensions(session: AsyncSession, nodes: list[Node]):
"""
Return a list of dimensions that are common between the nodes.
"""
metric_nodes = [node for node in nodes if node.type == NodeType.METRIC]
other_nodes = [node for node in nodes if node.type != NodeType.METRIC]
if metric_nodes:
nodes = list(set(other_nodes + await get_metric_parents(session, metric_nodes)))

common = await group_dimensions_by_name(session, nodes[0])
for node in nodes[1:]:
node_dimensions = await group_dimensions_by_name(session, node)

# Merge each set of dimensions based on the name and path
Expand Down
3 changes: 1 addition & 2 deletions datajunction-server/tests/api/dimension_links_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,15 +859,14 @@ async def test_measures_sql_with_reference_dimension_links(
assert dimensions_data == [
{
"filter_only": False,
"is_primary_key": False,
"name": "default.users.registration_country",
"node_display_name": "Users",
"node_name": "default.users",
"path": [
"default.events.user_registration_country",
],
"type": "string",
"is_hidden": False,
"properties": [],
},
]

Expand Down
Loading

0 comments on commit 14b427e

Please sign in to comment.