Skip to content

Commit

Permalink
Merge pull request #1227 from DataJunction/issue/1226
Browse files Browse the repository at this point in the history
Better input validation on get_sql_for_metrics().
  • Loading branch information
agorajek authored Nov 25, 2024
2 parents e564894 + f520042 commit 047dd3d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
16 changes: 16 additions & 0 deletions datajunction-server/datajunction_server/api/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import logging
from collections import OrderedDict
from http import HTTPStatus
from typing import List, Optional, Tuple, cast

from fastapi import BackgroundTasks, Depends, Query
Expand All @@ -18,6 +19,7 @@
from datajunction_server.database import Engine, Node
from datajunction_server.database.queryrequest import QueryBuildType, QueryRequest
from datajunction_server.database.user import User
from datajunction_server.errors import DJInvalidInputException
from datajunction_server.internal.access.authentication.http import SecureAPIRouter
from datajunction_server.internal.access.authorization import validate_access
from datajunction_server.internal.engines import get_engine
Expand Down Expand Up @@ -378,6 +380,20 @@ async def get_sql_for_metrics( # pylint: disable=too-many-locals
base_verb=access.ResourceRequestVerb.READ,
)

# make sure all metrics exist and have correct node type
nodes = [
await Node.get_by_name(session, node, raise_if_not_exists=True)
for node in metrics
]
non_metric_nodes = [node for node in nodes if node and node.type != NodeType.METRIC]

if non_metric_nodes:
raise DJInvalidInputException(
message="All nodes must be of metric type, but some are not: "
f"{', '.join([f'{n.name} ({n.type})' for n in non_metric_nodes])} .",
http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
)

if query_request := await QueryRequest.get_query_request(
session,
nodes=metrics,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def to_versioned_query_request( # pylint: disable=too-many-locals
) -> Dict[str, List[str]]:
"""
Prepare for searching in saved query requests by appending version numbers to all nodes
being worked with, from the nodes we're retrieving the queries of to the
being worked with.
"""
nodes_objs = [
await Node.get_by_name(
Expand Down
16 changes: 16 additions & 0 deletions datajunction-server/tests/api/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2897,6 +2897,22 @@ async def test_get_sql_for_metrics_failures(module__client_with_examples: AsyncC
)
assert response.status_code == 200

# Getting sql for metric with non-metric node
response = await module__client_with_examples.get(
"/sql/",
params={
"metrics": ["default.repair_orders"],
"dimensions": [],
"filters": [],
},
)
assert response.status_code == 422
assert response.json() == {
"message": "All nodes must be of metric type, but some are not: default.repair_orders (source) .",
"errors": [],
"warnings": [],
}


@pytest.mark.asyncio
async def test_get_sql_for_metrics_no_access(module__client_with_examples: AsyncClient):
Expand Down

0 comments on commit 047dd3d

Please sign in to comment.