diff --git a/.gitignore b/.gitignore index 169ef8c..9420dfe 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ auth.json metadata.json dashboard_api.yaml .DS_Store +*.zip # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/pyproject.toml b/pyproject.toml index 719472c..cc609ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies= [ "awswrangler >=3.5, <4", "boto3", "pandas >=2, <3", + "requests", # scripts only "rich", ] authors = [ diff --git a/scripts/bulk_csv_download.py b/scripts/bulk_csv_download.py new file mode 100644 index 0000000..416515b --- /dev/null +++ b/scripts/bulk_csv_download.py @@ -0,0 +1,65 @@ +import argparse +import io +import os +import pathlib +import sys +import zipfile + +import requests +from rich.progress import track + + +def bulk_csv_download(args): + if args.api_key is None: + args.api_key = os.getenv("CUMULUS_AGGREGATOR_API_KEY") + args.type = args.type.replace("_", "-") + if args.type not in ["last-valid", "aggregates"]: + sys.exit('Invalid type. Expected "last-valid" or "aggregates"') + dp_url = f"https://{args.domain}/{args.type}" + try: + res = requests.get(dp_url, headers={"x-api-key": args.api_key}, timeout=300) + except requests.exceptions.ConnectionError: + sys.exit("Invalid domain name") + if res.status_code == 403: + sys.exit("Invalid API key") + file_urls = res.json() + urls = [] + version = 0 + for file_url in file_urls: + file_array = file_url.split("/") + dp_version = int(file_array[4 if args.type == "last-valid" else 3]) + if file_array[1] == args.study: + if dp_version > version: + version = int(dp_version) + urls = [] + elif int(dp_version) == version: + if ( + args.type == "last-valid" and args.site == file_array[3] + ) or args.type == "aggregates": + urls.append(file_url) + if len(urls) == 0: + sys.exit(f"No aggregates matching {args.study} found") + archive = io.BytesIO() + with zipfile.ZipFile(archive, "w") as zip_archive: + for file in track(urls, description=f"Downloading {args.study} aggregates"): + csv_url = f"https://{args.domain}/{file}" + res = requests.get( + csv_url, headers={"x-api-key": args.api_key}, allow_redirects=True, timeout=300 + ) + with zip_archive.open(file.split("/")[-1], "w") as f: + f.write(bytes(res.text, "UTF-8")) + with open(pathlib.Path.cwd() / f"{args.study}.zip", "wb") as output: + output.write(archive.getbuffer()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="""Fetches all data for a given study""") + parser.add_argument("-s", "--study", help="Name of study to download", required=True) + parser.add_argument("-i", "--site", help="Name of site to download (last-valid only)") + parser.add_argument( + "-d", "--domain", help="Domain of aggregator", default="api.smartcumulus.org" + ) + parser.add_argument("-t", "--type", help="type of aggregate", default="last-valid") + parser.add_argument("-a", "--apikey", dest="api_key", help="API key of aggregator") + args = parser.parse_args() + bulk_csv_download(args) diff --git a/src/handlers/dashboard/get_csv.py b/src/handlers/dashboard/get_csv.py index 47fd19e..916d3b7 100644 --- a/src/handlers/dashboard/get_csv.py +++ b/src/handlers/dashboard/get_csv.py @@ -84,7 +84,7 @@ def get_csv_list_handler(event, context): s3_client = boto3.client("s3") if event["path"].startswith("/last-valid"): key_prefix = "last_valid" - url_prefix = "last_valid" + url_prefix = "last-valid" elif event["path"].startswith("/aggregates"): key_prefix = "csv_aggregates" url_prefix = "aggregates" @@ -104,9 +104,9 @@ def get_csv_list_handler(event, context): data_package = key_parts[2].split("__")[1] version = key_parts[-2] filename = key_parts[-1] - site = key_parts[3] if url_prefix == "last_valid" else None + site = key_parts[3] if url_prefix == "last-valid" else None url_parts = [url_prefix, study, data_package, version, filename] - if url_prefix == "last_valid": + if url_prefix == "last-valid": url_parts.insert(3, site) urls.append("/".join(url_parts)) if not s3_objs["IsTruncated"]: diff --git a/tests/dashboard/test_get_csv.py b/tests/dashboard/test_get_csv.py index 33eea44..c055b36 100644 --- a/tests/dashboard/test_get_csv.py +++ b/tests/dashboard/test_get_csv.py @@ -128,7 +128,7 @@ def test_get_csv(mock_bucket, params, status, expected): "/last-valid", 200, [ - "last_valid/study/encounter/princeton_plainsboro_teaching_hospital/099/study__encounter__aggregate.csv" + "last-valid/study/encounter/princeton_plainsboro_teaching_hospital/099/study__encounter__aggregate.csv" ], does_not_raise(), ),