From eb2fff4589932fa876391b88fd36c75bc625d750 Mon Sep 17 00:00:00 2001 From: matt garber Date: Tue, 27 Feb 2024 15:29:33 -0500 Subject: [PATCH] CSV file download API (#114) * CSV file download API * PR feedback, removed unused pylint excepts * remove parquet, looping for >1000 results * PR feedback * addded column-descriptions header --- pyproject.toml | 3 +- .../migrations/migration.002.column_types.py | 86 ++++++++++ src/handlers/dashboard/get_chart_data.py | 8 +- src/handlers/dashboard/get_csv.py | 122 ++++++++++++++ src/handlers/dashboard/get_data_packages.py | 6 +- src/handlers/dashboard/get_metadata.py | 6 +- src/handlers/dashboard/get_study_periods.py | 6 +- src/handlers/shared/awswrangler_functions.py | 2 +- src/handlers/shared/decorators.py | 2 +- src/handlers/shared/enums.py | 19 ++- src/handlers/shared/functions.py | 144 ++++++++++++----- .../site_upload/api_gateway_authorizer.py | 4 +- src/handlers/site_upload/cache_api.py | 6 +- src/handlers/site_upload/fetch_upload_url.py | 6 +- src/handlers/site_upload/powerset_merge.py | 145 +++++++++++------ src/handlers/site_upload/process_upload.py | 86 +++++----- src/handlers/site_upload/study_period.py | 55 ++++--- template.yaml | 77 +++++++++ tests/conftest.py | 80 ++++++---- tests/dashboard/test_get_chart_data.py | 2 +- tests/dashboard/test_get_csv.py | 150 ++++++++++++++++++ tests/dashboard/test_get_metadata.py | 47 +++--- tests/dashboard/test_get_study_periods.py | 3 +- tests/dashboard/test_get_subscriptions.py | 2 +- tests/{utils.py => mock_utils.py} | 22 ++- .../test_api_gateway_authorizer.py | 4 +- tests/site_upload/test_cache_api.py | 2 +- tests/site_upload/test_fetch_upload_url.py | 2 +- tests/site_upload/test_powerset_merge.py | 88 +++++----- tests/site_upload/test_process_upload.py | 3 +- tests/site_upload/test_study_period.py | 2 +- 31 files changed, 905 insertions(+), 285 deletions(-) create mode 100644 scripts/migrations/migration.002.column_types.py create mode 100644 src/handlers/dashboard/get_csv.py create mode 100644 tests/dashboard/test_get_csv.py rename tests/{utils.py => mock_utils.py} (88%) diff --git a/pyproject.toml b/pyproject.toml index 7a67175..ae49140 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,8 @@ dependencies= [ "arrow >=1.2.3", "awswrangler >=3.5, <4", "boto3", - "pandas >=2, <3" + "pandas >=2, <3", + "rich", ] authors = [ { name="Matt Garber", email="matthew.garber@childrens.harvard.edu" }, diff --git a/scripts/migrations/migration.002.column_types.py b/scripts/migrations/migration.002.column_types.py new file mode 100644 index 0000000..6651213 --- /dev/null +++ b/scripts/migrations/migration.002.column_types.py @@ -0,0 +1,86 @@ +""" Adds a new metadata type, column_types """ + +import argparse +import io +import json + +import boto3 +import pandas +from rich import progress + + +def get_csv_column_datatypes(dtypes): + """helper for generating column type for dashboard API""" + column_dict = {} + for column in dtypes.index: + if column.endswith("year"): + column_dict[column] = "year" + elif column.endswith("month"): + column_dict[column] = "month" + elif column.endswith("week"): + column_dict[column] = "week" + elif column.endswith("day") or str(dtypes[column]) == "datetime64": + column_dict[column] = "day" + elif "cnt" in column or str(dtypes[column]) in ( + "Int8", + "Int16", + "Int32", + "Int64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + ): + column_dict[column] = "integer" + elif str(dtypes[column]) in ("Float32", "Float64"): + column_dict[column] = "float" + elif str(dtypes[column]) == "boolean": + column_dict[column] = "float" + else: + column_dict[column] = "string" + return column_dict + + +def _put_s3_data(key: str, bucket_name: str, client, data: dict) -> None: + """Convenience class for writing a dict to S3""" + b_data = io.BytesIO(json.dumps(data).encode()) + client.upload_fileobj(Bucket=bucket_name, Key=key, Fileobj=b_data) + + +def create_column_type_metadata(bucket: str): + """creates a new metadata dict for column types. + + By design, this will replaces an existing column type dict if one already exists. + """ + client = boto3.client("s3") + res = client.list_objects_v2(Bucket=bucket, Prefix="aggregates/") + contents = res["Contents"] + output = {} + for resource in progress.track(contents): + dirs = resource["Key"].split("/") + study = dirs[1] + subscription = dirs[2].split("__")[1] + version = dirs[3] + bytes_buffer = io.BytesIO() + client.download_fileobj( + Bucket=bucket, Key=resource["Key"], Fileobj=bytes_buffer + ) + df = pandas.read_parquet(bytes_buffer) + type_dict = get_csv_column_datatypes(df.dtypes) + filename = f"{resource['Key'].split('/')[-1].split('.')[0]}.csv" + output.setdefault(study, {}) + output[study].setdefault(subscription, {}) + output[study][subscription].setdefault(version, {}) + output[study][subscription][version]["columns"] = type_dict + output[study][subscription][version]["filename"] = filename + # print(json.dumps(output, indent=2)) + _put_s3_data("metadata/column_types.json", bucket, client, output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Creates column types for existing aggregates. """ + ) + parser.add_argument("-b", "--bucket", help="bucket name") + args = parser.parse_args() + create_column_type_metadata(args.bucket) diff --git a/src/handlers/dashboard/get_chart_data.py b/src/handlers/dashboard/get_chart_data.py index ba2bd47..db89b1c 100644 --- a/src/handlers/dashboard/get_chart_data.py +++ b/src/handlers/dashboard/get_chart_data.py @@ -7,10 +7,10 @@ import boto3 import pandas -from ..dashboard.filter_config import get_filter_string -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath -from ..shared.functions import get_latest_data_package_version, http_response +from src.handlers.dashboard.filter_config import get_filter_string +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import BucketPath +from src.handlers.shared.functions import get_latest_data_package_version, http_response def _get_table_cols(table_name: str, version: str | None = None) -> list: diff --git a/src/handlers/dashboard/get_csv.py b/src/handlers/dashboard/get_csv.py new file mode 100644 index 0000000..479335c --- /dev/null +++ b/src/handlers/dashboard/get_csv.py @@ -0,0 +1,122 @@ +import os + +import boto3 +import botocore + +from src.handlers.shared import decorators, enums, functions + + +def _format_and_validate_key( + s3_client, + s3_bucket_name: str, + study: str, + subscription: str, + version: str, + filename: str, + site: str | None = None, +): + """Creates S3 key from url params""" + if site is not None: + key = f"last_valid/{study}/{study}__{subscription}/{site}/{version}/{filename}" + else: + key = f"csv_aggregates/{study}/{study}__{subscription}/{version}/{filename}" + try: + s3_client.head_object(Bucket=s3_bucket_name, Key=key) + return key + except botocore.exceptions.ClientError as e: + raise OSError(f"No object found at key {key}") from e + + +def _get_column_types( + s3_client, + s3_bucket_name: str, + study: str, + subscription: str, + version: str, + **kwargs, +) -> dict: + """Gets column types from the metadata store for a given subscription""" + types_metadata = functions.read_metadata( + s3_client, + s3_bucket_name, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) + try: + return types_metadata[study][subscription][version][ + enums.ColumnTypesKeys.COLUMNS.value + ] + except KeyError: + return {} + + +@decorators.generic_error_handler(msg="Error retrieving chart data") +def get_csv_handler(event, context): + """manages event from dashboard api call and creates a temporary URL""" + del context + s3_bucket_name = os.environ.get("BUCKET_NAME") + s3_client = boto3.client("s3") + key = _format_and_validate_key(s3_client, s3_bucket_name, **event["pathParameters"]) + types = _get_column_types(s3_client, s3_bucket_name, **event["pathParameters"]) + presign_url = s3_client.generate_presigned_url( + "get_object", + Params={ + "Bucket": s3_bucket_name, + "Key": key, + "ResponseContentType": "text/csv", + }, + ExpiresIn=600, + ) + extra_headers = { + "Location": presign_url, + "x-column-names": ",".join(key for key in types.keys()), + "x-column-types": ",".join(key for key in types.values()), + # TODO: add data to x-column-descriptions once a source for column descriptions + # has been established + "x-column-descriptions": "", + } + res = functions.http_response(302, "", extra_headers=extra_headers) + return res + + +@decorators.generic_error_handler(msg="Error retrieving csv data") +def get_csv_list_handler(event, context): + """manages event from dashboard api call and creates a temporary URL""" + del context + s3_bucket_name = os.environ.get("BUCKET_NAME") + s3_client = boto3.client("s3") + if event["path"].startswith("/last_valid"): + key_prefix = "last_valid" + url_prefix = "last_valid" + elif event["path"].startswith("/aggregates"): + key_prefix = "csv_aggregates" + url_prefix = "aggregates" + else: + raise Exception("Unexpected url encountered") + + urls = [] + s3_objs = s3_client.list_objects_v2(Bucket=s3_bucket_name, Prefix=key_prefix) + if s3_objs["KeyCount"] == 0: + return functions.http_response(200, urls) + while True: + for obj in s3_objs["Contents"]: + if not obj["Key"].endswith(".csv"): + continue + key_parts = obj["Key"].split("/") + study = key_parts[1] + subscription = key_parts[2].split("__")[1] + version = key_parts[-2] + filename = key_parts[-1] + site = key_parts[3] if url_prefix == "last_valid" else None + url_parts = [url_prefix, study, subscription, version, filename] + if url_prefix == "last_valid": + url_parts.insert(3, site) + urls.append("/".join(url_parts)) + if not s3_objs["IsTruncated"]: + break + s3_objs = s3_client.list_objects_v2( + Bucket=s3_bucket_name, + Prefix=key_prefix, + ContinuationToken=s3_objs["NextContinuationToken"], + ) + res = functions.http_response(200, urls) + return res diff --git a/src/handlers/dashboard/get_data_packages.py b/src/handlers/dashboard/get_data_packages.py index c3c349b..423425b 100644 --- a/src/handlers/dashboard/get_data_packages.py +++ b/src/handlers/dashboard/get_data_packages.py @@ -3,9 +3,9 @@ import os -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath, JsonFilename -from ..shared.functions import get_s3_json_as_dict, http_response +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import BucketPath, JsonFilename +from src.handlers.shared.functions import get_s3_json_as_dict, http_response @generic_error_handler(msg="Error retrieving data packages") diff --git a/src/handlers/dashboard/get_metadata.py b/src/handlers/dashboard/get_metadata.py index fa6c2d0..52600ee 100644 --- a/src/handlers/dashboard/get_metadata.py +++ b/src/handlers/dashboard/get_metadata.py @@ -4,8 +4,8 @@ import boto3 -from ..shared.decorators import generic_error_handler -from ..shared.functions import http_response, read_metadata +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.functions import http_response, read_metadata @generic_error_handler(msg="Error retrieving metadata") @@ -22,5 +22,7 @@ def metadata_handler(event, context): metadata = metadata[params["study"]] if "data_package" in params: metadata = metadata[params["data_package"]] + if "version" in params: + metadata = metadata[params["version"]] res = http_response(200, metadata) return res diff --git a/src/handlers/dashboard/get_study_periods.py b/src/handlers/dashboard/get_study_periods.py index c275891..89b2100 100644 --- a/src/handlers/dashboard/get_study_periods.py +++ b/src/handlers/dashboard/get_study_periods.py @@ -4,9 +4,9 @@ import boto3 -from ..shared.decorators import generic_error_handler -from ..shared.enums import JsonFilename -from ..shared.functions import http_response, read_metadata +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import JsonFilename +from src.handlers.shared.functions import http_response, read_metadata @generic_error_handler(msg="Error retrieving study period") diff --git a/src/handlers/shared/awswrangler_functions.py b/src/handlers/shared/awswrangler_functions.py index e64ca54..52ec1f2 100644 --- a/src/handlers/shared/awswrangler_functions.py +++ b/src/handlers/shared/awswrangler_functions.py @@ -1,7 +1,7 @@ """ functions specifically requiring AWSWranger, which requires a lambda layer""" import awswrangler -from .enums import BucketPath +from src.handlers.shared.enums import BucketPath def get_s3_data_package_list( diff --git a/src/handlers/shared/decorators.py b/src/handlers/shared/decorators.py index 4d2ba67..36c1e1d 100644 --- a/src/handlers/shared/decorators.py +++ b/src/handlers/shared/decorators.py @@ -3,7 +3,7 @@ import functools import logging -from .functions import http_response +from src.handlers.shared.functions import http_response def generic_error_handler(msg="Internal server error"): diff --git a/src/handlers/shared/enums.py b/src/handlers/shared/enums.py index 1084d51..65dae2d 100644 --- a/src/handlers/shared/enums.py +++ b/src/handlers/shared/enums.py @@ -1,8 +1,8 @@ """Enums shared across lambda functions""" -from enum import Enum +import enum -class BucketPath(Enum): +class BucketPath(enum.Enum): """stores root level buckets for managing data processing state""" ADMIN = "admin" @@ -18,15 +18,24 @@ class BucketPath(Enum): UPLOAD = "site_upload" -class JsonFilename(Enum): +class ColumnTypesKeys(enum.Enum): + """stores names of expected keys in the study period metadata dictionary""" + + COLUMN_TYPES_FORMAT_VERSION = "column_types_format_version" + COLUMNS = "columns" + LAST_DATA_UPDATE = "last_data_update" + + +class JsonFilename(enum.Enum): """stores names of expected kinds of persisted S3 JSON files""" + COLUMN_TYPES = "column_types" TRANSACTIONS = "transactions" DATA_PACKAGES = "data_packages" STUDY_PERIODS = "study_periods" -class TransactionKeys(Enum): +class TransactionKeys(enum.Enum): """stores names of expected keys in the transaction dictionary""" TRANSACTION_FORMAT_VERSION = "transaction_format_version" @@ -37,7 +46,7 @@ class TransactionKeys(Enum): DELETED = "deleted" -class StudyPeriodMetadataKeys(Enum): +class StudyPeriodMetadataKeys(enum.Enum): """stores names of expected keys in the study period metadata dictionary""" STUDY_PERIOD_FORMAT_VERSION = "study_period_format_version" diff --git a/src/handlers/shared/functions.py b/src/handlers/shared/functions.py index 5a5e5a8..11c8c66 100644 --- a/src/handlers/shared/functions.py +++ b/src/handlers/shared/functions.py @@ -6,26 +6,34 @@ import boto3 -from .enums import BucketPath, JsonFilename, StudyPeriodMetadataKeys, TransactionKeys +from src.handlers.shared import enums TRANSACTION_METADATA_TEMPLATE = { - TransactionKeys.TRANSACTION_FORMAT_VERSION.value: "2", - TransactionKeys.LAST_UPLOAD.value: None, - TransactionKeys.LAST_DATA_UPDATE.value: None, - TransactionKeys.LAST_AGGREGATION.value: None, - TransactionKeys.LAST_ERROR.value: None, - TransactionKeys.DELETED.value: None, + enums.TransactionKeys.TRANSACTION_FORMAT_VERSION.value: "2", + enums.TransactionKeys.LAST_UPLOAD.value: None, + enums.TransactionKeys.LAST_DATA_UPDATE.value: None, + enums.TransactionKeys.LAST_AGGREGATION.value: None, + enums.TransactionKeys.LAST_ERROR.value: None, + enums.TransactionKeys.DELETED.value: None, } STUDY_PERIOD_METADATA_TEMPLATE = { - StudyPeriodMetadataKeys.STUDY_PERIOD_FORMAT_VERSION.value: "2", - StudyPeriodMetadataKeys.EARLIEST_DATE.value: None, - StudyPeriodMetadataKeys.LATEST_DATE.value: None, - StudyPeriodMetadataKeys.LAST_DATA_UPDATE.value: None, + enums.StudyPeriodMetadataKeys.STUDY_PERIOD_FORMAT_VERSION.value: "2", + enums.StudyPeriodMetadataKeys.EARLIEST_DATE.value: None, + enums.StudyPeriodMetadataKeys.LATEST_DATE.value: None, + enums.StudyPeriodMetadataKeys.LAST_DATA_UPDATE.value: None, +} + +COLUMN_TYPES_METADATA_TEMPLATE = { + enums.ColumnTypesKeys.COLUMN_TYPES_FORMAT_VERSION.value: "1", + enums.ColumnTypesKeys.COLUMNS.value: None, + enums.ColumnTypesKeys.LAST_DATA_UPDATE.value: None, } -def http_response(status: int, body: str, allow_cors: bool = False) -> dict: +def http_response( + status: int, body: str, allow_cors: bool = False, extra_headers: dict | None = None +) -> dict: """Generates the payload AWS lambda expects as a return value""" headers = {"Content-Type": "application/json"} if allow_cors: @@ -36,6 +44,8 @@ def http_response(status: int, body: str, allow_cors: bool = False) -> dict: "Access-Control-Allow-Methods": "GET", } ) + if extra_headers: + headers.update(extra_headers) return { "isBase64Encoded": False, "statusCode": status, @@ -49,17 +59,20 @@ def http_response(status: int, body: str, allow_cors: bool = False) -> dict: def check_meta_type(meta_type: str) -> None: """helper for ensuring specified metadata types""" - types = [item.value for item in JsonFilename] + types = [item.value for item in enums.JsonFilename] if meta_type not in types: raise ValueError("invalid metadata type specified") def read_metadata( - s3_client, s3_bucket_name: str, meta_type: str = JsonFilename.TRANSACTIONS.value + s3_client, + s3_bucket_name: str, + *, + meta_type: str = enums.JsonFilename.TRANSACTIONS.value, ) -> dict: """Reads transaction information from an s3 bucket as a dictionary""" check_meta_type(meta_type) - s3_path = f"{BucketPath.META.value}/{meta_type}.json" + s3_path = f"{enums.BucketPath.META.value}/{meta_type}.json" res = s3_client.list_objects_v2(Bucket=s3_bucket_name, Prefix=s3_path) if "Contents" in res: res = s3_client.get_object(Bucket=s3_bucket_name, Key=s3_path) @@ -70,53 +83,76 @@ def read_metadata( def update_metadata( + *, metadata: dict, - site: str, study: str, data_package: str, version: str, target: str, + site: str | None = None, dt: datetime | None = None, - meta_type: str = JsonFilename.TRANSACTIONS.value, + value: str | list | None = None, + meta_type: str | None = enums.JsonFilename.TRANSACTIONS.value, ): """Safely updates items in metadata dictionary - It's assumed that, other than the version field itself, every item in one - of these metadata dicts is a datetime corresponding to an S3 event timestamp + It's assumed that, other than the version/column/type fields, every item in one + of these metadata dicts is a ISO date string corresponding to an S3 event timestamp. + + TODO: if we have other cases of non-datetime metadata, consider breaking this + function into two, one for updating datetimes and one for updating values """ check_meta_type(meta_type) - if meta_type == JsonFilename.TRANSACTIONS.value: - site_metadata = metadata.setdefault(site, {}) - study_metadata = site_metadata.setdefault(study, {}) - data_package_metadata = study_metadata.setdefault(data_package, {}) - data_version_metadata = data_package_metadata.setdefault( - version, TRANSACTION_METADATA_TEMPLATE - ) - dt = dt or datetime.now(timezone.utc) - data_version_metadata[target] = dt.isoformat() - elif meta_type == JsonFilename.STUDY_PERIODS.value: - site_metadata = metadata.setdefault(site, {}) - study_period_metadata = site_metadata.setdefault(study, {}) - data_version_metadata = study_period_metadata.setdefault( - version, STUDY_PERIOD_METADATA_TEMPLATE - ) - dt = dt or datetime.now(timezone.utc) - data_version_metadata[target] = dt.isoformat() + match meta_type: + case enums.JsonFilename.TRANSACTIONS.value: + site_metadata = metadata.setdefault(site, {}) + study_metadata = site_metadata.setdefault(study, {}) + data_package_metadata = study_metadata.setdefault(data_package, {}) + data_version_metadata = data_package_metadata.setdefault( + version, TRANSACTION_METADATA_TEMPLATE + ) + dt = dt or datetime.now(timezone.utc) + data_version_metadata[target] = dt.isoformat() + case enums.JsonFilename.STUDY_PERIODS.value: + site_metadata = metadata.setdefault(site, {}) + study_period_metadata = site_metadata.setdefault(study, {}) + data_version_metadata = study_period_metadata.setdefault( + version, STUDY_PERIOD_METADATA_TEMPLATE + ) + dt = dt or datetime.now(timezone.utc) + data_version_metadata[target] = dt.isoformat() + case enums.JsonFilename.COLUMN_TYPES.value: + study_metadata = metadata.setdefault(study, {}) + data_package_metadata = study_metadata.setdefault(data_package, {}) + data_version_metadata = data_package_metadata.setdefault( + version, COLUMN_TYPES_METADATA_TEMPLATE + ) + if target == enums.ColumnTypesKeys.COLUMNS.value: + data_version_metadata[target] = value + else: + dt = dt or datetime.now(timezone.utc) + data_version_metadata[target] = dt.isoformat() + # Should only be hit if you add a new JSON dict and forget to add it + # to this function + case _: + raise OSError(f"{meta_type} does not have a handler for updates.") return metadata def write_metadata( + *, s3_client, s3_bucket_name: str, metadata: dict, - meta_type: str = JsonFilename.TRANSACTIONS.value, + meta_type: str = enums.JsonFilename.TRANSACTIONS.value, ) -> None: """Writes transaction info from ∏a dictionary to an s3 bucket metadata location""" check_meta_type(meta_type) + s3_client.put_object( Bucket=s3_bucket_name, - Key=f"{BucketPath.META.value}/{meta_type}.json", + Key=f"{enums.BucketPath.META.value}/{meta_type}.json", Body=json.dumps(metadata), ) @@ -181,3 +217,35 @@ def get_latest_data_package_version(bucket, prefix): if highest_ver is None: logging.error("No data package versions found for %s", prefix) return highest_ver + + +def get_csv_column_datatypes(dtypes): + """helper for generating column type for dashboard API""" + column_dict = {} + for column in dtypes.index: + if column.endswith("year"): + column_dict[column] = "year" + elif column.endswith("month"): + column_dict[column] = "month" + elif column.endswith("week"): + column_dict[column] = "week" + elif column.endswith("day") or str(dtypes[column]) == "datetime64": + column_dict[column] = "day" + elif column.startswith("cnt") or str(dtypes[column]) in ( + "Int8", + "Int16", + "Int32", + "Int64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + ): + column_dict[column] = "integer" + elif str(dtypes[column]) in ("Float32", "Float64"): + column_dict[column] = "float" + elif str(dtypes[column]) == "boolean": + column_dict[column] = "float" + else: + column_dict[column] = "string" + return column_dict diff --git a/src/handlers/site_upload/api_gateway_authorizer.py b/src/handlers/site_upload/api_gateway_authorizer.py index c1b090e..76840f4 100644 --- a/src/handlers/site_upload/api_gateway_authorizer.py +++ b/src/handlers/site_upload/api_gateway_authorizer.py @@ -7,8 +7,8 @@ import os import re -from ..shared.enums import BucketPath -from ..shared.functions import get_s3_json_as_dict +from src.handlers.shared.enums import BucketPath +from src.handlers.shared.functions import get_s3_json_as_dict class AuthError(Exception): diff --git a/src/handlers/site_upload/cache_api.py b/src/handlers/site_upload/cache_api.py index f42a27f..362c4fe 100644 --- a/src/handlers/site_upload/cache_api.py +++ b/src/handlers/site_upload/cache_api.py @@ -6,9 +6,9 @@ import awswrangler import boto3 -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath, JsonFilename -from ..shared.functions import http_response +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import BucketPath, JsonFilename +from src.handlers.shared.functions import http_response def cache_api_data(s3_client, s3_bucket_name: str, db: str, target: str) -> None: diff --git a/src/handlers/site_upload/fetch_upload_url.py b/src/handlers/site_upload/fetch_upload_url.py index d532cde..a1d95d6 100644 --- a/src/handlers/site_upload/fetch_upload_url.py +++ b/src/handlers/site_upload/fetch_upload_url.py @@ -6,9 +6,9 @@ import boto3 import botocore.exceptions -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath -from ..shared.functions import get_s3_json_as_dict, http_response +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import BucketPath +from src.handlers.shared.functions import get_s3_json_as_dict, http_response def create_presigned_post( diff --git a/src/handlers/site_upload/powerset_merge.py b/src/handlers/site_upload/powerset_merge.py index 71f4b2f..4fcc283 100644 --- a/src/handlers/site_upload/powerset_merge.py +++ b/src/handlers/site_upload/powerset_merge.py @@ -1,27 +1,17 @@ """ Lambda for performing joins of site count data """ import csv +import datetime import logging import os import traceback -from datetime import datetime, timezone import awswrangler import boto3 +import numpy import pandas -from numpy import nan from pandas.core.indexes.range import RangeIndex -from ..shared.awswrangler_functions import get_s3_data_package_list -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath, TransactionKeys -from ..shared.functions import ( - get_s3_site_filename_suffix, - http_response, - move_s3_file, - read_metadata, - update_metadata, - write_metadata, -) +from src.handlers.shared import awswrangler_functions, decorators, enums, functions class MergeError(ValueError): @@ -46,18 +36,23 @@ def __init__(self, event): self.data_package = s3_key_array[2].split("__")[1] self.site = s3_key_array[3] self.version = s3_key_array[4] - self.metadata = read_metadata(self.s3_client, self.s3_bucket_name) + self.metadata = functions.read_metadata(self.s3_client, self.s3_bucket_name) + self.types_metadata = functions.read_metadata( + self.s3_client, + self.s3_bucket_name, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) # S3 Filesystem operations def get_data_package_list(self, path) -> list: """convenience wrapper for get_s3_data_package_list""" - return get_s3_data_package_list( + return awswrangler_functions.get_s3_data_package_list( path, self.s3_bucket_name, self.study, self.data_package ) def move_file(self, from_path: str, to_path: str) -> None: """convenience wrapper for move_s3_file""" - move_s3_file(self.s3_client, self.s3_bucket_name, from_path, to_path) + functions.move_s3_file(self.s3_client, self.s3_bucket_name, from_path, to_path) def copy_file(self, from_path: str, to_path: str) -> None: """convenience wrapper for copy_s3_file""" @@ -75,7 +70,7 @@ def copy_file(self, from_path: str, to_path: str) -> None: def write_parquet(self, df: pandas.DataFrame, is_new_data_package: bool) -> None: """writes dataframe as parquet to s3 and sends an SNS notification if new""" parquet_aggregate_path = ( - f"s3://{self.s3_bucket_name}/{BucketPath.AGGREGATE.value}/" + f"s3://{self.s3_bucket_name}/{enums.BucketPath.AGGREGATE.value}/" f"{self.study}/{self.study}__{self.data_package}/{self.version}/" f"{self.study}__{self.data_package}__aggregate.parquet" ) @@ -89,12 +84,12 @@ def write_parquet(self, df: pandas.DataFrame, is_new_data_package: bool) -> None def write_csv(self, df: pandas.DataFrame) -> None: """writes dataframe as csv to s3""" csv_aggregate_path = ( - f"s3://{self.s3_bucket_name}/{BucketPath.CSVAGGREGATE.value}/" + f"s3://{self.s3_bucket_name}/{enums.BucketPath.CSVAGGREGATE.value}/" f"{self.study}/{self.study}__{self.data_package}/{self.version}/" f"{self.study}__{self.data_package}__aggregate.csv" ) df = df.apply(lambda x: x.strip() if isinstance(x, str) else x).replace( - '""', nan + '""', numpy.nan ) df = df.replace(to_replace=r",", value="", regex=True) awswrangler.s3.to_csv( @@ -102,17 +97,46 @@ def write_csv(self, df: pandas.DataFrame) -> None: ) # metadata - def update_local_metadata(self, key, site=None): + def update_local_metadata( + self, + key, + *, + site=None, + value=None, + metadata: dict | None = None, + meta_type: str | None = enums.JsonFilename.TRANSACTIONS.value, + ): """convenience wrapper for update_metadata""" - if site is None: + # We are excluding COLUMN_TYPES explicitly from this first check because, + # by design, it should never have a site field in it - the column types + # are tied to the study version, not a specific site's data + if site is None and meta_type != enums.JsonFilename.COLUMN_TYPES.value: site = self.site - self.metadata = update_metadata( - self.metadata, site, self.study, self.data_package, self.version, key + if metadata is None: + metadata = self.metadata + metadata = functions.update_metadata( + metadata=metadata, + site=site, + study=self.study, + data_package=self.data_package, + version=self.version, + target=key, + value=value, + meta_type=meta_type, ) - def write_local_metadata(self): + def write_local_metadata( + self, metadata: dict | None = None, meta_type: str | None = None + ): """convenience wrapper for write_metadata""" - write_metadata(self.s3_client, self.s3_bucket_name, self.metadata) + metadata = metadata or self.metadata + meta_type = meta_type or enums.JsonFilename.TRANSACTIONS.value + functions.write_metadata( + s3_client=self.s3_client, + s3_bucket_name=self.s3_bucket_name, + metadata=metadata, + meta_type=meta_type, + ) def merge_error_handler( self, @@ -128,9 +152,9 @@ def merge_error_handler( logger.error(traceback.print_exc()) self.move_file( s3_path.replace(f"s3://{self.s3_bucket_name}/", ""), - f"{BucketPath.ERROR.value}/{subbucket_path}", + f"{enums.BucketPath.ERROR.value}/{subbucket_path}", ) - self.update_local_metadata(TransactionKeys.LAST_ERROR.value) + self.update_local_metadata(enums.TransactionKeys.LAST_ERROR.value) def get_static_string_series(static_str: str, index: RangeIndex) -> pandas.Series: @@ -203,7 +227,7 @@ def generate_csv_from_parquet(bucket_name: str, bucket_root: str, subbucket_path ) last_valid_df = last_valid_df.apply( lambda x: x.strip() if isinstance(x, str) else x - ).replace('""', nan) + ).replace('""', numpy.nan) # Here we are removing internal commas from fields so we get a valid unquoted CSV last_valid_df = last_valid_df.replace(to_replace=",", value="", regex=True) awswrangler.s3.to_csv( @@ -226,12 +250,14 @@ def merge_powersets(manager: S3Manager) -> None: # initializing this early in case an empty file causes us to never set it is_new_data_package = False df = pandas.DataFrame() - latest_file_list = manager.get_data_package_list(BucketPath.LATEST.value) - last_valid_file_list = manager.get_data_package_list(BucketPath.LAST_VALID.value) + latest_file_list = manager.get_data_package_list(enums.BucketPath.LATEST.value) + last_valid_file_list = manager.get_data_package_list( + enums.BucketPath.LAST_VALID.value + ) for last_valid_path in last_valid_file_list: if manager.version not in last_valid_path: continue - site_specific_name = get_s3_site_filename_suffix(last_valid_path) + site_specific_name = functions.get_s3_site_filename_suffix(last_valid_path) subbucket_path = f"{manager.study}/{manager.data_package}/{site_specific_name}" last_valid_site = site_specific_name.split("/", maxsplit=1)[0] # If the latest uploads don't include this site, we'll use the last-valid @@ -240,7 +266,7 @@ def merge_powersets(manager: S3Manager) -> None: if not any(x.endswith(site_specific_name) for x in latest_file_list): df = expand_and_concat_sets(df, last_valid_path, last_valid_site) manager.update_local_metadata( - TransactionKeys.LAST_AGGREGATION.value, site=last_valid_site + enums.TransactionKeys.LAST_AGGREGATION.value, site=last_valid_site ) except MergeError as e: # This is expected to trigger if there's an issue in expand_and_concat_sets; @@ -253,12 +279,12 @@ def merge_powersets(manager: S3Manager) -> None: for latest_path in latest_file_list: if manager.version not in latest_path: continue - site_specific_name = get_s3_site_filename_suffix(latest_path) + site_specific_name = functions.get_s3_site_filename_suffix(latest_path) subbucket_path = ( f"{manager.study}/{manager.study}__{manager.data_package}" f"/{site_specific_name}" ) - date_str = datetime.now(timezone.utc).isoformat() + date_str = datetime.datetime.now(datetime.timezone.utc).isoformat() timestamped_name = f".{date_str}.".join(site_specific_name.split(".")) timestamped_path = ( f"{manager.study}/{manager.study}__{manager.data_package}" @@ -269,8 +295,8 @@ def merge_powersets(manager: S3Manager) -> None: # if we're going to replace a file in last_valid, archive the old data if any(x.endswith(site_specific_name) for x in last_valid_file_list): manager.copy_file( - f"{BucketPath.LAST_VALID.value}/{subbucket_path}", - f"{BucketPath.ARCHIVE.value}/{timestamped_path}", + f"{enums.BucketPath.LAST_VALID.value}/{subbucket_path}", + f"{enums.BucketPath.ARCHIVE.value}/{timestamped_path}", ) # otherwise, this is the first instance - after it's in the database, # we'll generate a new list of valid tables for the dashboard @@ -278,8 +304,8 @@ def merge_powersets(manager: S3Manager) -> None: is_new_data_package = True df = expand_and_concat_sets(df, latest_path, manager.site) manager.move_file( - f"{BucketPath.LATEST.value}/{subbucket_path}", - f"{BucketPath.LAST_VALID.value}/{subbucket_path}", + f"{enums.BucketPath.LATEST.value}/{subbucket_path}", + f"{enums.BucketPath.LAST_VALID.value}/{subbucket_path}", ) #################### @@ -288,16 +314,18 @@ def merge_powersets(manager: S3Manager) -> None: # TODO: remove as soon as we support either parquet upload or # the API is supported by the dashboard generate_csv_from_parquet( - manager.s3_bucket_name, BucketPath.LAST_VALID.value, subbucket_path + manager.s3_bucket_name, + enums.BucketPath.LAST_VALID.value, + subbucket_path, ) #################### latest_site = site_specific_name.split("/", maxsplit=1)[0] manager.update_local_metadata( - TransactionKeys.LAST_DATA_UPDATE.value, site=latest_site + enums.TransactionKeys.LAST_DATA_UPDATE.value, site=latest_site ) manager.update_local_metadata( - TransactionKeys.LAST_AGGREGATION.value, site=latest_site + enums.TransactionKeys.LAST_AGGREGATION.value, site=latest_site ) except Exception as e: manager.merge_error_handler( @@ -310,21 +338,38 @@ def merge_powersets(manager: S3Manager) -> None: if any(x.endswith(site_specific_name) for x in last_valid_file_list): df = expand_and_concat_sets( df, - f"s3://{manager.s3_bucket_name}/{BucketPath.LAST_VALID.value}" + f"s3://{manager.s3_bucket_name}/{enums.BucketPath.LAST_VALID.value}" f"/{subbucket_path}", manager.site, ) - manager.update_local_metadata(TransactionKeys.LAST_AGGREGATION.value) + manager.update_local_metadata( + enums.TransactionKeys.LAST_AGGREGATION.value + ) if df.empty: - manager.merge_error_handler( - latest_path, - subbucket_path, - OSError("File not found"), - ) + raise OSError("File not found") manager.write_local_metadata() + # Updating the typing dict for the CSV API + column_dict = functions.get_csv_column_datatypes(df.dtypes) + manager.update_local_metadata( + enums.ColumnTypesKeys.COLUMNS.value, + value=column_dict, + metadata=manager.types_metadata, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) + manager.update_local_metadata( + enums.ColumnTypesKeys.LAST_DATA_UPDATE.value, + value=column_dict, + metadata=manager.types_metadata, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) + manager.write_local_metadata( + metadata=manager.types_metadata, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) + # In this section, we are trying to accomplish two things: # - Prepare a csv that can be loaded manually into the dashboard (requiring no # quotes, which means removing commas from strings) @@ -334,11 +379,11 @@ def merge_powersets(manager: S3Manager) -> None: manager.write_csv(df) -@generic_error_handler(msg="Error merging powersets") +@decorators.generic_error_handler(msg="Error merging powersets") def powerset_merge_handler(event, context): """manages event from SNS, triggers file processing and merge""" del context manager = S3Manager(event) merge_powersets(manager) - res = http_response(200, "Merge successful") + res = functions.http_response(200, "Merge successful") return res diff --git a/src/handlers/site_upload/process_upload.py b/src/handlers/site_upload/process_upload.py index 7784eb4..677a76f 100644 --- a/src/handlers/site_upload/process_upload.py +++ b/src/handlers/site_upload/process_upload.py @@ -3,15 +3,7 @@ import boto3 -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath, TransactionKeys -from ..shared.functions import ( - http_response, - move_s3_file, - read_metadata, - update_metadata, - write_metadata, -) +from src.handlers.shared import decorators, enums, functions class UnexpectedFileTypeError(Exception): @@ -23,7 +15,7 @@ def process_upload(s3_client, sns_client, s3_bucket_name: str, s3_key: str) -> N last_uploaded_date = s3_client.head_object(Bucket=s3_bucket_name, Key=s3_key)[ "LastModified" ] - metadata = read_metadata(s3_client, s3_bucket_name) + metadata = functions.read_metadata(s3_client, s3_bucket_name) path_params = s3_key.split("/") study = path_params[1] data_package = path_params[2] @@ -33,55 +25,59 @@ def process_upload(s3_client, sns_client, s3_bucket_name: str, s3_key: str) -> N # to archive - we don't care about metadata for this, but can look there to # verify transmission if it's a connectivity test if study == "template": - new_key = f"{BucketPath.ARCHIVE.value}/{s3_key.split('/', 1)[-1]}" - move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) + new_key = f"{enums.BucketPath.ARCHIVE.value}/{s3_key.split('/', 1)[-1]}" + functions.move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) elif s3_key.endswith(".parquet"): if "__meta_" in s3_key or "/discovery__" in s3_key: - new_key = f"{BucketPath.STUDY_META.value}/{s3_key.split('/', 1)[-1]}" + new_key = f"{enums.BucketPath.STUDY_META.value}/{s3_key.split('/', 1)[-1]}" topic_sns_arn = os.environ.get("TOPIC_PROCESS_STUDY_META_ARN") sns_subject = "Process study metadata upload event" else: - new_key = f"{BucketPath.LATEST.value}/{s3_key.split('/', 1)[-1]}" + new_key = f"{enums.BucketPath.LATEST.value}/{s3_key.split('/', 1)[-1]}" topic_sns_arn = os.environ.get("TOPIC_PROCESS_COUNTS_ARN") sns_subject = "Process counts upload event" - move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) - metadata = update_metadata( - metadata, - site, - study, - data_package, - version, - TransactionKeys.LAST_UPLOAD.value, - last_uploaded_date, + functions.move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) + metadata = functions.update_metadata( + metadata=metadata, + site=site, + study=study, + data_package=data_package, + version=version, + target=enums.TransactionKeys.LAST_UPLOAD.value, + dt=last_uploaded_date, ) sns_client.publish(TopicArn=topic_sns_arn, Message=new_key, Subject=sns_subject) - write_metadata(s3_client, s3_bucket_name, metadata) + functions.write_metadata( + s3_client=s3_client, s3_bucket_name=s3_bucket_name, metadata=metadata + ) else: - new_key = f"{BucketPath.ERROR.value}/{s3_key.split('/', 1)[-1]}" - move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) - metadata = update_metadata( - metadata, - site, - study, - data_package, - version, - TransactionKeys.LAST_UPLOAD.value, - last_uploaded_date, + new_key = f"{enums.BucketPath.ERROR.value}/{s3_key.split('/', 1)[-1]}" + functions.move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) + metadata = functions.update_metadata( + metadata=metadata, + site=site, + study=study, + data_package=data_package, + version=version, + target=enums.TransactionKeys.LAST_UPLOAD.value, + dt=last_uploaded_date, + ) + metadata = functions.update_metadata( + metadata=metadata, + site=site, + study=study, + data_package=data_package, + version=version, + target=enums.TransactionKeys.LAST_ERROR.value, + dt=last_uploaded_date, ) - metadata = update_metadata( - metadata, - site, - study, - data_package, - version, - TransactionKeys.LAST_ERROR.value, - last_uploaded_date, + functions.write_metadata( + s3_client=s3_client, s3_bucket_name=s3_bucket_name, metadata=metadata ) - write_metadata(s3_client, s3_bucket_name, metadata) raise UnexpectedFileTypeError -@generic_error_handler(msg="Error processing file upload") +@decorators.generic_error_handler(msg="Error processing file upload") def process_upload_handler(event, context): """manages event from S3, triggers file processing and merge""" del context @@ -90,5 +86,5 @@ def process_upload_handler(event, context): sns_client = boto3.client("sns", region_name=event["Records"][0]["awsRegion"]) s3_key = event["Records"][0]["s3"]["object"]["key"] process_upload(s3_client, sns_client, s3_bucket, s3_key) - res = http_response(200, "Upload processing successful") + res = functions.http_response(200, "Upload processing successful") return res diff --git a/src/handlers/site_upload/study_period.py b/src/handlers/site_upload/study_period.py index 93e5edc..8c87d21 100644 --- a/src/handlers/site_upload/study_period.py +++ b/src/handlers/site_upload/study_period.py @@ -6,10 +6,10 @@ import awswrangler import boto3 -from ..shared.awswrangler_functions import get_s3_study_meta_list -from ..shared.decorators import generic_error_handler -from ..shared.enums import JsonFilename, StudyPeriodMetadataKeys -from ..shared.functions import ( +from src.handlers.shared.awswrangler_functions import get_s3_study_meta_list +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import JsonFilename, StudyPeriodMetadataKeys +from src.handlers.shared.functions import ( http_response, read_metadata, update_metadata, @@ -28,37 +28,40 @@ def update_study_period(s3_client, s3_bucket, site, study, data_package, version ) study_meta = update_metadata( - study_meta, - site, - study, - data_package, - version, - StudyPeriodMetadataKeys.EARLIEST_DATE.value, - df["min_date"][0], + metadata=study_meta, + site=site, + study=study, + data_package=data_package, + version=version, + target=StudyPeriodMetadataKeys.EARLIEST_DATE.value, + dt=df["min_date"][0], meta_type=JsonFilename.STUDY_PERIODS.value, ) study_meta = update_metadata( - study_meta, - site, - study, - data_package, - version, - StudyPeriodMetadataKeys.LATEST_DATE.value, - df["max_date"][0], + metadata=study_meta, + site=site, + study=study, + data_package=data_package, + version=version, + target=StudyPeriodMetadataKeys.LATEST_DATE.value, + dt=df["max_date"][0], meta_type=JsonFilename.STUDY_PERIODS.value, ) study_meta = update_metadata( - study_meta, - site, - study, - data_package, - version, - StudyPeriodMetadataKeys.LAST_DATA_UPDATE.value, - datetime.now(timezone.utc), + metadata=study_meta, + site=site, + study=study, + data_package=data_package, + version=version, + target=StudyPeriodMetadataKeys.LAST_DATA_UPDATE.value, + dt=datetime.now(timezone.utc), meta_type=JsonFilename.STUDY_PERIODS.value, ) write_metadata( - s3_client, s3_bucket, study_meta, meta_type=JsonFilename.STUDY_PERIODS.value + s3_client=s3_client, + s3_bucket_name=s3_bucket, + metadata=study_meta, + meta_type=JsonFilename.STUDY_PERIODS.value, ) diff --git a/template.yaml b/template.yaml index a6868ac..9f0e786 100644 --- a/template.yaml +++ b/template.yaml @@ -243,6 +243,77 @@ Resources: # Dashboard API + DashboardGetCsvFunction: + Type: AWS::Serverless::Function + Properties: + FunctionName: !Sub 'CumulusAggDashboardGetCsv-${DeployStage}' + Handler: src/handlers/dashboard/get_csv.get_csv_handler + Runtime: python3.10 + MemorySize: 128 + Timeout: 100 + Description: Redirect to presigned URL for download of aggregate CSVs + Environment: + Variables: + BUCKET_NAME: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' + Events: + GetAggregateAPI: + Type: Api + Properties: + RestApiId: !Ref DashboardApiGateway + Path: /aggregates/{study}/{subscription}/{version}/{filename} + Method: GET + GetLastValidAPI: + Type: Api + Properties: + RestApiId: !Ref DashboardApiGateway + Path: /last_valid/{study}/{subscription}/{site}/{version}/{filename} + Method: GET + Policies: + - S3ReadPolicy: + BucketName: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' + + DashboardGetCsvLogGroup: + Type: AWS::Logs::LogGroup + Properties: + LogGroupName: !Sub "/aws/lambda/${DashboardGetCsvFunction}" + RetentionInDays: !Ref RetentionTime + + + DashboardGetCsvListFunction: + Type: AWS::Serverless::Function + Properties: + FunctionName: !Sub 'CumulusAggDashboardGetCsvList-${DeployStage}' + Handler: src/handlers/dashboard/get_csv.get_csv_list_handler + Runtime: python3.10 + MemorySize: 128 + Timeout: 100 + Description: List all available csvs from the aggregator + Environment: + Variables: + BUCKET_NAME: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' + Events: + GetAggregateAPI: + Type: Api + Properties: + RestApiId: !Ref DashboardApiGateway + Path: /aggregates/ + Method: GET + GetLastValidAPI: + Type: Api + Properties: + RestApiId: !Ref DashboardApiGateway + Path: /last_valid/ + Method: GET + Policies: + - S3ReadPolicy: + BucketName: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' + + DashboardGetCsvListLogGroup: + Type: AWS::Logs::LogGroup + Properties: + LogGroupName: !Sub "/aws/lambda/${DashboardGetCsvListFunction}" + RetentionInDays: !Ref RetentionTime + DashboardGetChartDataFunction: Type: AWS::Serverless::Function Properties: @@ -328,6 +399,12 @@ Resources: RestApiId: !Ref DashboardApiGateway Path: /metadata/{site}/{study}/{subscription} Method: GET + GetMetadataSiteStudySubscriptionVersionAPI: + Type: Api + Properties: + RestApiId: !Ref DashboardApiGateway + Path: /metadata/{site}/{study}/{subscription}/{version} + Method: GET Policies: - S3ReadPolicy: BucketName: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' diff --git a/tests/conftest.py b/tests/conftest.py index b011150..bdf4510 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,19 +28,9 @@ import pytest from moto import mock_athena, mock_s3, mock_sns -from scripts.credential_management import create_auth, create_meta -from src.handlers.shared.enums import BucketPath, JsonFilename -from src.handlers.shared.functions import write_metadata -from tests.utils import ( - EXISTING_DATA_P, - EXISTING_STUDY, - EXISTING_VERSION, - ITEM_COUNT, - MOCK_ENV, - OTHER_STUDY, - get_mock_metadata, - get_mock_study_metadata, -) +from scripts import credential_management +from src.handlers.shared import enums, functions +from tests import mock_utils def _init_mock_data(s3_client, bucket, study, data_package, version): @@ -57,25 +47,25 @@ def _init_mock_data(s3_client, bucket, study, data_package, version): s3_client.upload_file( "./tests/test_data/count_synthea_patient_agg.parquet", bucket, - f"{BucketPath.AGGREGATE.value}/{study}/" + f"{enums.BucketPath.AGGREGATE.value}/{study}/" f"{study}__{data_package}/{version}/{study}__{data_package}__aggregate.parquet", ) s3_client.upload_file( "./tests/test_data/count_synthea_patient_agg.csv", bucket, - f"{BucketPath.CSVAGGREGATE.value}/{study}/" + f"{enums.BucketPath.CSVAGGREGATE.value}/{study}/" f"{study}__{data_package}/{version}/{study}__{data_package}__aggregate.csv", ) s3_client.upload_file( "./tests/test_data/data_packages_cache.json", bucket, - f"{BucketPath.CACHE.value}/{JsonFilename.DATA_PACKAGES.value}.json", + f"{enums.BucketPath.CACHE.value}/{enums.JsonFilename.DATA_PACKAGES.value}.json", ) @pytest.fixture(autouse=True) def mock_env(): - with mock.patch.dict(os.environ, MOCK_ENV): + with mock.patch.dict(os.environ, mock_utils.MOCK_ENV): yield @@ -89,23 +79,49 @@ def mock_bucket(): bucket = os.environ["BUCKET_NAME"] s3_client.create_bucket(Bucket=bucket) aggregate_params = [ - [EXISTING_STUDY, EXISTING_DATA_P, EXISTING_VERSION], - [OTHER_STUDY, EXISTING_DATA_P, EXISTING_VERSION], + [ + mock_utils.EXISTING_STUDY, + mock_utils.EXISTING_DATA_P, + mock_utils.EXISTING_VERSION, + ], + [ + mock_utils.OTHER_STUDY, + mock_utils.EXISTING_DATA_P, + mock_utils.EXISTING_VERSION, + ], ] for param_list in aggregate_params: _init_mock_data(s3_client, bucket, *param_list) - create_auth(s3_client, bucket, "ppth_1", "test_1", "ppth") - create_meta(s3_client, bucket, "ppth", "princeton_plainsboro_teaching_hospital") - create_auth(s3_client, bucket, "elsewhere_2", "test_2", "elsewhere") - create_meta(s3_client, bucket, "elsewhere", "st_elsewhere") - create_auth(s3_client, bucket, "hope_3", "test_3", "hope") - create_meta(s3_client, bucket, "hope", "chicago_hope") - - metadata = get_mock_metadata() - write_metadata(s3_client, bucket, metadata) - study_metadata = get_mock_study_metadata() - write_metadata(s3_client, bucket, study_metadata, meta_type="study_periods") + credential_management.create_auth(s3_client, bucket, "ppth_1", "test_1", "ppth") + credential_management.create_meta( + s3_client, bucket, "ppth", "princeton_plainsboro_teaching_hospital" + ) + credential_management.create_auth( + s3_client, bucket, "elsewhere_2", "test_2", "elsewhere" + ) + credential_management.create_meta(s3_client, bucket, "elsewhere", "st_elsewhere") + credential_management.create_auth(s3_client, bucket, "hope_3", "test_3", "hope") + credential_management.create_meta(s3_client, bucket, "hope", "chicago_hope") + + metadata = mock_utils.get_mock_metadata() + functions.write_metadata( + s3_client=s3_client, s3_bucket_name=bucket, metadata=metadata + ) + study_metadata = mock_utils.get_mock_study_metadata() + functions.write_metadata( + s3_client=s3_client, + s3_bucket_name=bucket, + metadata=study_metadata, + meta_type=enums.JsonFilename.STUDY_PERIODS.value, + ) + column_types_metadata = mock_utils.get_mock_column_types_metadata() + functions.write_metadata( + s3_client=s3_client, + s3_bucket_name=bucket, + metadata=column_types_metadata, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) yield s3.stop() @@ -114,7 +130,7 @@ def mock_bucket(): def mock_notification(): """Mocks for SNS topics. - Make sure the topic name matches the end of the ARN defined in utils.py""" + Make sure the topic name matches the end of the ARN defined in mock_utils.py""" sns = mock_sns() sns.start() sns_client = boto3.client("sns", region_name="us-east-1") @@ -152,4 +168,4 @@ def mock_db(): def test_mock_bucket(): s3_client = boto3.client("s3", region_name="us-east-1") item = s3_client.list_objects_v2(Bucket=os.environ["TEST_BUCKET"]) - assert (len(item["Contents"])) == ITEM_COUNT + assert (len(item["Contents"])) == mock_utils.ITEM_COUNT diff --git a/tests/dashboard/test_get_chart_data.py b/tests/dashboard/test_get_chart_data.py index 8decd12..ec11a9a 100644 --- a/tests/dashboard/test_get_chart_data.py +++ b/tests/dashboard/test_get_chart_data.py @@ -6,7 +6,7 @@ import pytest from src.handlers.dashboard import get_chart_data -from tests.utils import ( +from tests.mock_utils import ( EXISTING_DATA_P, EXISTING_STUDY, MOCK_ENV, diff --git a/tests/dashboard/test_get_csv.py b/tests/dashboard/test_get_csv.py new file mode 100644 index 0000000..1e34f7a --- /dev/null +++ b/tests/dashboard/test_get_csv.py @@ -0,0 +1,150 @@ +import json +import os +from contextlib import nullcontext as does_not_raise +from unittest import mock + +import boto3 +import pytest + +from src.handlers.dashboard import get_csv +from src.handlers.shared import enums +from tests import mock_utils + +# data matching these params is created via conftest +site = mock_utils.EXISTING_SITE +study = mock_utils.EXISTING_STUDY +subscription = mock_utils.EXISTING_DATA_P +version = mock_utils.EXISTING_VERSION +filename = filename = f"{study}__{subscription}__aggregate.csv" + + +def _mock_last_valid(): + bucket = os.environ["BUCKET_NAME"] + s3_client = boto3.client("s3", region_name="us-east-1") + s3_client.upload_file( + "./tests/test_data/count_synthea_patient_agg.csv", + bucket, + f"{enums.BucketPath.LAST_VALID.value}/{study}/" + f"{study}__{subscription}/{site}/{version}/{filename}", + ) + + +@pytest.mark.parametrize( + "params,status,expected", + [ + ( + { + "site": None, + "study": study, + "subscription": subscription, + "version": version, + "filename": filename, + }, + 302, + None, + ), + ( + { + "site": site, + "study": study, + "subscription": subscription, + "version": version, + "filename": filename, + }, + 302, + None, + ), + ( + { + "study": study, + "subscription": subscription, + "version": version, + "filename": filename, + }, + 302, + None, + ), + ( + { + "site": site, + "study": study, + "subscription": subscription, + "version": version, + "filename": "foo", + }, + 500, + None, + ), + ( + { + "site": None, + "study": None, + "subscription": None, + "version": None, + "filename": None, + }, + 500, + None, + ), + ], +) +@mock.patch.dict(os.environ, mock_utils.MOCK_ENV) +def test_get_csv(mock_bucket, params, status, expected): + event = {"pathParameters": params} + if "site" in params and params["site"] is not None: + _mock_last_valid() + res = get_csv.get_csv_handler(event, {}) + assert res["statusCode"] == status + if status == 302: + if "site" not in params or params["site"] is None: + url = ( + "https://cumulus-aggregator-site-counts-test.s3.amazonaws.com/csv_aggregates/" + f"{study}/{study}__{subscription}/{version}/{filename}" + ) + else: + url = ( + "https://cumulus-aggregator-site-counts-test.s3.amazonaws.com/last_valid/" + f"{study}/{study}__{subscription}/{site}/{version}/{filename}" + ) + assert ( + res["headers"]["x-column-types"] == "integer,string,integer,string,string" + ) + assert res["headers"]["x-column-names"] == "cnt,gender,age,race_display,site" + assert res["headers"]["x-column-descriptions"] == "" + assert res["headers"]["Location"].startswith(url) + + +@pytest.mark.parametrize( + "path,status,expected,raises", + [ + ( + "/aggregates", + 200, + [ + "aggregates/other_study/encounter/099/other_study__encounter__aggregate.csv", + "aggregates/study/encounter/099/study__encounter__aggregate.csv", + ], + does_not_raise(), + ), + ( + "/last_valid", + 200, + [ + "last_valid/study/encounter/princeton_plainsboro_teaching_hospital/099/study__encounter__aggregate.csv" + ], + does_not_raise(), + ), + ("/some_other_endpoint", 500, [], does_not_raise()), + ], +) +@mock.patch.dict(os.environ, mock_utils.MOCK_ENV) +def test_get_csv_list(mock_bucket, path, status, expected, raises): + with raises: + if path.startswith("/last_valid"): + _mock_last_valid() + event = {"path": path} + res = get_csv.get_csv_list_handler(event, {}) + keys = json.loads(res["body"]) + assert res["statusCode"] == status + if status == 200: + assert keys == expected diff --git a/tests/dashboard/test_get_metadata.py b/tests/dashboard/test_get_metadata.py index 7166a9b..959a95f 100644 --- a/tests/dashboard/test_get_metadata.py +++ b/tests/dashboard/test_get_metadata.py @@ -3,41 +3,50 @@ import pytest from src.handlers.dashboard.get_metadata import metadata_handler -from tests.utils import ( - EXISTING_DATA_P, - EXISTING_SITE, - EXISTING_STUDY, - NEW_SITE, - NEW_STUDY, - get_mock_metadata, -) +from tests import mock_utils @pytest.mark.parametrize( "params,status,expected", [ - (None, 200, get_mock_metadata()), + (None, 200, mock_utils.get_mock_metadata()), + ( + {"site": mock_utils.EXISTING_SITE}, + 200, + mock_utils.get_mock_metadata()[mock_utils.EXISTING_SITE], + ), ( - {"site": EXISTING_SITE}, + {"site": mock_utils.EXISTING_SITE, "study": mock_utils.EXISTING_STUDY}, 200, - get_mock_metadata()[EXISTING_SITE], + mock_utils.get_mock_metadata()[mock_utils.EXISTING_SITE][ + mock_utils.EXISTING_STUDY + ], ), ( - {"site": EXISTING_SITE, "study": EXISTING_STUDY}, + { + "site": mock_utils.EXISTING_SITE, + "study": mock_utils.EXISTING_STUDY, + "data_package": mock_utils.EXISTING_DATA_P, + }, 200, - get_mock_metadata()[EXISTING_SITE][EXISTING_STUDY], + mock_utils.get_mock_metadata()[mock_utils.EXISTING_SITE][ + mock_utils.EXISTING_STUDY + ][mock_utils.EXISTING_DATA_P], ), ( { - "site": EXISTING_SITE, - "study": EXISTING_STUDY, - "data_package": EXISTING_DATA_P, + "site": mock_utils.EXISTING_SITE, + "study": mock_utils.EXISTING_STUDY, + "data_package": mock_utils.EXISTING_DATA_P, + "version": mock_utils.EXISTING_VERSION, }, 200, - get_mock_metadata()[EXISTING_SITE][EXISTING_STUDY][EXISTING_DATA_P], + mock_utils.get_mock_metadata()[mock_utils.EXISTING_SITE][ + mock_utils.EXISTING_STUDY + ][mock_utils.EXISTING_DATA_P][mock_utils.EXISTING_VERSION], ), - ({"site": NEW_SITE, "study": EXISTING_STUDY}, 500, None), - ({"site": EXISTING_SITE, "study": NEW_STUDY}, 500, None), + ({"site": mock_utils.NEW_SITE, "study": mock_utils.EXISTING_STUDY}, 500, None), + ({"site": mock_utils.EXISTING_SITE, "study": mock_utils.NEW_STUDY}, 500, None), ], ) def test_get_metadata(mock_bucket, params, status, expected): diff --git a/tests/dashboard/test_get_study_periods.py b/tests/dashboard/test_get_study_periods.py index 4e9c0a2..d9d2ba0 100644 --- a/tests/dashboard/test_get_study_periods.py +++ b/tests/dashboard/test_get_study_periods.py @@ -3,7 +3,7 @@ import pytest from src.handlers.dashboard.get_study_periods import study_periods_handler -from tests.utils import ( +from tests.mock_utils import ( EXISTING_SITE, EXISTING_STUDY, NEW_SITE, @@ -32,7 +32,6 @@ ) def test_get_study_periods(mock_bucket, params, status, expected): event = {"pathParameters": params} - res = study_periods_handler(event, {}) assert res["statusCode"] == status if status == 200: diff --git a/tests/dashboard/test_get_subscriptions.py b/tests/dashboard/test_get_subscriptions.py index 5fcfff3..7a4ec13 100644 --- a/tests/dashboard/test_get_subscriptions.py +++ b/tests/dashboard/test_get_subscriptions.py @@ -2,7 +2,7 @@ from unittest import mock from src.handlers.dashboard.get_data_packages import data_packages_handler -from tests.utils import DATA_PACKAGE_COUNT, MOCK_ENV +from tests.mock_utils import DATA_PACKAGE_COUNT, MOCK_ENV @mock.patch.dict(os.environ, MOCK_ENV) diff --git a/tests/utils.py b/tests/mock_utils.py similarity index 88% rename from tests/utils.py rename to tests/mock_utils.py index 6259e67..301d35f 100644 --- a/tests/utils.py +++ b/tests/mock_utils.py @@ -6,7 +6,7 @@ TEST_PROCESS_COUNTS_ARN = "arn:aws:sns:us-east-1:123456789012:test-counts" TEST_PROCESS_STUDY_META_ARN = "arn:aws:sns:us-east-1:123456789012:test-meta" TEST_CACHE_API_ARN = "arn:aws:sns:us-east-1:123456789012:test-cache" -ITEM_COUNT = 9 +ITEM_COUNT = 10 DATA_PACKAGE_COUNT = 2 EXISTING_SITE = "princeton_plainsboro_teaching_hospital" @@ -110,6 +110,26 @@ def get_mock_study_metadata(): } +def get_mock_column_types_metadata(): + return { + EXISTING_STUDY: { + EXISTING_DATA_P: { + EXISTING_VERSION: { + "column_types_format_version": "1", + "columns": { + "cnt": "integer", + "gender": "string", + "age": "integer", + "race_display": "string", + "site": "string", + }, + "last_data_update": "2023-02-24T15:08:07.771080+00:00", + } + } + } + } + + def get_mock_auth(): return { # u/a: ppth_1 test_1 diff --git a/tests/site_upload/test_api_gateway_authorizer.py b/tests/site_upload/test_api_gateway_authorizer.py index 19c58cb..8d7a7ec 100644 --- a/tests/site_upload/test_api_gateway_authorizer.py +++ b/tests/site_upload/test_api_gateway_authorizer.py @@ -3,13 +3,13 @@ import pytest from src.handlers.site_upload import api_gateway_authorizer -from tests import utils +from tests import mock_utils @pytest.mark.parametrize( "auth,expects", [ - (f"Basic {next(iter(utils.get_mock_auth().keys()))}", does_not_raise()), + (f"Basic {next(iter(mock_utils.get_mock_auth().keys()))}", does_not_raise()), ("Basic other_auth", pytest.raises(api_gateway_authorizer.AuthError)), (None, pytest.raises(AttributeError)), ], diff --git a/tests/site_upload/test_cache_api.py b/tests/site_upload/test_cache_api.py index 3953355..f50dc24 100644 --- a/tests/site_upload/test_cache_api.py +++ b/tests/site_upload/test_cache_api.py @@ -5,7 +5,7 @@ import pytest from src.handlers.site_upload.cache_api import cache_api_handler -from tests.utils import MOCK_ENV, get_mock_data_packages_cache +from tests.mock_utils import MOCK_ENV, get_mock_data_packages_cache def mock_data_packages(*args, **kwargs): diff --git a/tests/site_upload/test_fetch_upload_url.py b/tests/site_upload/test_fetch_upload_url.py index b96417b..08d997e 100644 --- a/tests/site_upload/test_fetch_upload_url.py +++ b/tests/site_upload/test_fetch_upload_url.py @@ -4,7 +4,7 @@ from src.handlers.shared.enums import BucketPath from src.handlers.site_upload.fetch_upload_url import upload_url_handler -from tests.utils import ( +from tests.mock_utils import ( EXISTING_DATA_P, EXISTING_SITE, EXISTING_STUDY, diff --git a/tests/site_upload/test_powerset_merge.py b/tests/site_upload/test_powerset_merge.py index 422571d..083a8e6 100644 --- a/tests/site_upload/test_powerset_merge.py +++ b/tests/site_upload/test_powerset_merge.py @@ -10,15 +10,9 @@ from freezegun import freeze_time from pandas import DataFrame, read_parquet -from src.handlers.shared.enums import BucketPath -from src.handlers.shared.functions import read_metadata -from src.handlers.site_upload.powerset_merge import ( - MergeError, - expand_and_concat_sets, - generate_csv_from_parquet, - powerset_merge_handler, -) -from tests.utils import ( +from src.handlers.shared import enums, functions +from src.handlers.site_upload import powerset_merge +from tests.mock_utils import ( EXISTING_DATA_P, EXISTING_SITE, EXISTING_STUDY, @@ -29,6 +23,7 @@ NEW_STUDY, NEW_VERSION, TEST_BUCKET, + get_mock_column_types_metadata, get_mock_metadata, ) @@ -145,26 +140,26 @@ def test_powerset_merge_single_upload( s3_client.upload_file( upload_file, TEST_BUCKET, - f"{BucketPath.LATEST.value}{upload_path}", + f"{enums.BucketPath.LATEST.value}{upload_path}", ) elif upload_path is not None: with io.BytesIO(DataFrame().to_parquet()) as upload_fileobj: s3_client.upload_fileobj( upload_fileobj, TEST_BUCKET, - f"{BucketPath.LATEST.value}{upload_path}", + f"{enums.BucketPath.LATEST.value}{upload_path}", ) if archives: s3_client.upload_file( upload_file, TEST_BUCKET, - f"{BucketPath.LAST_VALID.value}{upload_path}", + f"{enums.BucketPath.LAST_VALID.value}{upload_path}", ) event = { "Records": [ { - "Sns": {"Message": f"{BucketPath.LATEST.value}{event_key}"}, + "Sns": {"Message": f"{enums.BucketPath.LATEST.value}{event_key}"}, } ] } @@ -175,13 +170,13 @@ def test_powerset_merge_single_upload( data_package = event_list[2] site = event_list[3] version = event_list[4] - res = powerset_merge_handler(event, {}) + res = powerset_merge.powerset_merge_handler(event, {}) assert res["statusCode"] == status s3_res = s3_client.list_objects_v2(Bucket=TEST_BUCKET) assert len(s3_res["Contents"]) == expected_contents for item in s3_res["Contents"]: if item["Key"].endswith("aggregate.parquet"): - assert item["Key"].startswith(BucketPath.AGGREGATE.value) + assert item["Key"].startswith(enums.BucketPath.AGGREGATE.value) # This finds the aggregate that was created/updated - ie it skips mocks if study in item["Key"] and status == 200: agg_df = awswrangler.s3.read_parquet( @@ -189,10 +184,10 @@ def test_powerset_merge_single_upload( ) assert (agg_df["site"].eq(site)).any() elif item["Key"].endswith("aggregate.csv"): - assert item["Key"].startswith(BucketPath.CSVAGGREGATE.value) + assert item["Key"].startswith(enums.BucketPath.CSVAGGREGATE.value) elif item["Key"].endswith("transactions.json"): - assert item["Key"].startswith(BucketPath.META.value) - metadata = read_metadata(s3_client, TEST_BUCKET) + assert item["Key"].startswith(enums.BucketPath.META.value) + metadata = functions.read_metadata(s3_client, TEST_BUCKET) if res["statusCode"] == 200: assert ( metadata[site][study][data_package.split("__")[1]][version][ @@ -218,20 +213,39 @@ def test_powerset_merge_single_upload( ] != datetime.now(timezone.utc).isoformat() ) + elif item["Key"].endswith("column_types.json"): + assert item["Key"].startswith(enums.BucketPath.META.value) + metadata = functions.read_metadata( + s3_client, TEST_BUCKET, meta_type=enums.JsonFilename.COLUMN_TYPES.value + ) + if res["statusCode"] == 200: + last_update = metadata[study][data_package.split("__")[1]][version][ + "last_data_update" + ] + assert last_update == datetime.now(timezone.utc).isoformat() - elif item["Key"].startswith(BucketPath.LAST_VALID.value): + else: + assert ( + metadata["study"]["encounter"]["099"]["last_data_update"] + == get_mock_column_types_metadata()["study"]["encounter"]["099"][ + "last_data_update" + ] + ) + elif item["Key"].startswith(enums.BucketPath.LAST_VALID.value): if item["Key"].endswith(".parquet"): - assert item["Key"] == (f"{BucketPath.LAST_VALID.value}{upload_path}") + assert item["Key"] == ( + f"{enums.BucketPath.LAST_VALID.value}{upload_path}" + ) elif item["Key"].endswith(".csv"): assert f"{upload_path.replace('.parquet','.csv')}" in item["Key"] else: raise Exception("Invalid csv found at " f"{item['Key']}") else: assert ( - item["Key"].startswith(BucketPath.ARCHIVE.value) - or item["Key"].startswith(BucketPath.ERROR.value) - or item["Key"].startswith(BucketPath.ADMIN.value) - or item["Key"].startswith(BucketPath.CACHE.value) + item["Key"].startswith(enums.BucketPath.ARCHIVE.value) + or item["Key"].startswith(enums.BucketPath.ERROR.value) + or item["Key"].startswith(enums.BucketPath.ADMIN.value) + or item["Key"].startswith(enums.BucketPath.CACHE.value) or item["Key"].endswith("study_periods.json") ) if archives: @@ -240,7 +254,7 @@ def test_powerset_merge_single_upload( keys.append(resource["Key"]) date_str = datetime.now(timezone.utc).isoformat() archive_path = f".{date_str}.".join(upload_path.split(".")) - assert f"{BucketPath.ARCHIVE.value}{archive_path}" in keys + assert f"{enums.BucketPath.ARCHIVE.value}{archive_path}" in keys @freeze_time("2020-01-01") @@ -264,7 +278,7 @@ def test_powerset_merge_join_study_data( s3_client.upload_file( upload_file, TEST_BUCKET, - f"{BucketPath.LATEST.value}/{EXISTING_STUDY}/" + f"{enums.BucketPath.LATEST.value}/{EXISTING_STUDY}/" f"{EXISTING_STUDY}__{EXISTING_DATA_P}/{NEW_SITE}/" f"{EXISTING_VERSION}/encounter.parquet", ) @@ -272,7 +286,7 @@ def test_powerset_merge_join_study_data( s3_client.upload_file( "./tests/test_data/count_synthea_patient.parquet", TEST_BUCKET, - f"{BucketPath.LAST_VALID.value}/{EXISTING_STUDY}/" + f"{enums.BucketPath.LAST_VALID.value}/{EXISTING_STUDY}/" f"{EXISTING_STUDY}__{EXISTING_DATA_P}/{EXISTING_SITE}/" f"{EXISTING_VERSION}/encounter.parquet", ) @@ -281,7 +295,7 @@ def test_powerset_merge_join_study_data( s3_client.upload_file( "./tests/test_data/count_synthea_patient.parquet", TEST_BUCKET, - f"{BucketPath.LAST_VALID.value}/{EXISTING_STUDY}/" + f"{enums.BucketPath.LAST_VALID.value}/{EXISTING_STUDY}/" f"{EXISTING_STUDY}__{EXISTING_DATA_P}/{NEW_SITE}/" f"{EXISTING_VERSION}/encounter.parquet", ) @@ -290,21 +304,21 @@ def test_powerset_merge_join_study_data( "Records": [ { "Sns": { - "Message": f"{BucketPath.LATEST.value}/{EXISTING_STUDY}" + "Message": f"{enums.BucketPath.LATEST.value}/{EXISTING_STUDY}" f"/{EXISTING_STUDY}__{EXISTING_DATA_P}/{NEW_SITE}" f"/{EXISTING_VERSION}/encounter.parquet" }, } ] } - res = powerset_merge_handler(event, {}) + res = powerset_merge.powerset_merge_handler(event, {}) assert res["statusCode"] == 200 errors = 0 s3_res = s3_client.list_objects_v2(Bucket=TEST_BUCKET) for item in s3_res["Contents"]: - if item["Key"].startswith(BucketPath.ERROR.value): + if item["Key"].startswith(enums.BucketPath.ERROR.value): errors += 1 - elif item["Key"].startswith(f"{BucketPath.AGGREGATE.value}/study"): + elif item["Key"].startswith(f"{enums.BucketPath.AGGREGATE.value}/study"): agg_df = awswrangler.s3.read_parquet(f"s3://{TEST_BUCKET}/{item['Key']}") # if a file cant be merged and there's no fallback, we expect # [, site_name], otherwise, [, site_name, uploading_site_name] @@ -324,12 +338,12 @@ def test_powerset_merge_join_study_data( ( "./tests/test_data/cube_simple_example.parquet", False, - pytest.raises(MergeError), + pytest.raises(powerset_merge.MergeError), ), ( "./tests/test_data/count_synthea_empty.parquet", True, - pytest.raises(MergeError), + pytest.raises(powerset_merge.MergeError), ), ], ) @@ -343,7 +357,9 @@ def test_expand_and_concat(mock_bucket, upload_file, load_empty, raises): TEST_BUCKET, s3_path, ) - expand_and_concat_sets(df, f"s3://{TEST_BUCKET}/{s3_path}", EXISTING_STUDY) + powerset_merge.expand_and_concat_sets( + df, f"s3://{TEST_BUCKET}/{s3_path}", EXISTING_STUDY + ) def test_parquet_to_csv(mock_bucket): @@ -355,7 +371,7 @@ def test_parquet_to_csv(mock_bucket): TEST_BUCKET, f"{bucket_root}/{subbucket_path}", ) - generate_csv_from_parquet(TEST_BUCKET, bucket_root, subbucket_path) + powerset_merge.generate_csv_from_parquet(TEST_BUCKET, bucket_root, subbucket_path) df = awswrangler.s3.read_csv( f"s3://{TEST_BUCKET}/{bucket_root}/{subbucket_path.replace('.parquet','.csv')}" ) diff --git a/tests/site_upload/test_process_upload.py b/tests/site_upload/test_process_upload.py index 2c51b85..0b4f3e9 100644 --- a/tests/site_upload/test_process_upload.py +++ b/tests/site_upload/test_process_upload.py @@ -7,7 +7,7 @@ from src.handlers.shared.enums import BucketPath from src.handlers.shared.functions import read_metadata from src.handlers.site_upload.process_upload import process_upload_handler -from tests.utils import ( +from tests.mock_utils import ( EXISTING_DATA_P, EXISTING_SITE, EXISTING_STUDY, @@ -175,6 +175,7 @@ def test_process_upload( or item["Key"].startswith(BucketPath.ADMIN.value) or item["Key"].startswith(BucketPath.CACHE.value) or item["Key"].endswith("study_periods.json") + or item["Key"].endswith("column_types.json") ) if found_archive: assert "template" in upload_path diff --git a/tests/site_upload/test_study_period.py b/tests/site_upload/test_study_period.py index 61ce6fe..036dcda 100644 --- a/tests/site_upload/test_study_period.py +++ b/tests/site_upload/test_study_period.py @@ -8,7 +8,7 @@ from src.handlers.shared.enums import BucketPath from src.handlers.shared.functions import read_metadata from src.handlers.site_upload.study_period import study_period_handler -from tests.utils import ( +from tests.mock_utils import ( EXISTING_DATA_P, EXISTING_SITE, EXISTING_STUDY,