diff --git a/src/handlers/dashboard/get_chart_data.py b/src/handlers/dashboard/get_chart_data.py index c98ea07..a68ba2b 100644 --- a/src/handlers/dashboard/get_chart_data.py +++ b/src/handlers/dashboard/get_chart_data.py @@ -17,7 +17,7 @@ logger.setLevel(log_level) -def _get_table_cols(dp_id: str, version: str | None = None) -> list: +def _get_table_cols(dp_id: str) -> list: """Returns the columns associated with a table. Since running an athena query takes a decent amount of time due to queueing @@ -26,11 +26,11 @@ def _get_table_cols(dp_id: str, version: str | None = None) -> list: """ s3_bucket_name = os.environ.get("BUCKET_NAME") - dp_name = dp_id.rsplit("__", 1)[0] - prefix = f"{enums.BucketPath.CSVAGGREGATE.value}/{dp_id.split('__')[0]}/{dp_name}" + study, name, version = dp_id.split("__") + prefix = f"{enums.BucketPath.CSVAGGREGATE.value}/{study}/{study}__{name}" if version is None: version = functions.get_latest_data_package_version(s3_bucket_name, prefix) - s3_key = f"{prefix}/{version}/{dp_name}__aggregate.csv" + s3_key = f"{prefix}/{version}/{study}__{name}__aggregate.csv" s3_client = boto3.client("s3") try: s3_iter = s3_client.get_object( @@ -98,14 +98,16 @@ def _format_payload( if "stratifier" in query_params.keys(): payload["stratifier"] = query_params["stratifier"] counts = {} - for unique_val in df[query_params["column"]]: - df_slice = df[df[query_params["column"]] == unique_val] + for unique_val in df[query_params["column"]].unique(): + column_mask = df[query_params["column"]] == unique_val + df_slice = df[column_mask] df_slice = df_slice.drop(columns=[query_params["stratifier"], query_params["column"]]) counts[unique_val] = int(df_slice[count_col].sum()) payload["counts"] = counts data = [] for unique_strat in df[query_params["stratifier"]].unique(): - df_slice = df[df[query_params["stratifier"]] == unique_strat] + strat_mask = df[query_params["stratifier"]] == unique_strat + df_slice = df[strat_mask] df_slice = df_slice.drop(columns=[query_params["stratifier"]]) rows = df_slice.values.tolist() data.append({"stratifier": unique_strat, "rows": rows}) diff --git a/tests/dashboard/test_get_chart_data.py b/tests/dashboard/test_get_chart_data.py index c69314c..0f6c9c5 100644 --- a/tests/dashboard/test_get_chart_data.py +++ b/tests/dashboard/test_get_chart_data.py @@ -102,8 +102,6 @@ def test_build_query(query_params, filters, path_params, query_str): def test_format_payload(query_params, filters, expected_payload): df = mock_data_frame(filters) payload = get_chart_data._format_payload(df, query_params, filters, "cnt") - print(query_params) - print(payload) assert payload == expected_payload @@ -120,7 +118,7 @@ def test_get_data_cols(mock_bucket): ( "SELECT gender, sum(cnt) as cnt" f'FROM "{TEST_GLUE_DB}"."test_study" ' - "WHERE COALESCE (race) IS NOT Null AND gender IS NOT Null " + "WHERE COALESCE (race) IS NOT NULL AND gender IS NOT NULL " "AND gender LIKE 'female' " "GROUP BY gender", "cnt",