Skip to content

Commit

Permalink
Updated stratifier behavior (#129)
Browse files Browse the repository at this point in the history
* Updated stratifier behavior

* PR feedback
  • Loading branch information
dogversioning authored Oct 17, 2024
1 parent a9cfe5d commit bd98255
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 196 deletions.
50 changes: 33 additions & 17 deletions src/handlers/dashboard/get_chart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
logger.setLevel(log_level)


def _get_table_cols(dp_id: str, version: str | None = None) -> list:
def _get_table_cols(dp_id: str) -> list:
"""Returns the columns associated with a table.
Since running an athena query takes a decent amount of time due to queueing
Expand All @@ -26,11 +26,11 @@ def _get_table_cols(dp_id: str, version: str | None = None) -> list:
"""

s3_bucket_name = os.environ.get("BUCKET_NAME")
dp_name = dp_id.rsplit("__", 1)[0]
prefix = f"{enums.BucketPath.CSVAGGREGATE.value}/{dp_id.split('__')[0]}/{dp_name}"
study, name, version = dp_id.split("__")
prefix = f"{enums.BucketPath.CSVAGGREGATE.value}/{study}/{study}__{name}"
if version is None:
version = functions.get_latest_data_package_version(s3_bucket_name, prefix)
s3_key = f"{prefix}/{version}/{dp_name}__aggregate.csv"
s3_key = f"{prefix}/{version}/{study}__{name}__aggregate.csv"
s3_client = boto3.client("s3")
try:
s3_iter = s3_client.get_object(
Expand All @@ -48,10 +48,11 @@ def _build_query(query_params: dict, filters: list, path_params: dict) -> str:
columns = _get_table_cols(dp_id)
filter_str = filter_config.get_filter_string(filters)
if filter_str != "":
filter_str = f"AND {filter_str}"
filter_str = f"AND {filter_str} "
count_col = next(c for c in columns if c.startswith("cnt"))
columns.remove(count_col)
select_str = f"{query_params['column']}, sum({count_col}) as {count_col}"
strat_str = ""
group_str = f"{query_params['column']}"
# the 'if in' check is meant to handle the case where the selected column is also
# present in the filter logic and has already been removed
Expand All @@ -61,29 +62,33 @@ def _build_query(query_params: dict, filters: list, path_params: dict) -> str:
select_str = f"{query_params['stratifier']}, {select_str}"
group_str = f"{query_params['stratifier']}, {group_str}"
columns.remove(query_params["stratifier"])
strat_str = f'AND {query_params["stratifier"]} IS NOT NULL '
if len(columns) > 0:
coalesce_str = (
f"WHERE COALESCE (cast({' AS VARCHAR), cast('.join(columns)} AS VARCHAR)) "
"IS NOT NULL AND"
"IS NOT NULL AND "
)
else:
coalesce_str = "WHERE"
coalesce_str = "WHERE "
query_str = (
f"SELECT {select_str} " # nosec # noqa: S608
f"FROM \"{os.environ.get('GLUE_DB_NAME')}\".\"{dp_id}\" "
f"{coalesce_str} "
f"{query_params['column']} IS NOT NULL {filter_str} "
f"{coalesce_str}"
f"{query_params['column']} IS NOT NULL "
f"{filter_str}"
f"{strat_str}"
f"GROUP BY {group_str} "
)
if "stratifier" in query_params.keys():
query_str += f"ORDER BY {query_params['stratifier']}, {query_params['column']}"
else:
query_str += f"ORDER BY {query_params['column']}"
logging.debug(query_str)
return query_str
return query_str, count_col


def _format_payload(df: pandas.DataFrame, query_params: dict, filters: list) -> dict:
def _format_payload(
df: pandas.DataFrame, query_params: dict, filters: list, count_col: str
) -> dict:
"""Coerces query results into the return format defined by the dashboard"""
payload = {}
payload["column"] = query_params["column"]
Expand All @@ -92,13 +97,22 @@ def _format_payload(df: pandas.DataFrame, query_params: dict, filters: list) ->
payload["totalCount"] = int(df["cnt"].sum())
if "stratifier" in query_params.keys():
payload["stratifier"] = query_params["stratifier"]
counts = {}
for unique_val in df[query_params["column"]].unique():
column_mask = df[query_params["column"]] == unique_val
df_slice = df[column_mask]
df_slice = df_slice.drop(columns=[query_params["stratifier"], query_params["column"]])
counts[unique_val] = int(df_slice[count_col].sum())
payload["counts"] = counts
data = []
for unique_val in df[query_params["stratifier"]]:
df_slice = df[df[query_params["stratifier"]] == unique_val]
for unique_strat in df[query_params["stratifier"]].unique():
strat_mask = df[query_params["stratifier"]] == unique_strat
df_slice = df[strat_mask]
df_slice = df_slice.drop(columns=[query_params["stratifier"]])
rows = df_slice.values.tolist()
data.append({"stratifier": unique_val, "rows": rows})
data.append({"stratifier": unique_strat, "rows": rows})
payload["data"] = data

else:
rows = df.values.tolist()
payload["data"] = [{"rows": rows}]
Expand All @@ -112,17 +126,19 @@ def chart_data_handler(event, context):
del context
query_params = event["queryStringParameters"]
filters = event["multiValueQueryStringParameters"].get("filter", [])
if "filter" in query_params and filters == []:
filters = [query_params["filter"]]
path_params = event["pathParameters"]
boto3.setup_default_session(region_name="us-east-1")
try:
query = _build_query(query_params, filters, path_params)
query, count_col = _build_query(query_params, filters, path_params)
df = awswrangler.athena.read_sql_query(
query,
database=os.environ.get("GLUE_DB_NAME"),
s3_output=f"s3://{os.environ.get('BUCKET_NAME')}/awswrangler",
workgroup=os.environ.get("WORKGROUP_NAME"),
)
res = _format_payload(df, query_params, filters)
res = _format_payload(df, query_params, filters, count_col)
res = functions.http_response(200, res)
except errors.AggregatorS3Error:
# while the API is publicly accessible, we've been asked to not pass
Expand Down
2 changes: 2 additions & 0 deletions template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ Resources:
Required: true
- method.request.querystring.filters:
Required: false
- method.request.querystring.stratifier:
Required: false
Policies:
- S3CrudPolicy:
BucketName: !Ref AggregatorBucket
Expand Down
23 changes: 14 additions & 9 deletions tests/dashboard/test_get_chart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ def mock_data_frame(filter_param):
[],
{"data_package_id": "test_study"},
f'SELECT gender, sum(cnt) as cnt FROM "{TEST_GLUE_DB}"."test_study" '
"WHERE COALESCE (cast(race AS VARCHAR)) IS NOT NULL AND gender IS NOT NULL "
"WHERE COALESCE (cast(race AS VARCHAR)) IS NOT NULL AND gender IS NOT NULL "
"GROUP BY gender ORDER BY gender",
),
(
{"column": "gender", "stratifier": "race"},
[],
{"data_package_id": "test_study"},
f'SELECT race, gender, sum(cnt) as cnt FROM "{TEST_GLUE_DB}"."test_study" '
"WHERE gender IS NOT NULL "
"WHERE gender IS NOT NULL "
"AND race IS NOT NULL "
"GROUP BY race, gender ORDER BY race, gender",
),
(
Expand All @@ -63,12 +64,13 @@ def mock_data_frame(filter_param):
f'SELECT race, gender, sum(cnt) as cnt FROM "{TEST_GLUE_DB}"."test_study" '
"WHERE gender IS NOT NULL "
"AND gender LIKE 'female' "
"AND race IS NOT NULL "
"GROUP BY race, gender ORDER BY race, gender",
),
],
)
def test_build_query(query_params, filters, path_params, query_str):
query = get_chart_data._build_query(query_params, filters, path_params)
query, _ = get_chart_data._build_query(query_params, filters, path_params)
assert query == query_str


Expand Down Expand Up @@ -99,7 +101,7 @@ def test_build_query(query_params, filters, path_params, query_str):
)
def test_format_payload(query_params, filters, expected_payload):
df = mock_data_frame(filters)
payload = get_chart_data._format_payload(df, query_params, filters)
payload = get_chart_data._format_payload(df, query_params, filters, "cnt")
assert payload == expected_payload


Expand All @@ -113,11 +115,14 @@ def test_get_data_cols(mock_bucket):
@mock.patch(
"src.handlers.dashboard.get_chart_data._build_query",
lambda query_params, filters, path_params: (
"SELECT gender, sum(cnt) as cnt"
f'FROM "{TEST_GLUE_DB}"."test_study" '
"WHERE COALESCE (race) IS NOT Null AND gender IS NOT Null "
"AND gender LIKE 'female' "
"GROUP BY gender",
(
"SELECT gender, sum(cnt) as cnt"
f'FROM "{TEST_GLUE_DB}"."test_study" '
"WHERE COALESCE (race) IS NOT NULL AND gender IS NOT NULL "
"AND gender LIKE 'female' "
"GROUP BY gender",
"cnt",
)
),
)
@mock.patch(
Expand Down
3 changes: 3 additions & 0 deletions tests/test_data/cube_response_filtered_stratified.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
"rowCount": 5,
"totalCount": 33839,
"stratifier": "race",
"counts": {
"female": 33839
},
"data": [
{
"stratifier": "",
Expand Down
Loading

0 comments on commit bd98255

Please sign in to comment.