Skip to content

Commit

Permalink
Added study results download script (#135)
Browse files Browse the repository at this point in the history
* Added study results download script

* trim prints
  • Loading branch information
dogversioning authored Oct 25, 2024
1 parent e90c6fa commit 4c92ff1
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ auth.json
metadata.json
dashboard_api.yaml
.DS_Store
*.zip

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies= [
"awswrangler >=3.5, <4",
"boto3",
"pandas >=2, <3",
"requests", # scripts only
"rich",
]
authors = [
Expand Down
65 changes: 65 additions & 0 deletions scripts/bulk_csv_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import argparse
import io
import os
import pathlib
import sys
import zipfile

import requests
from rich.progress import track


def bulk_csv_download(args):
if args.api_key is None:
args.api_key = os.getenv("CUMULUS_AGGREGATOR_API_KEY")
args.type = args.type.replace("_", "-")
if args.type not in ["last-valid", "aggregates"]:
sys.exit('Invalid type. Expected "last-valid" or "aggregates"')
dp_url = f"https://{args.domain}/{args.type}"
try:
res = requests.get(dp_url, headers={"x-api-key": args.api_key}, timeout=300)
except requests.exceptions.ConnectionError:
sys.exit("Invalid domain name")
if res.status_code == 403:
sys.exit("Invalid API key")
file_urls = res.json()
urls = []
version = 0
for file_url in file_urls:
file_array = file_url.split("/")
dp_version = int(file_array[4 if args.type == "last-valid" else 3])
if file_array[1] == args.study:
if dp_version > version:
version = int(dp_version)
urls = []
elif int(dp_version) == version:
if (
args.type == "last-valid" and args.site == file_array[3]
) or args.type == "aggregates":
urls.append(file_url)
if len(urls) == 0:
sys.exit(f"No aggregates matching {args.study} found")
archive = io.BytesIO()
with zipfile.ZipFile(archive, "w") as zip_archive:
for file in track(urls, description=f"Downloading {args.study} aggregates"):
csv_url = f"https://{args.domain}/{file}"
res = requests.get(
csv_url, headers={"x-api-key": args.api_key}, allow_redirects=True, timeout=300
)
with zip_archive.open(file.split("/")[-1], "w") as f:
f.write(bytes(res.text, "UTF-8"))
with open(pathlib.Path.cwd() / f"{args.study}.zip", "wb") as output:
output.write(archive.getbuffer())


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="""Fetches all data for a given study""")
parser.add_argument("-s", "--study", help="Name of study to download", required=True)
parser.add_argument("-i", "--site", help="Name of site to download (last-valid only)")
parser.add_argument(
"-d", "--domain", help="Domain of aggregator", default="api.smartcumulus.org"
)
parser.add_argument("-t", "--type", help="type of aggregate", default="last-valid")
parser.add_argument("-a", "--apikey", dest="api_key", help="API key of aggregator")
args = parser.parse_args()
bulk_csv_download(args)
6 changes: 3 additions & 3 deletions src/handlers/dashboard/get_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_csv_list_handler(event, context):
s3_client = boto3.client("s3")
if event["path"].startswith("/last-valid"):
key_prefix = "last_valid"
url_prefix = "last_valid"
url_prefix = "last-valid"
elif event["path"].startswith("/aggregates"):
key_prefix = "csv_aggregates"
url_prefix = "aggregates"
Expand All @@ -104,9 +104,9 @@ def get_csv_list_handler(event, context):
data_package = key_parts[2].split("__")[1]
version = key_parts[-2]
filename = key_parts[-1]
site = key_parts[3] if url_prefix == "last_valid" else None
site = key_parts[3] if url_prefix == "last-valid" else None
url_parts = [url_prefix, study, data_package, version, filename]
if url_prefix == "last_valid":
if url_prefix == "last-valid":
url_parts.insert(3, site)
urls.append("/".join(url_parts))
if not s3_objs["IsTruncated"]:
Expand Down
2 changes: 1 addition & 1 deletion tests/dashboard/test_get_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_get_csv(mock_bucket, params, status, expected):
"/last-valid",
200,
[
"last_valid/study/encounter/princeton_plainsboro_teaching_hospital/099/study__encounter__aggregate.csv"
"last-valid/study/encounter/princeton_plainsboro_teaching_hospital/099/study__encounter__aggregate.csv"
],
does_not_raise(),
),
Expand Down

0 comments on commit 4c92ff1

Please sign in to comment.