From d06c1af59cfab46d3921237b35e1b3ca415eb945 Mon Sep 17 00:00:00 2001 From: matt garber Date: Tue, 13 Feb 2024 15:30:16 -0500 Subject: [PATCH] Ruff, dependency upgrades (#112) * Ruff * PR feedback, removed unused pylint excepts --- .github/workflows/ci.yaml | 24 ++- .pre-commit-config.yaml | 23 +-- pyproject.toml | 46 +++--- scripts/cumulus_upload_data.py | 2 +- .../migrations/migration.002.column_types.py | 86 ++++++++++ src/handlers/dashboard/filter_config.py | 5 +- src/handlers/dashboard/get_chart_data.py | 15 +- src/handlers/dashboard/get_csv.py | 111 +++++++++++++ src/handlers/dashboard/get_data_packages.py | 6 +- src/handlers/dashboard/get_metadata.py | 6 +- src/handlers/dashboard/get_study_periods.py | 6 +- src/handlers/shared/awswrangler_functions.py | 2 +- src/handlers/shared/decorators.py | 4 +- src/handlers/shared/enums.py | 19 ++- src/handlers/shared/functions.py | 145 ++++++++++++----- .../site_upload/api_gateway_authorizer.py | 17 +- src/handlers/site_upload/cache_api.py | 6 +- src/handlers/site_upload/fetch_upload_url.py | 6 +- src/handlers/site_upload/powerset_merge.py | 152 ++++++++++++------ src/handlers/site_upload/process_upload.py | 86 +++++----- src/handlers/site_upload/study_period.py | 55 ++++--- template.yaml | 97 +++++++++-- tests/conftest.py | 82 ++++++---- tests/dashboard/test_filter_config.py | 6 +- tests/dashboard/test_get_chart_data.py | 12 +- tests/dashboard/test_get_csv.py | 148 +++++++++++++++++ tests/dashboard/test_get_metadata.py | 57 ++++--- tests/dashboard/test_get_study_periods.py | 15 +- tests/dashboard/test_get_subscriptions.py | 6 +- tests/{utils.py => mock_utils.py} | 22 ++- .../test_api_gateway_authorizer.py | 17 +- tests/site_upload/test_cache_api.py | 4 +- tests/site_upload/test_fetch_upload_url.py | 7 +- tests/site_upload/test_powerset_merge.py | 98 ++++++----- tests/site_upload/test_process_upload.py | 9 +- tests/site_upload/test_study_period.py | 11 +- 36 files changed, 981 insertions(+), 432 deletions(-) create mode 100644 scripts/migrations/migration.002.column_types.py create mode 100644 src/handlers/dashboard/get_csv.py create mode 100644 tests/dashboard/test_get_csv.py rename tests/{utils.py => mock_utils.py} (88%) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 849f5e6..66a5606 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -10,7 +10,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: '3.10' - name: Install dependencies run: | @@ -23,23 +23,17 @@ jobs: lint: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' - name: Install linters run: | python -m pip install --upgrade pip pip install ".[dev]" - - name: Run pycodestyle - run: | - pycodestyle scripts src tests --max-line-length=88 - - name: Run pylint - if: success() || failure() # still run pylint if above checks fail - run: | - pylint scripts src tests - - name: Run bandit - if: success() || failure() # still run bandit if above checks fail - run: | - bandit -r scripts src - - name: Run black + - name: Run ruff if: success() || failure() # still run black if above checks fails run: | - black --check --verbose . + ruff check + ruff format --check diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 144f529..6fae921 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,10 @@ +default_install_hook_types: [pre-commit, pre-push] repos: - - repo: https://github.com/psf/black - #this version is synced with the black mentioned in .github/workflows/ci.yml - rev: 22.12.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.2.1 hooks: - - id: black - entry: bash -c 'black "$@"; git add -u' -- - # It is recommended to specify the latest version of Python - # supported by your project here, or alternatively use - # pre-commit's default_language_version, see - # https://pre-commit.com/#top_level-default_language_version - language_version: python3.9 - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", "--filter-files"] + - name: Ruff formatting + id: ruff-format + - name: Ruff linting + id: ruff + stages: [pre-push] diff --git a/pyproject.toml b/pyproject.toml index be210b8..c0b703e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,16 @@ [project] name = "aggregator" -requires-python = ">= 3.9" -version = "0.1.3" +requires-python = ">= 3.10" +version = "0.3.0" # This project is designed to run on the AWS serverless application framework (SAM). # The project dependencies are handled via AWS layers. These are only required for # local development. dependencies= [ "arrow >=1.2.3", - "awswrangler >=2.19.0, <3", + "awswrangler >=3.5, <4", "boto3", - "pandas >=1.5.0, <2" + "pandas >=2, <3", + "rich" ] authors = [ { name="Matt Garber", email="matthew.garber@childrens.harvard.edu" }, @@ -45,23 +46,26 @@ test = [ "pytest-mock" ] dev = [ - "bandit", - "black==22.12.0", - "isort==5.12.0", + "ruff == 0.2.1", "pre-commit", - "pylint", - "pycodestyle" ] +[tool.ruff] +target-version = "py310" -[tool.coverage.run] -command_line="-m pytest" -source=["./src/"] - -[tool.coverage.report] -show_missing=true - -[tool.isort] -profile = "black" -src_paths = ["src", "tests"] -skip_glob = [".aws_sam"] - +[tool.ruff.lint] +select = [ + "A", # prevent using keywords that clobber python builtins + "B", # bugbear: security warnings + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "ISC", # implicit string concatenation + "PLE", # pylint errors + "RUF", # the ruff developer's own rules + "UP", # alert you when better syntax is available in your python version +] +ignore = [ +# Recommended ingore from `ruff format` due to in-project conflicts with check. +# It's expected that this will be fixed in the coming months. + "ISC001" +] diff --git a/scripts/cumulus_upload_data.py b/scripts/cumulus_upload_data.py index 6975c80..66ad8f6 100755 --- a/scripts/cumulus_upload_data.py +++ b/scripts/cumulus_upload_data.py @@ -107,7 +107,7 @@ def upload_file(cli_args): if args["test"]: args_dict["user"] = os.environ.get("CUMULUS_TEST_UPLOAD_USER", "general") args_dict["file"] = ( - f"{str(Path(__file__).resolve().parents[1])}" + f"{Path(__file__).resolve().parents[1]!s}" f"/tests/test_data/count_synthea_patient.parquet" ) args_dict["auth"] = os.environ.get("CUMULUS_TEST_UPLOAD_AUTH", "secretval") diff --git a/scripts/migrations/migration.002.column_types.py b/scripts/migrations/migration.002.column_types.py new file mode 100644 index 0000000..6651213 --- /dev/null +++ b/scripts/migrations/migration.002.column_types.py @@ -0,0 +1,86 @@ +""" Adds a new metadata type, column_types """ + +import argparse +import io +import json + +import boto3 +import pandas +from rich import progress + + +def get_csv_column_datatypes(dtypes): + """helper for generating column type for dashboard API""" + column_dict = {} + for column in dtypes.index: + if column.endswith("year"): + column_dict[column] = "year" + elif column.endswith("month"): + column_dict[column] = "month" + elif column.endswith("week"): + column_dict[column] = "week" + elif column.endswith("day") or str(dtypes[column]) == "datetime64": + column_dict[column] = "day" + elif "cnt" in column or str(dtypes[column]) in ( + "Int8", + "Int16", + "Int32", + "Int64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + ): + column_dict[column] = "integer" + elif str(dtypes[column]) in ("Float32", "Float64"): + column_dict[column] = "float" + elif str(dtypes[column]) == "boolean": + column_dict[column] = "float" + else: + column_dict[column] = "string" + return column_dict + + +def _put_s3_data(key: str, bucket_name: str, client, data: dict) -> None: + """Convenience class for writing a dict to S3""" + b_data = io.BytesIO(json.dumps(data).encode()) + client.upload_fileobj(Bucket=bucket_name, Key=key, Fileobj=b_data) + + +def create_column_type_metadata(bucket: str): + """creates a new metadata dict for column types. + + By design, this will replaces an existing column type dict if one already exists. + """ + client = boto3.client("s3") + res = client.list_objects_v2(Bucket=bucket, Prefix="aggregates/") + contents = res["Contents"] + output = {} + for resource in progress.track(contents): + dirs = resource["Key"].split("/") + study = dirs[1] + subscription = dirs[2].split("__")[1] + version = dirs[3] + bytes_buffer = io.BytesIO() + client.download_fileobj( + Bucket=bucket, Key=resource["Key"], Fileobj=bytes_buffer + ) + df = pandas.read_parquet(bytes_buffer) + type_dict = get_csv_column_datatypes(df.dtypes) + filename = f"{resource['Key'].split('/')[-1].split('.')[0]}.csv" + output.setdefault(study, {}) + output[study].setdefault(subscription, {}) + output[study][subscription].setdefault(version, {}) + output[study][subscription][version]["columns"] = type_dict + output[study][subscription][version]["filename"] = filename + # print(json.dumps(output, indent=2)) + _put_s3_data("metadata/column_types.json", bucket, client, output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Creates column types for existing aggregates. """ + ) + parser.add_argument("-b", "--bucket", help="bucket name") + args = parser.parse_args() + create_column_type_metadata(args.bucket) diff --git a/src/handlers/dashboard/filter_config.py b/src/handlers/dashboard/filter_config.py index beb8a26..27572b0 100644 --- a/src/handlers/dashboard/filter_config.py +++ b/src/handlers/dashboard/filter_config.py @@ -73,10 +73,7 @@ def _parse_filter_req(filter_req): if "," in filter_req: return " AND ".join(_parse_filter_req(x) for x in filter_req.split(",")) filter_req_split = filter_req.split(":") - if ( - filter_req_split[1] - in _FILTER_MAP_ONE_PARAM.keys() # pylint: disable=consider-iterating-dictionary - ): + if filter_req_split[1] in _FILTER_MAP_ONE_PARAM: return _FILTER_MAP_ONE_PARAM[filter_req_split[1]] % filter_req_split[0] return _FILTER_MAP_TWO_PARAM[filter_req_split[1]] % ( filter_req_split[0], diff --git a/src/handlers/dashboard/get_chart_data.py b/src/handlers/dashboard/get_chart_data.py index b1ae477..db89b1c 100644 --- a/src/handlers/dashboard/get_chart_data.py +++ b/src/handlers/dashboard/get_chart_data.py @@ -7,13 +7,13 @@ import boto3 import pandas -from ..dashboard.filter_config import get_filter_string -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath -from ..shared.functions import get_latest_data_package_version, http_response +from src.handlers.dashboard.filter_config import get_filter_string +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import BucketPath +from src.handlers.shared.functions import get_latest_data_package_version, http_response -def _get_table_cols(table_name: str, version: str = None) -> list: +def _get_table_cols(table_name: str, version: str | None = None) -> list: """Returns the columns associated with a table. Since running an athena query takes a decent amount of time due to queueing @@ -29,7 +29,8 @@ def _get_table_cols(table_name: str, version: str = None) -> list: s3_key = f"{prefix}/{version}/{table_name}__aggregate.csv" s3_client = boto3.client("s3") s3_iter = s3_client.get_object( - Bucket=s3_bucket_name, Key=s3_key # type: ignore[arg-type] + Bucket=s3_bucket_name, + Key=s3_key, )["Body"].iter_lines() return next(s3_iter).decode().split(",") @@ -41,7 +42,7 @@ def _build_query(query_params: dict, filters: list, path_params: dict) -> str: filter_str = get_filter_string(filters) if filter_str != "": filter_str = f"AND {filter_str}" - count_col = [c for c in columns if c.startswith("cnt")][0] + count_col = next(c for c in columns if c.startswith("cnt")) columns.remove(count_col) select_str = f"{query_params['column']}, sum({count_col}) as {count_col}" group_str = f"{query_params['column']}" diff --git a/src/handlers/dashboard/get_csv.py b/src/handlers/dashboard/get_csv.py new file mode 100644 index 0000000..4eb52ea --- /dev/null +++ b/src/handlers/dashboard/get_csv.py @@ -0,0 +1,111 @@ +import os + +import boto3 +import botocore + +from src.handlers.shared import decorators, enums, functions + + +def _format_key( + s3_client, + s3_bucket_name: str, + study: str, + subscription: str, + version: str, + filename: str, + site: str | None = None, +): + """Creates S3 key from url params""" + if site is not None: + key = f"last_valid/{study}/{study}__{subscription}/{site}/{version}/{filename}" + else: + key = f"csv_aggregates/{study}/{study}__{subscription}/{version}/{filename}" + s3_client.list_objects_v2(Bucket=s3_bucket_name) + try: + s3_client.head_object(Bucket=s3_bucket_name, Key=key) + return key + except botocore.exceptions.ClientError as e: + raise OSError(f"No object found at key {key}") from e + + +def _get_column_types( + s3_client, + s3_bucket_name: str, + study: str, + subscription: str, + version: str, + **kwargs, +) -> dict: + """Gets column types from the metadata store for a given subscription""" + types_metadata = functions.read_metadata( + s3_client, + s3_bucket_name, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) + try: + return types_metadata[study][subscription][version][ + enums.ColumnTypesKeys.COLUMNS.value + ] + except KeyError: + return {} + + +@decorators.generic_error_handler(msg="Error retrieving chart data") +def get_csv_handler(event, context): + """manages event from dashboard api call and creates a temporary URL""" + del context + s3_bucket_name = os.environ.get("BUCKET_NAME") + s3_client = boto3.client("s3") + key = _format_key(s3_client, s3_bucket_name, **event["pathParameters"]) + types = _get_column_types(s3_client, s3_bucket_name, **event["pathParameters"]) + presign_url = s3_client.generate_presigned_url( + "get_object", + Params={ + "Bucket": s3_bucket_name, + "Key": key, + "ResponseContentType": "text/csv", + }, + ExpiresIn=600, + ) + extra_headers = { + "Location": presign_url, + "x-column-names": ",".join(key for key in types.keys()), + "x-column-types": ",".join(key for key in types.values()), + # TODO: add x-column-descriptions once a source for column descriptions + # has been established + } + res = functions.http_response(302, "", extra_headers=extra_headers) + return res + + +@decorators.generic_error_handler(msg="Error retrieving csv data") +def get_csv_list_handler(event, context): + """manages event from dashboard api call and creates a temporary URL""" + del context + s3_bucket_name = os.environ.get("BUCKET_NAME") + s3_client = boto3.client("s3") + if event["path"].startswith("/last_valid"): + key_prefix = "last_valid" + url_prefix = "last_valid" + elif event["path"].startswith("/aggregates"): + key_prefix = "csv_aggregates" + url_prefix = "aggregates" + else: + raise Exception("Unexpected url encountered") + s3_objs = s3_client.list_objects_v2(Bucket=s3_bucket_name, Prefix=key_prefix) + urls = [] + if s3_objs["KeyCount"] == 0: + return functions.http_response(200, urls) + for obj in s3_objs["Contents"]: + key_parts = obj["Key"].split("/") + study = key_parts[1] + subscription = key_parts[2].split("__")[1] + version = key_parts[-2] + filename = key_parts[-1] + site = key_parts[3] if url_prefix == "last_valid" else None + url_parts = [url_prefix, study, subscription, version, filename] + if url_prefix == "last_valid": + url_parts.insert(3, site) + urls.append("/".join(url_parts)) + res = functions.http_response(200, urls) + return res diff --git a/src/handlers/dashboard/get_data_packages.py b/src/handlers/dashboard/get_data_packages.py index c3c349b..423425b 100644 --- a/src/handlers/dashboard/get_data_packages.py +++ b/src/handlers/dashboard/get_data_packages.py @@ -3,9 +3,9 @@ import os -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath, JsonFilename -from ..shared.functions import get_s3_json_as_dict, http_response +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import BucketPath, JsonFilename +from src.handlers.shared.functions import get_s3_json_as_dict, http_response @generic_error_handler(msg="Error retrieving data packages") diff --git a/src/handlers/dashboard/get_metadata.py b/src/handlers/dashboard/get_metadata.py index fa6c2d0..52600ee 100644 --- a/src/handlers/dashboard/get_metadata.py +++ b/src/handlers/dashboard/get_metadata.py @@ -4,8 +4,8 @@ import boto3 -from ..shared.decorators import generic_error_handler -from ..shared.functions import http_response, read_metadata +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.functions import http_response, read_metadata @generic_error_handler(msg="Error retrieving metadata") @@ -22,5 +22,7 @@ def metadata_handler(event, context): metadata = metadata[params["study"]] if "data_package" in params: metadata = metadata[params["data_package"]] + if "version" in params: + metadata = metadata[params["version"]] res = http_response(200, metadata) return res diff --git a/src/handlers/dashboard/get_study_periods.py b/src/handlers/dashboard/get_study_periods.py index c275891..89b2100 100644 --- a/src/handlers/dashboard/get_study_periods.py +++ b/src/handlers/dashboard/get_study_periods.py @@ -4,9 +4,9 @@ import boto3 -from ..shared.decorators import generic_error_handler -from ..shared.enums import JsonFilename -from ..shared.functions import http_response, read_metadata +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import JsonFilename +from src.handlers.shared.functions import http_response, read_metadata @generic_error_handler(msg="Error retrieving study period") diff --git a/src/handlers/shared/awswrangler_functions.py b/src/handlers/shared/awswrangler_functions.py index e64ca54..52ec1f2 100644 --- a/src/handlers/shared/awswrangler_functions.py +++ b/src/handlers/shared/awswrangler_functions.py @@ -1,7 +1,7 @@ """ functions specifically requiring AWSWranger, which requires a lambda layer""" import awswrangler -from .enums import BucketPath +from src.handlers.shared.enums import BucketPath def get_s3_data_package_list( diff --git a/src/handlers/shared/decorators.py b/src/handlers/shared/decorators.py index 5a88b41..36c1e1d 100644 --- a/src/handlers/shared/decorators.py +++ b/src/handlers/shared/decorators.py @@ -3,7 +3,7 @@ import functools import logging -from .functions import http_response +from src.handlers.shared.functions import http_response def generic_error_handler(msg="Internal server error"): @@ -14,7 +14,7 @@ def error_decorator(func): def wrapper(*args, **kwargs): try: res = func(*args, **kwargs) - except Exception as e: # pylint: disable=broad-except + except Exception as e: trace = [] tb = e.__traceback__ while tb is not None: diff --git a/src/handlers/shared/enums.py b/src/handlers/shared/enums.py index 1084d51..65dae2d 100644 --- a/src/handlers/shared/enums.py +++ b/src/handlers/shared/enums.py @@ -1,8 +1,8 @@ """Enums shared across lambda functions""" -from enum import Enum +import enum -class BucketPath(Enum): +class BucketPath(enum.Enum): """stores root level buckets for managing data processing state""" ADMIN = "admin" @@ -18,15 +18,24 @@ class BucketPath(Enum): UPLOAD = "site_upload" -class JsonFilename(Enum): +class ColumnTypesKeys(enum.Enum): + """stores names of expected keys in the study period metadata dictionary""" + + COLUMN_TYPES_FORMAT_VERSION = "column_types_format_version" + COLUMNS = "columns" + LAST_DATA_UPDATE = "last_data_update" + + +class JsonFilename(enum.Enum): """stores names of expected kinds of persisted S3 JSON files""" + COLUMN_TYPES = "column_types" TRANSACTIONS = "transactions" DATA_PACKAGES = "data_packages" STUDY_PERIODS = "study_periods" -class TransactionKeys(Enum): +class TransactionKeys(enum.Enum): """stores names of expected keys in the transaction dictionary""" TRANSACTION_FORMAT_VERSION = "transaction_format_version" @@ -37,7 +46,7 @@ class TransactionKeys(Enum): DELETED = "deleted" -class StudyPeriodMetadataKeys(Enum): +class StudyPeriodMetadataKeys(enum.Enum): """stores names of expected keys in the study period metadata dictionary""" STUDY_PERIOD_FORMAT_VERSION = "study_period_format_version" diff --git a/src/handlers/shared/functions.py b/src/handlers/shared/functions.py index 2309ef8..3423402 100644 --- a/src/handlers/shared/functions.py +++ b/src/handlers/shared/functions.py @@ -3,30 +3,37 @@ import json import logging from datetime import datetime, timezone -from typing import Optional import boto3 -from .enums import BucketPath, JsonFilename, StudyPeriodMetadataKeys, TransactionKeys +from src.handlers.shared import enums TRANSACTION_METADATA_TEMPLATE = { - TransactionKeys.TRANSACTION_FORMAT_VERSION.value: "2", - TransactionKeys.LAST_UPLOAD.value: None, - TransactionKeys.LAST_DATA_UPDATE.value: None, - TransactionKeys.LAST_AGGREGATION.value: None, - TransactionKeys.LAST_ERROR.value: None, - TransactionKeys.DELETED.value: None, + enums.TransactionKeys.TRANSACTION_FORMAT_VERSION.value: "2", + enums.TransactionKeys.LAST_UPLOAD.value: None, + enums.TransactionKeys.LAST_DATA_UPDATE.value: None, + enums.TransactionKeys.LAST_AGGREGATION.value: None, + enums.TransactionKeys.LAST_ERROR.value: None, + enums.TransactionKeys.DELETED.value: None, } STUDY_PERIOD_METADATA_TEMPLATE = { - StudyPeriodMetadataKeys.STUDY_PERIOD_FORMAT_VERSION.value: "2", - StudyPeriodMetadataKeys.EARLIEST_DATE.value: None, - StudyPeriodMetadataKeys.LATEST_DATE.value: None, - StudyPeriodMetadataKeys.LAST_DATA_UPDATE.value: None, + enums.StudyPeriodMetadataKeys.STUDY_PERIOD_FORMAT_VERSION.value: "2", + enums.StudyPeriodMetadataKeys.EARLIEST_DATE.value: None, + enums.StudyPeriodMetadataKeys.LATEST_DATE.value: None, + enums.StudyPeriodMetadataKeys.LAST_DATA_UPDATE.value: None, +} + +COLUMN_TYPES_METADATA_TEMPLATE = { + enums.ColumnTypesKeys.COLUMN_TYPES_FORMAT_VERSION.value: "1", + enums.ColumnTypesKeys.COLUMNS.value: None, + enums.ColumnTypesKeys.LAST_DATA_UPDATE.value: None, } -def http_response(status: int, body: str, allow_cors: bool = False) -> dict: +def http_response( + status: int, body: str, allow_cors: bool = False, extra_headers: dict | None = None +) -> dict: """Generates the payload AWS lambda expects as a return value""" headers = {"Content-Type": "application/json"} if allow_cors: @@ -37,6 +44,8 @@ def http_response(status: int, body: str, allow_cors: bool = False) -> dict: "Access-Control-Allow-Methods": "GET", } ) + if extra_headers: + headers.update(extra_headers) return { "isBase64Encoded": False, "statusCode": status, @@ -50,17 +59,21 @@ def http_response(status: int, body: str, allow_cors: bool = False) -> dict: def check_meta_type(meta_type: str) -> None: """helper for ensuring specified metadata types""" - types = [item.value for item in JsonFilename] + types = [item.value for item in enums.JsonFilename] if meta_type not in types: raise ValueError("invalid metadata type specified") def read_metadata( - s3_client, s3_bucket_name: str, meta_type: str = JsonFilename.TRANSACTIONS.value + s3_client, + s3_bucket_name: str, + *, + meta_type: str = enums.JsonFilename.TRANSACTIONS.value, ) -> dict: """Reads transaction information from an s3 bucket as a dictionary""" + print(s3_bucket_name) check_meta_type(meta_type) - s3_path = f"{BucketPath.META.value}/{meta_type}.json" + s3_path = f"{enums.BucketPath.META.value}/{meta_type}.json" res = s3_client.list_objects_v2(Bucket=s3_bucket_name, Prefix=s3_path) if "Contents" in res: res = s3_client.get_object(Bucket=s3_bucket_name, Key=s3_path) @@ -71,53 +84,73 @@ def read_metadata( def update_metadata( + *, metadata: dict, - site: str, study: str, data_package: str, version: str, target: str, - dt: Optional[datetime] = None, - meta_type: str = JsonFilename.TRANSACTIONS.value, + site: str | None = None, + dt: datetime | None = None, + value: str | list | None = None, + meta_type: str | None = enums.JsonFilename.TRANSACTIONS.value, ): """Safely updates items in metadata dictionary - It's assumed that, other than the version field itself, every item in one - of these metadata dicts is a datetime corresponding to an S3 event timestamp + It's assumed that, other than the version/column/type fields, every item in one + of these metadata dicts is a ISO date string corresponding to an S3 event timestamp """ check_meta_type(meta_type) - if meta_type == JsonFilename.TRANSACTIONS.value: - site_metadata = metadata.setdefault(site, {}) - study_metadata = site_metadata.setdefault(study, {}) - data_package_metadata = study_metadata.setdefault(data_package, {}) - data_version_metadata = data_package_metadata.setdefault( - version, TRANSACTION_METADATA_TEMPLATE - ) - dt = dt or datetime.now(timezone.utc) - data_version_metadata[target] = dt.isoformat() - elif meta_type == JsonFilename.STUDY_PERIODS.value: - site_metadata = metadata.setdefault(site, {}) - study_period_metadata = site_metadata.setdefault(study, {}) - data_version_metadata = study_period_metadata.setdefault( - version, STUDY_PERIOD_METADATA_TEMPLATE - ) - dt = dt or datetime.now(timezone.utc) - data_version_metadata[target] = dt.isoformat() + match meta_type: + case enums.JsonFilename.TRANSACTIONS.value: + site_metadata = metadata.setdefault(site, {}) + study_metadata = site_metadata.setdefault(study, {}) + data_package_metadata = study_metadata.setdefault(data_package, {}) + data_version_metadata = data_package_metadata.setdefault( + version, TRANSACTION_METADATA_TEMPLATE + ) + dt = dt or datetime.now(timezone.utc) + data_version_metadata[target] = dt.isoformat() + case enums.JsonFilename.STUDY_PERIODS.value: + site_metadata = metadata.setdefault(site, {}) + study_period_metadata = site_metadata.setdefault(study, {}) + data_version_metadata = study_period_metadata.setdefault( + version, STUDY_PERIOD_METADATA_TEMPLATE + ) + dt = dt or datetime.now(timezone.utc) + data_version_metadata[target] = dt.isoformat() + case enums.JsonFilename.COLUMN_TYPES.value: + study_metadata = metadata.setdefault(study, {}) + data_package_metadata = study_metadata.setdefault(data_package, {}) + data_version_metadata = data_package_metadata.setdefault( + version, COLUMN_TYPES_METADATA_TEMPLATE + ) + if target == enums.ColumnTypesKeys.COLUMNS.value: + data_version_metadata[target] = value + else: + dt = dt or datetime.now(timezone.utc) + data_version_metadata[target] = dt.isoformat() + # Should only be hit if you add a new JSON dict and forget to add it + # to this function + case _: + raise OSError(f"{meta_type} does not have a handler for updates.") return metadata def write_metadata( + *, s3_client, s3_bucket_name: str, metadata: dict, - meta_type: str = JsonFilename.TRANSACTIONS.value, + meta_type: str = enums.JsonFilename.TRANSACTIONS.value, ) -> None: """Writes transaction info from ∏a dictionary to an s3 bucket metadata location""" check_meta_type(meta_type) + s3_client.put_object( Bucket=s3_bucket_name, - Key=f"{BucketPath.META.value}/{meta_type}.json", + Key=f"{enums.BucketPath.META.value}/{meta_type}.json", Body=json.dumps(metadata), ) @@ -182,3 +215,35 @@ def get_latest_data_package_version(bucket, prefix): if highest_ver is None: logging.error("No data package versions found for %s", prefix) return highest_ver + + +def get_csv_column_datatypes(dtypes): + """helper for generating column type for dashboard API""" + column_dict = {} + for column in dtypes.index: + if column.endswith("year"): + column_dict[column] = "year" + elif column.endswith("month"): + column_dict[column] = "month" + elif column.endswith("week"): + column_dict[column] = "week" + elif column.endswith("day") or str(dtypes[column]) == "datetime64": + column_dict[column] = "day" + elif "cnt" in column or str(dtypes[column]) in ( + "Int8", + "Int16", + "Int32", + "Int64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + ): + column_dict[column] = "integer" + elif str(dtypes[column]) in ("Float32", "Float64"): + column_dict[column] = "float" + elif str(dtypes[column]) == "boolean": + column_dict[column] = "float" + else: + column_dict[column] = "string" + return column_dict diff --git a/src/handlers/site_upload/api_gateway_authorizer.py b/src/handlers/site_upload/api_gateway_authorizer.py index aee9a4c..76840f4 100644 --- a/src/handlers/site_upload/api_gateway_authorizer.py +++ b/src/handlers/site_upload/api_gateway_authorizer.py @@ -3,13 +3,12 @@ """ # pylint: disable=invalid-name,pointless-string-statement -from __future__ import print_function import os import re -from ..shared.enums import BucketPath -from ..shared.functions import get_s3_json_as_dict +from src.handlers.shared.enums import BucketPath +from src.handlers.shared.functions import get_s3_json_as_dict class AuthError(Exception): @@ -28,7 +27,7 @@ def lambda_handler(event, context): if auth_token not in user_db.keys() or auth_header[0] != "Basic": raise AuthError except (AuthError, KeyError): - raise AuthError(event) # pylint: disable=raise-missing-from + raise AuthError(event) # noqa: B904 principalId = user_db[auth_token]["site"] @@ -66,7 +65,7 @@ class HttpVerb: ALL = "*" -class AuthPolicy(object): # pylint: disable=missing-class-docstring; # pragma: no cover +class AuthPolicy: awsAccountId = "" """The AWS account id the policy will be generated for. This is used to create the method ARNs.""" @@ -81,8 +80,8 @@ class AuthPolicy(object): # pylint: disable=missing-class-docstring; # pragma: conditions statement. the build method processes these lists and generates the approriate statements for the final policy""" - allowMethods = [] - denyMethods = [] + allowMethods = [] # noqa: RUF012 + denyMethods = [] # noqa: RUF012 restApiId = "<>" """ Replace the placeholder value with a default API Gateway API id to be used in @@ -211,7 +210,7 @@ def allowMethodWithConditions(self, verb, resource, conditions): methods and includes a condition for the policy statement. More on AWS policy conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition - """ # noqa: E501 + """ self._addMethod("Allow", verb, resource, conditions) def denyMethodWithConditions(self, verb, resource, conditions): @@ -219,7 +218,7 @@ def denyMethodWithConditions(self, verb, resource, conditions): methods and includes a condition for the policy statement. More on AWS policy conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition - """ # noqa: E501 + """ self._addMethod("Deny", verb, resource, conditions) def build(self): diff --git a/src/handlers/site_upload/cache_api.py b/src/handlers/site_upload/cache_api.py index f42a27f..362c4fe 100644 --- a/src/handlers/site_upload/cache_api.py +++ b/src/handlers/site_upload/cache_api.py @@ -6,9 +6,9 @@ import awswrangler import boto3 -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath, JsonFilename -from ..shared.functions import http_response +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import BucketPath, JsonFilename +from src.handlers.shared.functions import http_response def cache_api_data(s3_client, s3_bucket_name: str, db: str, target: str) -> None: diff --git a/src/handlers/site_upload/fetch_upload_url.py b/src/handlers/site_upload/fetch_upload_url.py index d532cde..a1d95d6 100644 --- a/src/handlers/site_upload/fetch_upload_url.py +++ b/src/handlers/site_upload/fetch_upload_url.py @@ -6,9 +6,9 @@ import boto3 import botocore.exceptions -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath -from ..shared.functions import get_s3_json_as_dict, http_response +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import BucketPath +from src.handlers.shared.functions import get_s3_json_as_dict, http_response def create_presigned_post( diff --git a/src/handlers/site_upload/powerset_merge.py b/src/handlers/site_upload/powerset_merge.py index e65a41f..3ce861f 100644 --- a/src/handlers/site_upload/powerset_merge.py +++ b/src/handlers/site_upload/powerset_merge.py @@ -1,27 +1,17 @@ """ Lambda for performing joins of site count data """ import csv +import datetime import logging import os import traceback -from datetime import datetime, timezone import awswrangler import boto3 +import numpy import pandas -from numpy import nan from pandas.core.indexes.range import RangeIndex -from ..shared.awswrangler_functions import get_s3_data_package_list -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath, TransactionKeys -from ..shared.functions import ( - get_s3_site_filename_suffix, - http_response, - move_s3_file, - read_metadata, - update_metadata, - write_metadata, -) +from src.handlers.shared import awswrangler_functions, decorators, enums, functions class MergeError(ValueError): @@ -46,18 +36,23 @@ def __init__(self, event): self.data_package = s3_key_array[2].split("__")[1] self.site = s3_key_array[3] self.version = s3_key_array[4] - self.metadata = read_metadata(self.s3_client, self.s3_bucket_name) + self.metadata = functions.read_metadata(self.s3_client, self.s3_bucket_name) + self.types_metadata = functions.read_metadata( + self.s3_client, + self.s3_bucket_name, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) # S3 Filesystem operations def get_data_package_list(self, path) -> list: """convenience wrapper for get_s3_data_package_list""" - return get_s3_data_package_list( + return awswrangler_functions.get_s3_data_package_list( path, self.s3_bucket_name, self.study, self.data_package ) def move_file(self, from_path: str, to_path: str) -> None: """convenience wrapper for move_s3_file""" - move_s3_file(self.s3_client, self.s3_bucket_name, from_path, to_path) + functions.move_s3_file(self.s3_client, self.s3_bucket_name, from_path, to_path) def copy_file(self, from_path: str, to_path: str) -> None: """convenience wrapper for copy_s3_file""" @@ -75,7 +70,7 @@ def copy_file(self, from_path: str, to_path: str) -> None: def write_parquet(self, df: pandas.DataFrame, is_new_data_package: bool) -> None: """writes dataframe as parquet to s3 and sends an SNS notification if new""" parquet_aggregate_path = ( - f"s3://{self.s3_bucket_name}/{BucketPath.AGGREGATE.value}/" + f"s3://{self.s3_bucket_name}/{enums.BucketPath.AGGREGATE.value}/" f"{self.study}/{self.study}__{self.data_package}/{self.version}/" f"{self.study}__{self.data_package}__aggregate.parquet" ) @@ -89,12 +84,12 @@ def write_parquet(self, df: pandas.DataFrame, is_new_data_package: bool) -> None def write_csv(self, df: pandas.DataFrame) -> None: """writes dataframe as csv to s3""" csv_aggregate_path = ( - f"s3://{self.s3_bucket_name}/{BucketPath.CSVAGGREGATE.value}/" + f"s3://{self.s3_bucket_name}/{enums.BucketPath.CSVAGGREGATE.value}/" f"{self.study}/{self.study}__{self.data_package}/{self.version}/" f"{self.study}__{self.data_package}__aggregate.csv" ) df = df.apply(lambda x: x.strip() if isinstance(x, str) else x).replace( - '""', nan + '""', numpy.nan ) df = df.replace(to_replace=r",", value="", regex=True) awswrangler.s3.to_csv( @@ -102,17 +97,45 @@ def write_csv(self, df: pandas.DataFrame) -> None: ) # metadata - def update_local_metadata(self, key, site=None): + def update_local_metadata( + self, + key, + *, + site=None, + value=None, + metadata: dict | None = None, + meta_type: str | None = enums.JsonFilename.TRANSACTIONS.value, + ): """convenience wrapper for update_metadata""" - if site is None: + if site is None and meta_type != enums.JsonFilename.COLUMN_TYPES.value: site = self.site - self.metadata = update_metadata( - self.metadata, site, self.study, self.data_package, self.version, key + if metadata is None: + metadata = self.metadata + metadata = functions.update_metadata( + metadata=metadata, + site=site, + study=self.study, + data_package=self.data_package, + version=self.version, + target=key, + value=value, + meta_type=meta_type, ) - def write_local_metadata(self): + def write_local_metadata( + self, metadata: dict | None = None, meta_type: str | None = None + ): """convenience wrapper for write_metadata""" - write_metadata(self.s3_client, self.s3_bucket_name, self.metadata) + if metadata is None: + metadata = self.metadata + if meta_type is None: + meta_type = enums.JsonFilename.TRANSACTIONS.value + functions.write_metadata( + s3_client=self.s3_client, + s3_bucket_name=self.s3_bucket_name, + metadata=metadata, + meta_type=meta_type, + ) def merge_error_handler( self, @@ -128,9 +151,9 @@ def merge_error_handler( logger.error(traceback.print_exc()) self.move_file( s3_path.replace(f"s3://{self.s3_bucket_name}/", ""), - f"{BucketPath.ERROR.value}/{subbucket_path}", + f"{enums.BucketPath.ERROR.value}/{subbucket_path}", ) - self.update_local_metadata(TransactionKeys.LAST_ERROR.value) + self.update_local_metadata(enums.TransactionKeys.LAST_ERROR.value) def get_static_string_series(static_str: str, index: RangeIndex) -> pandas.Series: @@ -189,7 +212,7 @@ def expand_and_concat_sets( .reset_index() # this last line makes "cnt" the first column in the set, matching the # library style - .filter(["cnt"] + data_cols) + .filter(["cnt", *data_cols]) ) return agg_df @@ -203,14 +226,15 @@ def generate_csv_from_parquet(bucket_name: str, bucket_root: str, subbucket_path ) last_valid_df = last_valid_df.apply( lambda x: x.strip() if isinstance(x, str) else x - ).replace('""', nan) + ).replace('""', numpy.nan) # Here we are removing internal commas from fields so we get a valid unquoted CSV last_valid_df = last_valid_df.replace(to_replace=",", value="", regex=True) awswrangler.s3.to_csv( last_valid_df, ( - f"s3://{bucket_name}/{bucket_root}/" - f"{subbucket_path}".replace(".parquet", ".csv") + f"s3://{bucket_name}/{bucket_root}/{subbucket_path}".replace( + ".parquet", ".csv" + ) ), index=False, quoting=csv.QUOTE_NONE, @@ -225,12 +249,14 @@ def merge_powersets(manager: S3Manager) -> None: # initializing this early in case an empty file causes us to never set it is_new_data_package = False df = pandas.DataFrame() - latest_file_list = manager.get_data_package_list(BucketPath.LATEST.value) - last_valid_file_list = manager.get_data_package_list(BucketPath.LAST_VALID.value) + latest_file_list = manager.get_data_package_list(enums.BucketPath.LATEST.value) + last_valid_file_list = manager.get_data_package_list( + enums.BucketPath.LAST_VALID.value + ) for last_valid_path in last_valid_file_list: if manager.version not in last_valid_path: continue - site_specific_name = get_s3_site_filename_suffix(last_valid_path) + site_specific_name = functions.get_s3_site_filename_suffix(last_valid_path) subbucket_path = f"{manager.study}/{manager.data_package}/{site_specific_name}" last_valid_site = site_specific_name.split("/", maxsplit=1)[0] # If the latest uploads don't include this site, we'll use the last-valid @@ -239,7 +265,7 @@ def merge_powersets(manager: S3Manager) -> None: if not any(x.endswith(site_specific_name) for x in latest_file_list): df = expand_and_concat_sets(df, last_valid_path, last_valid_site) manager.update_local_metadata( - TransactionKeys.LAST_AGGREGATION.value, site=last_valid_site + enums.TransactionKeys.LAST_AGGREGATION.value, site=last_valid_site ) except MergeError as e: # This is expected to trigger if there's an issue in expand_and_concat_sets; @@ -250,15 +276,14 @@ def merge_powersets(manager: S3Manager) -> None: e, ) for latest_path in latest_file_list: - if manager.version not in latest_path: continue - site_specific_name = get_s3_site_filename_suffix(latest_path) + site_specific_name = functions.get_s3_site_filename_suffix(latest_path) subbucket_path = ( f"{manager.study}/{manager.study}__{manager.data_package}" f"/{site_specific_name}" ) - date_str = datetime.now(timezone.utc).isoformat() + date_str = datetime.datetime.now(datetime.timezone.utc).isoformat() timestamped_name = f".{date_str}.".join(site_specific_name.split(".")) timestamped_path = ( f"{manager.study}/{manager.study}__{manager.data_package}" @@ -269,8 +294,8 @@ def merge_powersets(manager: S3Manager) -> None: # if we're going to replace a file in last_valid, archive the old data if any(x.endswith(site_specific_name) for x in last_valid_file_list): manager.copy_file( - f"{BucketPath.LAST_VALID.value}/{subbucket_path}", - f"{BucketPath.ARCHIVE.value}/{timestamped_path}", + f"{enums.BucketPath.LAST_VALID.value}/{subbucket_path}", + f"{enums.BucketPath.ARCHIVE.value}/{timestamped_path}", ) # otherwise, this is the first instance - after it's in the database, # we'll generate a new list of valid tables for the dashboard @@ -278,8 +303,8 @@ def merge_powersets(manager: S3Manager) -> None: is_new_data_package = True df = expand_and_concat_sets(df, latest_path, manager.site) manager.move_file( - f"{BucketPath.LATEST.value}/{subbucket_path}", - f"{BucketPath.LAST_VALID.value}/{subbucket_path}", + f"{enums.BucketPath.LATEST.value}/{subbucket_path}", + f"{enums.BucketPath.LAST_VALID.value}/{subbucket_path}", ) #################### @@ -288,18 +313,20 @@ def merge_powersets(manager: S3Manager) -> None: # TODO: remove as soon as we support either parquet upload or # the API is supported by the dashboard generate_csv_from_parquet( - manager.s3_bucket_name, BucketPath.LAST_VALID.value, subbucket_path + manager.s3_bucket_name, + enums.BucketPath.LAST_VALID.value, + subbucket_path, ) #################### latest_site = site_specific_name.split("/", maxsplit=1)[0] manager.update_local_metadata( - TransactionKeys.LAST_DATA_UPDATE.value, site=latest_site + enums.TransactionKeys.LAST_DATA_UPDATE.value, site=latest_site ) manager.update_local_metadata( - TransactionKeys.LAST_AGGREGATION.value, site=latest_site + enums.TransactionKeys.LAST_AGGREGATION.value, site=latest_site ) - except Exception as e: # pylint: disable=broad-except + except Exception as e: manager.merge_error_handler( latest_path, subbucket_path, @@ -310,13 +337,38 @@ def merge_powersets(manager: S3Manager) -> None: if any(x.endswith(site_specific_name) for x in last_valid_file_list): df = expand_and_concat_sets( df, - f"s3://{manager.s3_bucket_name}/{BucketPath.LAST_VALID.value}" + f"s3://{manager.s3_bucket_name}/{enums.BucketPath.LAST_VALID.value}" f"/{subbucket_path}", manager.site, ) - manager.update_local_metadata(TransactionKeys.LAST_AGGREGATION.value) + manager.update_local_metadata( + enums.TransactionKeys.LAST_AGGREGATION.value + ) + + if df.empty: + raise OSError("File not found") + manager.write_local_metadata() + # Updating the typing dict for the CSV API + column_dict = functions.get_csv_column_datatypes(df.dtypes) + manager.update_local_metadata( + enums.ColumnTypesKeys.COLUMNS.value, + value=column_dict, + metadata=manager.types_metadata, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) + manager.update_local_metadata( + enums.ColumnTypesKeys.LAST_DATA_UPDATE.value, + value=column_dict, + metadata=manager.types_metadata, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) + manager.write_local_metadata( + metadata=manager.types_metadata, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) + # In this section, we are trying to accomplish two things: # - Prepare a csv that can be loaded manually into the dashboard (requiring no # quotes, which means removing commas from strings) @@ -326,11 +378,11 @@ def merge_powersets(manager: S3Manager) -> None: manager.write_csv(df) -@generic_error_handler(msg="Error merging powersets") +@decorators.generic_error_handler(msg="Error merging powersets") def powerset_merge_handler(event, context): """manages event from SNS, triggers file processing and merge""" del context manager = S3Manager(event) merge_powersets(manager) - res = http_response(200, "Merge successful") + res = functions.http_response(200, "Merge successful") return res diff --git a/src/handlers/site_upload/process_upload.py b/src/handlers/site_upload/process_upload.py index 7784eb4..677a76f 100644 --- a/src/handlers/site_upload/process_upload.py +++ b/src/handlers/site_upload/process_upload.py @@ -3,15 +3,7 @@ import boto3 -from ..shared.decorators import generic_error_handler -from ..shared.enums import BucketPath, TransactionKeys -from ..shared.functions import ( - http_response, - move_s3_file, - read_metadata, - update_metadata, - write_metadata, -) +from src.handlers.shared import decorators, enums, functions class UnexpectedFileTypeError(Exception): @@ -23,7 +15,7 @@ def process_upload(s3_client, sns_client, s3_bucket_name: str, s3_key: str) -> N last_uploaded_date = s3_client.head_object(Bucket=s3_bucket_name, Key=s3_key)[ "LastModified" ] - metadata = read_metadata(s3_client, s3_bucket_name) + metadata = functions.read_metadata(s3_client, s3_bucket_name) path_params = s3_key.split("/") study = path_params[1] data_package = path_params[2] @@ -33,55 +25,59 @@ def process_upload(s3_client, sns_client, s3_bucket_name: str, s3_key: str) -> N # to archive - we don't care about metadata for this, but can look there to # verify transmission if it's a connectivity test if study == "template": - new_key = f"{BucketPath.ARCHIVE.value}/{s3_key.split('/', 1)[-1]}" - move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) + new_key = f"{enums.BucketPath.ARCHIVE.value}/{s3_key.split('/', 1)[-1]}" + functions.move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) elif s3_key.endswith(".parquet"): if "__meta_" in s3_key or "/discovery__" in s3_key: - new_key = f"{BucketPath.STUDY_META.value}/{s3_key.split('/', 1)[-1]}" + new_key = f"{enums.BucketPath.STUDY_META.value}/{s3_key.split('/', 1)[-1]}" topic_sns_arn = os.environ.get("TOPIC_PROCESS_STUDY_META_ARN") sns_subject = "Process study metadata upload event" else: - new_key = f"{BucketPath.LATEST.value}/{s3_key.split('/', 1)[-1]}" + new_key = f"{enums.BucketPath.LATEST.value}/{s3_key.split('/', 1)[-1]}" topic_sns_arn = os.environ.get("TOPIC_PROCESS_COUNTS_ARN") sns_subject = "Process counts upload event" - move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) - metadata = update_metadata( - metadata, - site, - study, - data_package, - version, - TransactionKeys.LAST_UPLOAD.value, - last_uploaded_date, + functions.move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) + metadata = functions.update_metadata( + metadata=metadata, + site=site, + study=study, + data_package=data_package, + version=version, + target=enums.TransactionKeys.LAST_UPLOAD.value, + dt=last_uploaded_date, ) sns_client.publish(TopicArn=topic_sns_arn, Message=new_key, Subject=sns_subject) - write_metadata(s3_client, s3_bucket_name, metadata) + functions.write_metadata( + s3_client=s3_client, s3_bucket_name=s3_bucket_name, metadata=metadata + ) else: - new_key = f"{BucketPath.ERROR.value}/{s3_key.split('/', 1)[-1]}" - move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) - metadata = update_metadata( - metadata, - site, - study, - data_package, - version, - TransactionKeys.LAST_UPLOAD.value, - last_uploaded_date, + new_key = f"{enums.BucketPath.ERROR.value}/{s3_key.split('/', 1)[-1]}" + functions.move_s3_file(s3_client, s3_bucket_name, s3_key, new_key) + metadata = functions.update_metadata( + metadata=metadata, + site=site, + study=study, + data_package=data_package, + version=version, + target=enums.TransactionKeys.LAST_UPLOAD.value, + dt=last_uploaded_date, + ) + metadata = functions.update_metadata( + metadata=metadata, + site=site, + study=study, + data_package=data_package, + version=version, + target=enums.TransactionKeys.LAST_ERROR.value, + dt=last_uploaded_date, ) - metadata = update_metadata( - metadata, - site, - study, - data_package, - version, - TransactionKeys.LAST_ERROR.value, - last_uploaded_date, + functions.write_metadata( + s3_client=s3_client, s3_bucket_name=s3_bucket_name, metadata=metadata ) - write_metadata(s3_client, s3_bucket_name, metadata) raise UnexpectedFileTypeError -@generic_error_handler(msg="Error processing file upload") +@decorators.generic_error_handler(msg="Error processing file upload") def process_upload_handler(event, context): """manages event from S3, triggers file processing and merge""" del context @@ -90,5 +86,5 @@ def process_upload_handler(event, context): sns_client = boto3.client("sns", region_name=event["Records"][0]["awsRegion"]) s3_key = event["Records"][0]["s3"]["object"]["key"] process_upload(s3_client, sns_client, s3_bucket, s3_key) - res = http_response(200, "Upload processing successful") + res = functions.http_response(200, "Upload processing successful") return res diff --git a/src/handlers/site_upload/study_period.py b/src/handlers/site_upload/study_period.py index 93e5edc..8c87d21 100644 --- a/src/handlers/site_upload/study_period.py +++ b/src/handlers/site_upload/study_period.py @@ -6,10 +6,10 @@ import awswrangler import boto3 -from ..shared.awswrangler_functions import get_s3_study_meta_list -from ..shared.decorators import generic_error_handler -from ..shared.enums import JsonFilename, StudyPeriodMetadataKeys -from ..shared.functions import ( +from src.handlers.shared.awswrangler_functions import get_s3_study_meta_list +from src.handlers.shared.decorators import generic_error_handler +from src.handlers.shared.enums import JsonFilename, StudyPeriodMetadataKeys +from src.handlers.shared.functions import ( http_response, read_metadata, update_metadata, @@ -28,37 +28,40 @@ def update_study_period(s3_client, s3_bucket, site, study, data_package, version ) study_meta = update_metadata( - study_meta, - site, - study, - data_package, - version, - StudyPeriodMetadataKeys.EARLIEST_DATE.value, - df["min_date"][0], + metadata=study_meta, + site=site, + study=study, + data_package=data_package, + version=version, + target=StudyPeriodMetadataKeys.EARLIEST_DATE.value, + dt=df["min_date"][0], meta_type=JsonFilename.STUDY_PERIODS.value, ) study_meta = update_metadata( - study_meta, - site, - study, - data_package, - version, - StudyPeriodMetadataKeys.LATEST_DATE.value, - df["max_date"][0], + metadata=study_meta, + site=site, + study=study, + data_package=data_package, + version=version, + target=StudyPeriodMetadataKeys.LATEST_DATE.value, + dt=df["max_date"][0], meta_type=JsonFilename.STUDY_PERIODS.value, ) study_meta = update_metadata( - study_meta, - site, - study, - data_package, - version, - StudyPeriodMetadataKeys.LAST_DATA_UPDATE.value, - datetime.now(timezone.utc), + metadata=study_meta, + site=site, + study=study, + data_package=data_package, + version=version, + target=StudyPeriodMetadataKeys.LAST_DATA_UPDATE.value, + dt=datetime.now(timezone.utc), meta_type=JsonFilename.STUDY_PERIODS.value, ) write_metadata( - s3_client, s3_bucket, study_meta, meta_type=JsonFilename.STUDY_PERIODS.value + s3_client=s3_client, + s3_bucket_name=s3_bucket, + metadata=study_meta, + meta_type=JsonFilename.STUDY_PERIODS.value, ) diff --git a/template.yaml b/template.yaml index b053489..97c67f7 100644 --- a/template.yaml +++ b/template.yaml @@ -60,7 +60,7 @@ Resources: Properties: FunctionName: !Sub 'CumulusAggFetchAuthorizer-${DeployStage}' Handler: src/handlers/site_upload/api_gateway_authorizer.lambda_handler - Runtime: python3.9 + Runtime: python3.10 MemorySize: 128 Timeout: 100 Description: Validates credentials before providing signed urls @@ -83,7 +83,7 @@ Resources: Properties: FunctionName: !Sub 'CumulusAggFetchUploadUrl-${DeployStage}' Handler: src/handlers/site_upload/fetch_upload_url.upload_url_handler - Runtime: python3.9 + Runtime: python3.10 MemorySize: 128 Timeout: 100 Description: Generates a presigned URL for uploading files to S3 @@ -113,7 +113,7 @@ Resources: Properties: FunctionName: !Sub 'CumulusAggProcessUpload-${DeployStage}' Handler: src/handlers/site_upload/process_upload.process_upload_handler - Runtime: python3.9 + Runtime: python3.10 MemorySize: 128 Timeout: 800 Description: Handles initial relocation of upload data @@ -142,7 +142,7 @@ Resources: FunctionName: !Sub 'CumulusAggPowersetMerge-${DeployStage}' Layers: [arn:aws:lambda:us-east-1:336392948345:layer:AWSSDKPandas-Python39:1] Handler: src/handlers/site_upload/powerset_merge.powerset_merge_handler - Runtime: python3.9 + Runtime: python3.10 MemorySize: 8192 Timeout: 800 Description: Merges and aggregates powerset count data @@ -173,7 +173,7 @@ Resources: FunctionName: !Sub 'CumulusAggStudyPeriod-${DeployStage}' Layers: [arn:aws:lambda:us-east-1:336392948345:layer:AWSSDKPandas-Python39:1] Handler: src/handlers/site_upload/study_period.study_period_handler - Runtime: python3.9 + Runtime: python3.10 MemorySize: 512 Timeout: 800 Description: Handles metadata outside of upload/processing for studies @@ -201,7 +201,7 @@ Resources: FunctionName: !Sub 'CumulusAggCacheAPI-${DeployStage}' Layers: [arn:aws:lambda:us-east-1:336392948345:layer:AWSSDKPandas-Python39:1] Handler: src/handlers/site_upload/cache_api.cache_api_handler - Runtime: python3.9 + Runtime: python3.10 MemorySize: 512 Timeout: 800 Description: Caches selected database queries to S3 @@ -243,13 +243,84 @@ Resources: # Dashboard API + DashboardGetCsvFunction: + Type: AWS::Serverless::Function + Properties: + FunctionName: !Sub 'CumulusAggDashboardGetCsv-${DeployStage}' + Handler: src/handlers/dashboard/get_csv.get_csv_handler + Runtime: python3.10 + MemorySize: 128 + Timeout: 100 + Description: Redirect to presigned URL for download of aggregate CSVs + Environment: + Variables: + BUCKET_NAME: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' + Events: + GetAggregateAPI: + Type: Api + Properties: + RestApiId: !Ref DashboardApiGateway + Path: /aggregate/{study}/{subscription}/{version}/{filename} + Method: GET + GetLastValidAPI: + Type: Api + Properties: + RestApiId: !Ref DashboardApiGateway + Path: /last_valid/{study}/{subscription}/{site}/{version}/{filename} + Method: GET + Policies: + - S3ReadPolicy: + BucketName: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' + + DashboardGetCsvLogGroup: + Type: AWS::Logs::LogGroup + Properties: + LogGroupName: !Sub "/aws/lambda/${DashboardGetCsvFunction}" + RetentionInDays: !Ref RetentionTime + + + DashboardGetCsvListFunction: + Type: AWS::Serverless::Function + Properties: + FunctionName: !Sub 'CumulusAggDashboardGetCsv-${DeployStage}' + Handler: src/handlers/dashboard/get_csv.get_csv_list_handler + Runtime: python3.10 + MemorySize: 128 + Timeout: 100 + Description: List all available csvs from the aggregator + Environment: + Variables: + BUCKET_NAME: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' + Events: + GetAggregateAPI: + Type: Api + Properties: + RestApiId: !Ref DashboardApiGateway + Path: /aggregate/ + Method: GET + GetLastValidAPI: + Type: Api + Properties: + RestApiId: !Ref DashboardApiGateway + Path: /last_valid/ + Method: GET + Policies: + - S3ReadPolicy: + BucketName: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' + + DashboardGetCsvLogGroup: + Type: AWS::Logs::LogGroup + Properties: + LogGroupName: !Sub "/aws/lambda/${DashboardGetCsvFunction}" + RetentionInDays: !Ref RetentionTime + DashboardGetChartDataFunction: Type: AWS::Serverless::Function Properties: FunctionName: !Sub 'CumulusAggDashboardGetChartData-${DeployStage}' Layers: [arn:aws:lambda:us-east-1:336392948345:layer:AWSSDKPandas-Python39:1] Handler: src/handlers/dashboard/get_chart_data.chart_data_handler - Runtime: python3.9 + Runtime: python3.10 MemorySize: 2048 Timeout: 100 Description: Retrieve data for chart display in Cumulus Dashboard @@ -296,7 +367,7 @@ Resources: Properties: FunctionName: !Sub 'CumulusAggDashboardGetMetadata-${DeployStage}' Handler: src/handlers/dashboard/get_metadata.metadata_handler - Runtime: python3.9 + Runtime: python3.10 MemorySize: 128 Timeout: 100 Description: Retrieve data about site uploads @@ -328,6 +399,12 @@ Resources: RestApiId: !Ref DashboardApiGateway Path: /metadata/{site}/{study}/{subscription} Method: GET + GetMetadataSiteStudySubscriptionVersionAPI: + Type: Api + Properties: + RestApiId: !Ref DashboardApiGateway + Path: /metadata/{site}/{study}/{subscription}/{version} + Method: GET Policies: - S3ReadPolicy: BucketName: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' @@ -344,7 +421,7 @@ Resources: FunctionName: !Sub 'CumulusAggDashboardDataPackages-${DeployStage}' Layers: [arn:aws:lambda:us-east-1:336392948345:layer:AWSSDKPandas-Python39:1] Handler: src/handlers/dashboard/get_data_packages.data_packages_handler - Runtime: python3.9 + Runtime: python3.10 MemorySize: 512 Timeout: 100 Description: Retrieve data for chart display in Cumulus Dashboard @@ -400,7 +477,7 @@ Resources: Properties: FunctionName: !Sub 'CumulusAggDashboardStudyPeriods-${DeployStage}' Handler: src/handlers/dashboard/get_study_periods.study_periods_handler - Runtime: python3.9 + Runtime: python3.10 MemorySize: 128 Timeout: 100 Description: Retrieve data about the study period diff --git a/tests/conftest.py b/tests/conftest.py index dfcc345..bdf4510 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,20 +27,10 @@ import boto3 import pytest from moto import mock_athena, mock_s3, mock_sns -from scripts.credential_management import create_auth, create_meta - -from src.handlers.shared.enums import BucketPath, JsonFilename -from src.handlers.shared.functions import write_metadata -from tests.utils import ( - EXISTING_DATA_P, - EXISTING_STUDY, - EXISTING_VERSION, - ITEM_COUNT, - MOCK_ENV, - OTHER_STUDY, - get_mock_metadata, - get_mock_study_metadata, -) + +from scripts import credential_management +from src.handlers.shared import enums, functions +from tests import mock_utils def _init_mock_data(s3_client, bucket, study, data_package, version): @@ -57,25 +47,25 @@ def _init_mock_data(s3_client, bucket, study, data_package, version): s3_client.upload_file( "./tests/test_data/count_synthea_patient_agg.parquet", bucket, - f"{BucketPath.AGGREGATE.value}/{study}/" + f"{enums.BucketPath.AGGREGATE.value}/{study}/" f"{study}__{data_package}/{version}/{study}__{data_package}__aggregate.parquet", ) s3_client.upload_file( "./tests/test_data/count_synthea_patient_agg.csv", bucket, - f"{BucketPath.CSVAGGREGATE.value}/{study}/" + f"{enums.BucketPath.CSVAGGREGATE.value}/{study}/" f"{study}__{data_package}/{version}/{study}__{data_package}__aggregate.csv", ) s3_client.upload_file( "./tests/test_data/data_packages_cache.json", bucket, - f"{BucketPath.CACHE.value}/{JsonFilename.DATA_PACKAGES.value}.json", + f"{enums.BucketPath.CACHE.value}/{enums.JsonFilename.DATA_PACKAGES.value}.json", ) @pytest.fixture(autouse=True) def mock_env(): - with mock.patch.dict(os.environ, MOCK_ENV): + with mock.patch.dict(os.environ, mock_utils.MOCK_ENV): yield @@ -89,23 +79,49 @@ def mock_bucket(): bucket = os.environ["BUCKET_NAME"] s3_client.create_bucket(Bucket=bucket) aggregate_params = [ - [EXISTING_STUDY, EXISTING_DATA_P, EXISTING_VERSION], - [OTHER_STUDY, EXISTING_DATA_P, EXISTING_VERSION], + [ + mock_utils.EXISTING_STUDY, + mock_utils.EXISTING_DATA_P, + mock_utils.EXISTING_VERSION, + ], + [ + mock_utils.OTHER_STUDY, + mock_utils.EXISTING_DATA_P, + mock_utils.EXISTING_VERSION, + ], ] for param_list in aggregate_params: _init_mock_data(s3_client, bucket, *param_list) - create_auth(s3_client, bucket, "ppth_1", "test_1", "ppth") - create_meta(s3_client, bucket, "ppth", "princeton_plainsboro_teaching_hospital") - create_auth(s3_client, bucket, "elsewhere_2", "test_2", "elsewhere") - create_meta(s3_client, bucket, "elsewhere", "st_elsewhere") - create_auth(s3_client, bucket, "hope_3", "test_3", "hope") - create_meta(s3_client, bucket, "hope", "chicago_hope") - - metadata = get_mock_metadata() - write_metadata(s3_client, bucket, metadata) - study_metadata = get_mock_study_metadata() - write_metadata(s3_client, bucket, study_metadata, meta_type="study_periods") + credential_management.create_auth(s3_client, bucket, "ppth_1", "test_1", "ppth") + credential_management.create_meta( + s3_client, bucket, "ppth", "princeton_plainsboro_teaching_hospital" + ) + credential_management.create_auth( + s3_client, bucket, "elsewhere_2", "test_2", "elsewhere" + ) + credential_management.create_meta(s3_client, bucket, "elsewhere", "st_elsewhere") + credential_management.create_auth(s3_client, bucket, "hope_3", "test_3", "hope") + credential_management.create_meta(s3_client, bucket, "hope", "chicago_hope") + + metadata = mock_utils.get_mock_metadata() + functions.write_metadata( + s3_client=s3_client, s3_bucket_name=bucket, metadata=metadata + ) + study_metadata = mock_utils.get_mock_study_metadata() + functions.write_metadata( + s3_client=s3_client, + s3_bucket_name=bucket, + metadata=study_metadata, + meta_type=enums.JsonFilename.STUDY_PERIODS.value, + ) + column_types_metadata = mock_utils.get_mock_column_types_metadata() + functions.write_metadata( + s3_client=s3_client, + s3_bucket_name=bucket, + metadata=column_types_metadata, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) yield s3.stop() @@ -114,7 +130,7 @@ def mock_bucket(): def mock_notification(): """Mocks for SNS topics. - Make sure the topic name matches the end of the ARN defined in utils.py""" + Make sure the topic name matches the end of the ARN defined in mock_utils.py""" sns = mock_sns() sns.start() sns_client = boto3.client("sns", region_name="us-east-1") @@ -152,4 +168,4 @@ def mock_db(): def test_mock_bucket(): s3_client = boto3.client("s3", region_name="us-east-1") item = s3_client.list_objects_v2(Bucket=os.environ["TEST_BUCKET"]) - assert (len(item["Contents"])) == ITEM_COUNT + assert (len(item["Contents"])) == mock_utils.ITEM_COUNT diff --git a/tests/dashboard/test_filter_config.py b/tests/dashboard/test_filter_config.py index a4ce9dc..c804c2b 100644 --- a/tests/dashboard/test_filter_config.py +++ b/tests/dashboard/test_filter_config.py @@ -4,7 +4,7 @@ @pytest.mark.parametrize( - "input,output", + "input_str,output_str", [ # Checking individual conversions (["col:strEq:str"], "col LIKE 'str'"), @@ -141,5 +141,5 @@ ), ], ) -def test_filter_string(input, output): - assert get_filter_string(input) == output +def test_filter_string(input_str, output_str): + assert get_filter_string(input_str) == output_str diff --git a/tests/dashboard/test_get_chart_data.py b/tests/dashboard/test_get_chart_data.py index 6ccb89b..ec11a9a 100644 --- a/tests/dashboard/test_get_chart_data.py +++ b/tests/dashboard/test_get_chart_data.py @@ -2,19 +2,15 @@ import os from unittest import mock -import boto3 import pandas import pytest from src.handlers.dashboard import get_chart_data -from tests.utils import ( +from tests.mock_utils import ( EXISTING_DATA_P, EXISTING_STUDY, - EXISTING_VERSION, MOCK_ENV, - TEST_BUCKET, TEST_GLUE_DB, - TEST_WORKGROUP, ) @@ -22,9 +18,9 @@ def mock_get_table_cols(name): return ["cnt", "gender", "race"] -def mock_data_frame(filter): +def mock_data_frame(filter_param): df = pandas.read_csv("tests/test_data/cube_simple_example.csv", na_filter=False) - if filter != []: + if filter_param != []: df = df[df["gender"] == "female"] return df @@ -109,8 +105,6 @@ def test_format_payload(query_params, filters, expected_payload): def test_get_data_cols(mock_bucket): - s3_client = boto3.client("s3", region_name="us-east-1") - s3_res = s3_client.list_objects_v2(Bucket=TEST_BUCKET) table_name = f"{EXISTING_STUDY}__{EXISTING_DATA_P}" res = get_chart_data._get_table_cols(table_name) cols = pandas.read_csv("./tests/test_data/count_synthea_patient_agg.csv").columns diff --git a/tests/dashboard/test_get_csv.py b/tests/dashboard/test_get_csv.py new file mode 100644 index 0000000..d32bcfc --- /dev/null +++ b/tests/dashboard/test_get_csv.py @@ -0,0 +1,148 @@ +import json +import os +from contextlib import nullcontext as does_not_raise +from unittest import mock + +import boto3 +import pytest + +from src.handlers.dashboard import get_csv +from src.handlers.shared import enums +from tests import mock_utils + +# data matching these params is created via conftest +site = mock_utils.EXISTING_SITE +study = mock_utils.EXISTING_STUDY +subscription = mock_utils.EXISTING_DATA_P +version = mock_utils.EXISTING_VERSION +filename = filename = f"{study}__{subscription}__aggregate.csv" + + +def _mock_last_valid(): + bucket = os.environ["BUCKET_NAME"] + s3_client = boto3.client("s3", region_name="us-east-1") + s3_client.upload_file( + "./tests/test_data/count_synthea_patient_agg.csv", + bucket, + f"{enums.BucketPath.LAST_VALID.value}/{study}/" + f"{study}__{subscription}/{site}/{version}/{filename}", + ) + + +@pytest.mark.parametrize( + "params,status,expected", + [ + ( + { + "site": None, + "study": study, + "subscription": subscription, + "version": version, + "filename": filename, + }, + 302, + None, + ), + ( + { + "site": site, + "study": study, + "subscription": subscription, + "version": version, + "filename": filename, + }, + 302, + None, + ), + ( + { + "study": study, + "subscription": subscription, + "version": version, + "filename": filename, + }, + 302, + None, + ), + ( + { + "site": site, + "study": study, + "subscription": subscription, + "version": version, + "filename": "foo", + }, + 500, + None, + ), + ( + { + "site": None, + "study": None, + "subscription": None, + "version": None, + "filename": None, + }, + 500, + None, + ), + ], +) +@mock.patch.dict(os.environ, mock_utils.MOCK_ENV) +def test_get_csv(mock_bucket, params, status, expected): + event = {"pathParameters": params} + if "site" in params and params["site"] is not None: + _mock_last_valid() + res = get_csv.get_csv_handler(event, {}) + assert res["statusCode"] == status + if status == 302: + if "site" not in params or params["site"] is None: + url = ( + "https://cumulus-aggregator-site-counts-test.s3.amazonaws.com/csv_aggregates/" + f"{study}/{study}__{subscription}/{version}/{filename}" + ) + else: + url = ( + "https://cumulus-aggregator-site-counts-test.s3.amazonaws.com/last_valid/" + f"{study}/{study}__{subscription}/{site}/{version}/{filename}" + ) + assert ( + res["headers"]["x-column-types"] == "integer,string,integer,string,string" + ) + assert res["headers"]["Location"].startswith(url) + + +@pytest.mark.parametrize( + "path,status,expected,raises", + [ + ( + "/aggregates", + 200, + [ + "aggregates/other_study/encounter/099/other_study__encounter__aggregate.csv", + "aggregates/study/encounter/099/study__encounter__aggregate.csv", + ], + does_not_raise(), + ), + ( + "/last_valid", + 200, + [ + "last_valid/study/encounter/princeton_plainsboro_teaching_hospital/099/study__encounter__aggregate.csv" + ], + does_not_raise(), + ), + ("/some_other_endpoint", 500, [], does_not_raise()), + ], +) +@mock.patch.dict(os.environ, mock_utils.MOCK_ENV) +def test_get_csv_list(mock_bucket, path, status, expected, raises): + with raises: + if path.startswith("/last_valid"): + _mock_last_valid() + event = {"path": path} + res = get_csv.get_csv_list_handler(event, {}) + keys = json.loads(res["body"]) + assert res["statusCode"] == status + if status == 200: + assert keys == expected diff --git a/tests/dashboard/test_get_metadata.py b/tests/dashboard/test_get_metadata.py index 586815f..959a95f 100644 --- a/tests/dashboard/test_get_metadata.py +++ b/tests/dashboard/test_get_metadata.py @@ -1,53 +1,52 @@ import json -import os -from datetime import datetime, timezone -from unittest import mock -import boto3 import pytest from src.handlers.dashboard.get_metadata import metadata_handler -from src.handlers.shared.enums import BucketPath -from src.handlers.shared.functions import read_metadata, write_metadata -from tests.utils import ( - EXISTING_DATA_P, - EXISTING_SITE, - EXISTING_STUDY, - EXISTING_VERSION, - NEW_SITE, - NEW_STUDY, - OTHER_SITE, - OTHER_STUDY, - TEST_BUCKET, - get_mock_metadata, -) +from tests import mock_utils @pytest.mark.parametrize( "params,status,expected", [ - (None, 200, get_mock_metadata()), + (None, 200, mock_utils.get_mock_metadata()), + ( + {"site": mock_utils.EXISTING_SITE}, + 200, + mock_utils.get_mock_metadata()[mock_utils.EXISTING_SITE], + ), ( - {"site": EXISTING_SITE}, + {"site": mock_utils.EXISTING_SITE, "study": mock_utils.EXISTING_STUDY}, 200, - get_mock_metadata()[EXISTING_SITE], + mock_utils.get_mock_metadata()[mock_utils.EXISTING_SITE][ + mock_utils.EXISTING_STUDY + ], ), ( - {"site": EXISTING_SITE, "study": EXISTING_STUDY}, + { + "site": mock_utils.EXISTING_SITE, + "study": mock_utils.EXISTING_STUDY, + "data_package": mock_utils.EXISTING_DATA_P, + }, 200, - get_mock_metadata()[EXISTING_SITE][EXISTING_STUDY], + mock_utils.get_mock_metadata()[mock_utils.EXISTING_SITE][ + mock_utils.EXISTING_STUDY + ][mock_utils.EXISTING_DATA_P], ), ( { - "site": EXISTING_SITE, - "study": EXISTING_STUDY, - "data_package": EXISTING_DATA_P, + "site": mock_utils.EXISTING_SITE, + "study": mock_utils.EXISTING_STUDY, + "data_package": mock_utils.EXISTING_DATA_P, + "version": mock_utils.EXISTING_VERSION, }, 200, - get_mock_metadata()[EXISTING_SITE][EXISTING_STUDY][EXISTING_DATA_P], + mock_utils.get_mock_metadata()[mock_utils.EXISTING_SITE][ + mock_utils.EXISTING_STUDY + ][mock_utils.EXISTING_DATA_P][mock_utils.EXISTING_VERSION], ), - ({"site": NEW_SITE, "study": EXISTING_STUDY}, 500, None), - ({"site": EXISTING_SITE, "study": NEW_STUDY}, 500, None), + ({"site": mock_utils.NEW_SITE, "study": mock_utils.EXISTING_STUDY}, 500, None), + ({"site": mock_utils.EXISTING_SITE, "study": mock_utils.NEW_STUDY}, 500, None), ], ) def test_get_metadata(mock_bucket, params, status, expected): diff --git a/tests/dashboard/test_get_study_periods.py b/tests/dashboard/test_get_study_periods.py index 408a4e2..9bd6833 100644 --- a/tests/dashboard/test_get_study_periods.py +++ b/tests/dashboard/test_get_study_periods.py @@ -1,23 +1,13 @@ import json -import os -from datetime import datetime, timezone -from unittest import mock -import boto3 import pytest from src.handlers.dashboard.get_study_periods import study_periods_handler -from src.handlers.shared.enums import BucketPath -from src.handlers.shared.functions import read_metadata, write_metadata -from tests.utils import ( - EXISTING_DATA_P, +from tests.mock_utils import ( EXISTING_SITE, EXISTING_STUDY, - EXISTING_VERSION, NEW_SITE, NEW_STUDY, - OTHER_SITE, - OTHER_STUDY, get_mock_study_metadata, ) @@ -42,8 +32,9 @@ ) def test_get_study_periods(mock_bucket, params, status, expected): event = {"pathParameters": params} - res = study_periods_handler(event, {}) + print(res["body"]) + print(expected) assert res["statusCode"] == status if status == 200: assert json.loads(res["body"]) == expected diff --git a/tests/dashboard/test_get_subscriptions.py b/tests/dashboard/test_get_subscriptions.py index 5b928ca..7a4ec13 100644 --- a/tests/dashboard/test_get_subscriptions.py +++ b/tests/dashboard/test_get_subscriptions.py @@ -1,12 +1,8 @@ import os from unittest import mock -import awswrangler -import pandas -from pytest_mock import MockerFixture - from src.handlers.dashboard.get_data_packages import data_packages_handler -from tests.utils import DATA_PACKAGE_COUNT, MOCK_ENV, get_mock_metadata +from tests.mock_utils import DATA_PACKAGE_COUNT, MOCK_ENV @mock.patch.dict(os.environ, MOCK_ENV) diff --git a/tests/utils.py b/tests/mock_utils.py similarity index 88% rename from tests/utils.py rename to tests/mock_utils.py index 6259e67..301d35f 100644 --- a/tests/utils.py +++ b/tests/mock_utils.py @@ -6,7 +6,7 @@ TEST_PROCESS_COUNTS_ARN = "arn:aws:sns:us-east-1:123456789012:test-counts" TEST_PROCESS_STUDY_META_ARN = "arn:aws:sns:us-east-1:123456789012:test-meta" TEST_CACHE_API_ARN = "arn:aws:sns:us-east-1:123456789012:test-cache" -ITEM_COUNT = 9 +ITEM_COUNT = 10 DATA_PACKAGE_COUNT = 2 EXISTING_SITE = "princeton_plainsboro_teaching_hospital" @@ -110,6 +110,26 @@ def get_mock_study_metadata(): } +def get_mock_column_types_metadata(): + return { + EXISTING_STUDY: { + EXISTING_DATA_P: { + EXISTING_VERSION: { + "column_types_format_version": "1", + "columns": { + "cnt": "integer", + "gender": "string", + "age": "integer", + "race_display": "string", + "site": "string", + }, + "last_data_update": "2023-02-24T15:08:07.771080+00:00", + } + } + } + } + + def get_mock_auth(): return { # u/a: ppth_1 test_1 diff --git a/tests/site_upload/test_api_gateway_authorizer.py b/tests/site_upload/test_api_gateway_authorizer.py index 94bdc82..8d7a7ec 100644 --- a/tests/site_upload/test_api_gateway_authorizer.py +++ b/tests/site_upload/test_api_gateway_authorizer.py @@ -1,21 +1,17 @@ -import json -import os from contextlib import nullcontext as does_not_raise -from unittest import mock import pytest -from pytest_mock import MockerFixture -from src.handlers.site_upload.api_gateway_authorizer import lambda_handler -from tests.utils import TEST_BUCKET, get_mock_auth +from src.handlers.site_upload import api_gateway_authorizer +from tests import mock_utils @pytest.mark.parametrize( "auth,expects", [ - (f"Basic {list(get_mock_auth().keys())[0]}", does_not_raise()), - ("Basic other_auth", pytest.raises(Exception)), - (None, pytest.raises(Exception)), + (f"Basic {next(iter(mock_utils.get_mock_auth().keys()))}", does_not_raise()), + ("Basic other_auth", pytest.raises(api_gateway_authorizer.AuthError)), + (None, pytest.raises(AttributeError)), ], ) def test_validate_pw(auth, expects, mock_bucket): @@ -25,4 +21,5 @@ def test_validate_pw(auth, expects, mock_bucket): "methodArn": "arn:aws:execute-api:us-east-1:11223:123/Prod/post/lambda", } with expects: - res = lambda_handler(event, {}) + res = api_gateway_authorizer.lambda_handler(event, {}) + assert res["policyDocument"]["Statement"][0]["Effect"] == "Allow" diff --git a/tests/site_upload/test_cache_api.py b/tests/site_upload/test_cache_api.py index 836677b..f50dc24 100644 --- a/tests/site_upload/test_cache_api.py +++ b/tests/site_upload/test_cache_api.py @@ -1,12 +1,11 @@ import os from unittest import mock -import awswrangler import pandas import pytest from src.handlers.site_upload.cache_api import cache_api_handler -from tests.utils import MOCK_ENV, get_mock_data_packages_cache +from tests.mock_utils import MOCK_ENV, get_mock_data_packages_cache def mock_data_packages(*args, **kwargs): @@ -24,7 +23,6 @@ def mock_data_packages(*args, **kwargs): ], ) def test_cache_api(mocker, mock_bucket, subject, message, mock_result, status): - mock_query_result = mocker.patch("awswrangler.athena.read_sql_query") mock_query_result.side_effect = mock_result event = {"Records": [{"Sns": {"Subject": subject, "Message": message}}]} diff --git a/tests/site_upload/test_fetch_upload_url.py b/tests/site_upload/test_fetch_upload_url.py index 766c57b..08d997e 100644 --- a/tests/site_upload/test_fetch_upload_url.py +++ b/tests/site_upload/test_fetch_upload_url.py @@ -1,19 +1,14 @@ import json -import os -from unittest import mock -import boto3 import pytest from src.handlers.shared.enums import BucketPath from src.handlers.site_upload.fetch_upload_url import upload_url_handler -from tests.utils import ( +from tests.mock_utils import ( EXISTING_DATA_P, EXISTING_SITE, EXISTING_STUDY, EXISTING_VERSION, - TEST_BUCKET, - get_mock_metadata, ) diff --git a/tests/site_upload/test_powerset_merge.py b/tests/site_upload/test_powerset_merge.py index 52fbed8..20c51c5 100644 --- a/tests/site_upload/test_powerset_merge.py +++ b/tests/site_upload/test_powerset_merge.py @@ -10,28 +10,20 @@ from freezegun import freeze_time from pandas import DataFrame, read_parquet -from src.handlers.shared.enums import BucketPath -from src.handlers.shared.functions import read_metadata, write_metadata -from src.handlers.site_upload.powerset_merge import ( - MergeError, - expand_and_concat_sets, - generate_csv_from_parquet, - powerset_merge_handler, -) -from tests.utils import ( +from src.handlers.shared import enums, functions +from src.handlers.site_upload import powerset_merge +from tests.mock_utils import ( EXISTING_DATA_P, EXISTING_SITE, EXISTING_STUDY, EXISTING_VERSION, ITEM_COUNT, MOCK_ENV, - NEW_DATA_P, NEW_SITE, NEW_STUDY, NEW_VERSION, - OTHER_SITE, - OTHER_STUDY, TEST_BUCKET, + get_mock_column_types_metadata, get_mock_metadata, ) @@ -148,26 +140,26 @@ def test_powerset_merge_single_upload( s3_client.upload_file( upload_file, TEST_BUCKET, - f"{BucketPath.LATEST.value}{upload_path}", + f"{enums.BucketPath.LATEST.value}{upload_path}", ) elif upload_path is not None: with io.BytesIO(DataFrame().to_parquet()) as upload_fileobj: s3_client.upload_fileobj( upload_fileobj, TEST_BUCKET, - f"{BucketPath.LATEST.value}{upload_path}", + f"{enums.BucketPath.LATEST.value}{upload_path}", ) if archives: s3_client.upload_file( upload_file, TEST_BUCKET, - f"{BucketPath.LAST_VALID.value}{upload_path}", + f"{enums.BucketPath.LAST_VALID.value}{upload_path}", ) event = { "Records": [ { - "Sns": {"Message": f"{BucketPath.LATEST.value}{event_key}"}, + "Sns": {"Message": f"{enums.BucketPath.LATEST.value}{event_key}"}, } ] } @@ -178,13 +170,13 @@ def test_powerset_merge_single_upload( data_package = event_list[2] site = event_list[3] version = event_list[4] - res = powerset_merge_handler(event, {}) + res = powerset_merge.powerset_merge_handler(event, {}) assert res["statusCode"] == status s3_res = s3_client.list_objects_v2(Bucket=TEST_BUCKET) assert len(s3_res["Contents"]) == expected_contents for item in s3_res["Contents"]: if item["Key"].endswith("aggregate.parquet"): - assert item["Key"].startswith(BucketPath.AGGREGATE.value) + assert item["Key"].startswith(enums.BucketPath.AGGREGATE.value) # This finds the aggregate that was created/updated - ie it skips mocks if study in item["Key"] and status == 200: agg_df = awswrangler.s3.read_parquet( @@ -192,10 +184,10 @@ def test_powerset_merge_single_upload( ) assert (agg_df["site"].eq(site)).any() elif item["Key"].endswith("aggregate.csv"): - assert item["Key"].startswith(BucketPath.CSVAGGREGATE.value) + assert item["Key"].startswith(enums.BucketPath.CSVAGGREGATE.value) elif item["Key"].endswith("transactions.json"): - assert item["Key"].startswith(BucketPath.META.value) - metadata = read_metadata(s3_client, TEST_BUCKET) + assert item["Key"].startswith(enums.BucketPath.META.value) + metadata = functions.read_metadata(s3_client, TEST_BUCKET) if res["statusCode"] == 200: assert ( metadata[site][study][data_package.split("__")[1]][version][ @@ -213,7 +205,7 @@ def test_powerset_merge_single_upload( "study" ]["encounter"]["099"]["last_aggregation"] ) - if upload_file is not None: + if upload_file is not None and study != NEW_STUDY: # checking to see that merge powerset didn't touch last upload assert ( metadata[site][study][data_package.split("__")[1]][version][ @@ -221,20 +213,41 @@ def test_powerset_merge_single_upload( ] != datetime.now(timezone.utc).isoformat() ) + elif item["Key"].endswith("column_types.json"): + assert item["Key"].startswith(enums.BucketPath.META.value) + metadata = functions.read_metadata( + s3_client, TEST_BUCKET, meta_type=enums.JsonFilename.COLUMN_TYPES.value + ) + if res["statusCode"] == 200: + assert ( + metadata[study][data_package.split("__")[1]][version][ + "last_data_update" + ] + == datetime.now(timezone.utc).isoformat() + ) - elif item["Key"].startswith(BucketPath.LAST_VALID.value): + else: + assert ( + metadata["study"]["encounter"]["099"]["last_data_update"] + == get_mock_column_types_metadata()["study"]["encounter"]["099"][ + "last_data_update" + ] + ) + elif item["Key"].startswith(enums.BucketPath.LAST_VALID.value): if item["Key"].endswith(".parquet"): - assert item["Key"] == (f"{BucketPath.LAST_VALID.value}{upload_path}") + assert item["Key"] == ( + f"{enums.BucketPath.LAST_VALID.value}{upload_path}" + ) elif item["Key"].endswith(".csv"): assert f"{upload_path.replace('.parquet','.csv')}" in item["Key"] else: raise Exception("Invalid csv found at " f"{item['Key']}") else: assert ( - item["Key"].startswith(BucketPath.ARCHIVE.value) - or item["Key"].startswith(BucketPath.ERROR.value) - or item["Key"].startswith(BucketPath.ADMIN.value) - or item["Key"].startswith(BucketPath.CACHE.value) + item["Key"].startswith(enums.BucketPath.ARCHIVE.value) + or item["Key"].startswith(enums.BucketPath.ERROR.value) + or item["Key"].startswith(enums.BucketPath.ADMIN.value) + or item["Key"].startswith(enums.BucketPath.CACHE.value) or item["Key"].endswith("study_periods.json") ) if archives: @@ -243,7 +256,7 @@ def test_powerset_merge_single_upload( keys.append(resource["Key"]) date_str = datetime.now(timezone.utc).isoformat() archive_path = f".{date_str}.".join(upload_path.split(".")) - assert f"{BucketPath.ARCHIVE.value}{archive_path}" in keys + assert f"{enums.BucketPath.ARCHIVE.value}{archive_path}" in keys @freeze_time("2020-01-01") @@ -263,12 +276,11 @@ def test_powerset_merge_join_study_data( mock_bucket, mock_notification, ): - s3_client = boto3.client("s3", region_name="us-east-1") s3_client.upload_file( upload_file, TEST_BUCKET, - f"{BucketPath.LATEST.value}/{EXISTING_STUDY}/" + f"{enums.BucketPath.LATEST.value}/{EXISTING_STUDY}/" f"{EXISTING_STUDY}__{EXISTING_DATA_P}/{NEW_SITE}/" f"{EXISTING_VERSION}/encounter.parquet", ) @@ -276,7 +288,7 @@ def test_powerset_merge_join_study_data( s3_client.upload_file( "./tests/test_data/count_synthea_patient.parquet", TEST_BUCKET, - f"{BucketPath.LAST_VALID.value}/{EXISTING_STUDY}/" + f"{enums.BucketPath.LAST_VALID.value}/{EXISTING_STUDY}/" f"{EXISTING_STUDY}__{EXISTING_DATA_P}/{EXISTING_SITE}/" f"{EXISTING_VERSION}/encounter.parquet", ) @@ -285,7 +297,7 @@ def test_powerset_merge_join_study_data( s3_client.upload_file( "./tests/test_data/count_synthea_patient.parquet", TEST_BUCKET, - f"{BucketPath.LAST_VALID.value}/{EXISTING_STUDY}/" + f"{enums.BucketPath.LAST_VALID.value}/{EXISTING_STUDY}/" f"{EXISTING_STUDY}__{EXISTING_DATA_P}/{NEW_SITE}/" f"{EXISTING_VERSION}/encounter.parquet", ) @@ -294,21 +306,21 @@ def test_powerset_merge_join_study_data( "Records": [ { "Sns": { - "Message": f"{BucketPath.LATEST.value}/{EXISTING_STUDY}" + "Message": f"{enums.BucketPath.LATEST.value}/{EXISTING_STUDY}" f"/{EXISTING_STUDY}__{EXISTING_DATA_P}/{NEW_SITE}" f"/{EXISTING_VERSION}/encounter.parquet" }, } ] } - res = powerset_merge_handler(event, {}) + res = powerset_merge.powerset_merge_handler(event, {}) assert res["statusCode"] == 200 errors = 0 s3_res = s3_client.list_objects_v2(Bucket=TEST_BUCKET) for item in s3_res["Contents"]: - if item["Key"].startswith(BucketPath.ERROR.value): + if item["Key"].startswith(enums.BucketPath.ERROR.value): errors += 1 - elif item["Key"].startswith(f"{BucketPath.AGGREGATE.value}/study"): + elif item["Key"].startswith(f"{enums.BucketPath.AGGREGATE.value}/study"): agg_df = awswrangler.s3.read_parquet(f"s3://{TEST_BUCKET}/{item['Key']}") # if a file cant be merged and there's no fallback, we expect # [, site_name], otherwise, [, site_name, uploading_site_name] @@ -328,26 +340,28 @@ def test_powerset_merge_join_study_data( ( "./tests/test_data/cube_simple_example.parquet", False, - pytest.raises(MergeError), + pytest.raises(powerset_merge.MergeError), ), ( "./tests/test_data/count_synthea_empty.parquet", True, - pytest.raises(MergeError), + pytest.raises(powerset_merge.MergeError), ), ], ) def test_expand_and_concat(mock_bucket, upload_file, load_empty, raises): with raises: df = read_parquet("./tests/test_data/count_synthea_patient_agg.parquet") - s3_path = f"/test/uploaded.parquet" + s3_path = "/test/uploaded.parquet" s3_client = boto3.client("s3", region_name="us-east-1") s3_client.upload_file( upload_file, TEST_BUCKET, s3_path, ) - expand_and_concat_sets(df, f"s3://{TEST_BUCKET}/{s3_path}", EXISTING_STUDY) + powerset_merge.expand_and_concat_sets( + df, f"s3://{TEST_BUCKET}/{s3_path}", EXISTING_STUDY + ) def test_parquet_to_csv(mock_bucket): @@ -359,7 +373,7 @@ def test_parquet_to_csv(mock_bucket): TEST_BUCKET, f"{bucket_root}/{subbucket_path}", ) - generate_csv_from_parquet(TEST_BUCKET, bucket_root, subbucket_path) + powerset_merge.generate_csv_from_parquet(TEST_BUCKET, bucket_root, subbucket_path) df = awswrangler.s3.read_csv( f"s3://{TEST_BUCKET}/{bucket_root}/{subbucket_path.replace('.parquet','.csv')}" ) diff --git a/tests/site_upload/test_process_upload.py b/tests/site_upload/test_process_upload.py index 787ca05..0b4f3e9 100644 --- a/tests/site_upload/test_process_upload.py +++ b/tests/site_upload/test_process_upload.py @@ -1,4 +1,3 @@ -import os from datetime import datetime, timezone import boto3 @@ -6,9 +5,9 @@ from freezegun import freeze_time from src.handlers.shared.enums import BucketPath -from src.handlers.shared.functions import read_metadata, write_metadata +from src.handlers.shared.functions import read_metadata from src.handlers.site_upload.process_upload import process_upload_handler -from tests.utils import ( +from tests.mock_utils import ( EXISTING_DATA_P, EXISTING_SITE, EXISTING_STUDY, @@ -16,10 +15,7 @@ ITEM_COUNT, NEW_DATA_P, NEW_SITE, - NEW_STUDY, NEW_VERSION, - OTHER_SITE, - OTHER_STUDY, TEST_BUCKET, ) @@ -179,6 +175,7 @@ def test_process_upload( or item["Key"].startswith(BucketPath.ADMIN.value) or item["Key"].startswith(BucketPath.CACHE.value) or item["Key"].endswith("study_periods.json") + or item["Key"].endswith("column_types.json") ) if found_archive: assert "template" in upload_path diff --git a/tests/site_upload/test_study_period.py b/tests/site_upload/test_study_period.py index 2657161..036dcda 100644 --- a/tests/site_upload/test_study_period.py +++ b/tests/site_upload/test_study_period.py @@ -1,5 +1,4 @@ import csv -import os from datetime import datetime, timezone import boto3 @@ -7,21 +6,17 @@ from freezegun import freeze_time from src.handlers.shared.enums import BucketPath -from src.handlers.shared.functions import read_metadata, write_metadata +from src.handlers.shared.functions import read_metadata from src.handlers.site_upload.study_period import study_period_handler -from tests.utils import ( +from tests.mock_utils import ( EXISTING_DATA_P, EXISTING_SITE, EXISTING_STUDY, EXISTING_VERSION, - NEW_DATA_P, NEW_SITE, NEW_STUDY, NEW_VERSION, - OTHER_SITE, - OTHER_STUDY, TEST_BUCKET, - get_mock_study_metadata, ) @@ -114,7 +109,7 @@ def test_process_upload( metadata[site][study][version]["last_data_update"] == datetime.now(timezone.utc).isoformat() ) - with open("./tests/test_data/meta_date.csv", "r") as file: + with open("./tests/test_data/meta_date.csv") as file: reader = csv.reader(file) # discarding CSV header row next(reader)