From 6d3e29a0201d58a62fb2d17d73de5f9523b634dc Mon Sep 17 00:00:00 2001 From: Abdelhak Marouane Date: Thu, 24 Oct 2024 09:06:22 -0500 Subject: [PATCH] Using pythonoperators instead of ECS operator --- dags/generate_dags.py | 6 +- .../groups/collection_group.py | 20 +- .../groups/discover_group.py | 43 +- .../groups/processing_tasks.py | 209 +--------- .../groups/transfer_group.py | 65 +-- .../utils/build_stac/handler.py | 110 +++++ .../utils/build_stac/utils/__init__.py | 0 .../utils/build_stac/utils/events.py | 19 + .../utils/build_stac/utils/regex.py | 91 +++++ .../utils/build_stac/utils/role.py | 10 + .../utils/build_stac/utils/stac.py | 142 +++++++ .../utils/cogify_transfer/handler.py | 86 ++++ .../utils/cogify_transfer/requirements.txt | 11 + .../utils/vector_ingest/handler.py | 377 ++++++++++++++++++ .../utils/vector_ingest/requirements.txt | 7 + .../veda_dataset_pipeline.py | 63 ++- .../veda_discover_pipeline.py | 52 ++- .../veda_generic_vector_pipeline.py | 38 +- .../veda_vector_pipeline.py | 40 +- sm2a/airflow_worker/requirements.txt | 6 + sm2a/infrastructure/main.tf | 9 +- sm2a/infrastructure/variables.tf | 24 +- 22 files changed, 1041 insertions(+), 387 deletions(-) create mode 100644 dags/veda_data_pipeline/utils/build_stac/handler.py create mode 100644 dags/veda_data_pipeline/utils/build_stac/utils/__init__.py create mode 100644 dags/veda_data_pipeline/utils/build_stac/utils/events.py create mode 100644 dags/veda_data_pipeline/utils/build_stac/utils/regex.py create mode 100644 dags/veda_data_pipeline/utils/build_stac/utils/role.py create mode 100644 dags/veda_data_pipeline/utils/build_stac/utils/stac.py create mode 100644 dags/veda_data_pipeline/utils/cogify_transfer/handler.py create mode 100644 dags/veda_data_pipeline/utils/cogify_transfer/requirements.txt create mode 100644 dags/veda_data_pipeline/utils/vector_ingest/handler.py create mode 100644 dags/veda_data_pipeline/utils/vector_ingest/requirements.txt diff --git a/dags/generate_dags.py b/dags/generate_dags.py index 9eeab714..7bd2c31b 100644 --- a/dags/generate_dags.py +++ b/dags/generate_dags.py @@ -15,9 +15,9 @@ def generate_dags(): from pathlib import Path - - mwaa_stac_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True) - bucket = mwaa_stac_conf["EVENT_BUCKET"] + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + bucket = airflow_vars_json.get("EVENT_BUCKET") try: client = boto3.client("s3") diff --git a/dags/veda_data_pipeline/groups/collection_group.py b/dags/veda_data_pipeline/groups/collection_group.py index de4f2dd1..0a561175 100644 --- a/dags/veda_data_pipeline/groups/collection_group.py +++ b/dags/veda_data_pipeline/groups/collection_group.py @@ -32,28 +32,40 @@ def ingest_collection_task(ti): dataset (Dict[str, Any]): dataset dictionary (JSON) role_arn (str): role arn for Zarr collection generation """ + import json collection = ti.xcom_pull(task_ids='Collection.generate_collection') + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + cognito_app_secret = airflow_vars_json.get("COGNITO_APP_SECRET") + stac_ingestor_api_url = airflow_vars_json.get("STAC_INGESTOR_API_URL") return submission_handler( event=collection, endpoint="/collections", - cognito_app_secret=Variable.get("COGNITO_APP_SECRET"), - stac_ingestor_api_url=Variable.get("STAC_INGESTOR_API_URL"), + cognito_app_secret=cognito_app_secret, + stac_ingestor_api_url=stac_ingestor_api_url ) # NOTE unused, but useful for item ingests, since collections are a dependency for items def check_collection_exists_task(ti): + import json config = ti.dag_run.conf + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + stac_url = airflow_vars_json.get("STAC_URL") return check_collection_exists( - endpoint=Variable.get("STAC_URL", default_var=None), + endpoint=stac_url, collection_id=config.get("collection"), ) def generate_collection_task(ti): + import json config = ti.dag_run.conf - role_arn = Variable.get("ASSUME_ROLE_READ_ARN", default_var=None) + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + role_arn = airflow_vars_json.get("ASSUME_ROLE_READ_ARN") # TODO it would be ideal if this also works with complete collections where provided - this would make the collection ingest more re-usable collection = generator.generate_stac( diff --git a/dags/veda_data_pipeline/groups/discover_group.py b/dags/veda_data_pipeline/groups/discover_group.py index 8bf0c07e..63f41dd6 100644 --- a/dags/veda_data_pipeline/groups/discover_group.py +++ b/dags/veda_data_pipeline/groups/discover_group.py @@ -1,23 +1,17 @@ from datetime import timedelta -import time +import json import uuid from airflow.models.variable import Variable from airflow.models.xcom import LazyXComAccess -from airflow.operators.dummy_operator import DummyOperator as EmptyOperator -from airflow.decorators import task_group, task -from airflow.models.baseoperator import chain -from airflow.operators.python import BranchPythonOperator, PythonOperator, ShortCircuitOperator -from airflow.utils.trigger_rule import TriggerRule -from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator +from airflow.decorators import task from veda_data_pipeline.utils.s3_discovery import ( s3_discovery_handler, EmptyFileListError ) -from veda_data_pipeline.groups.processing_tasks import build_stac_kwargs, submit_to_stac_ingestor_task - group_kwgs = {"group_id": "Discover", "tooltip": "Discover"} + @task(retries=1, retry_delay=timedelta(minutes=1)) def discover_from_s3_task(ti=None, event={}, **kwargs): """Discover grouped assets/files from S3 in batches of 2800. Produce a list of such files stored on S3 to process. @@ -32,8 +26,11 @@ def discover_from_s3_task(ti=None, event={}, **kwargs): if event.get("schedule") and last_successful_execution: config["last_successful_execution"] = last_successful_execution.isoformat() # (event, chunk_size=2800, role_arn=None, bucket_output=None): - MWAA_STAC_CONF = Variable.get("MWAA_STACK_CONF", deserialize_json=True) - read_assume_arn = Variable.get("ASSUME_ROLE_READ_ARN", default_var=None) + + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + event_bucket = airflow_vars_json.get("EVENT_BUCKET") + read_assume_arn = airflow_vars_json.get("ASSUME_ROLE_READ_ARN") # Making the chunk size small, this helped us process large data faster than # passing a large chunk of 500 chunk_size = config.get("chunk_size", 500) @@ -41,7 +38,7 @@ def discover_from_s3_task(ti=None, event={}, **kwargs): return s3_discovery_handler( event=config, role_arn=read_assume_arn, - bucket_output=MWAA_STAC_CONF["EVENT_BUCKET"], + bucket_output=event_bucket, chunk_size=chunk_size ) except EmptyFileListError as ex: @@ -49,22 +46,24 @@ def discover_from_s3_task(ti=None, event={}, **kwargs): # TODO test continued short circuit operator behavior (no files -> skip remaining tasks) return {} + @task def get_files_to_process(payload, ti=None): """Get files from S3 produced by the discovery task. Used as part of both the parallel_run_process_rasters and parallel_run_process_vectors tasks. """ - if isinstance(payload, LazyXComAccess): # if used as part of a dynamic task mapping + if isinstance(payload, LazyXComAccess): # if used as part of a dynamic task mapping payloads_xcom = payload[0].pop("payload", []) payload = payload[0] else: payloads_xcom = payload.pop("payload", []) dag_run_id = ti.dag_run.run_id return [{ - "run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}", - **payload, - "payload": payload_xcom, - } for indx, payload_xcom in enumerate(payloads_xcom)] + "run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}", + **payload, + "payload": payload_xcom, + } for indx, payload_xcom in enumerate(payloads_xcom)] + @task def get_dataset_files_to_process(payload, ti=None): @@ -75,7 +74,7 @@ def get_dataset_files_to_process(payload, ti=None): result = [] for x in payload: - if isinstance(x, LazyXComAccess): # if used as part of a dynamic task mapping + if isinstance(x, LazyXComAccess): # if used as part of a dynamic task mapping payloads_xcom = x[0].pop("payload", []) payload_0 = x[0] else: @@ -83,8 +82,8 @@ def get_dataset_files_to_process(payload, ti=None): payload_0 = x for indx, payload_xcom in enumerate(payloads_xcom): result.append({ - "run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}", - **payload_0, - "payload": payload_xcom, - }) + "run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}", + **payload_0, + "payload": payload_xcom, + }) return result diff --git a/dags/veda_data_pipeline/groups/processing_tasks.py b/dags/veda_data_pipeline/groups/processing_tasks.py index a7bbe8c4..48758fcc 100644 --- a/dags/veda_data_pipeline/groups/processing_tasks.py +++ b/dags/veda_data_pipeline/groups/processing_tasks.py @@ -1,12 +1,9 @@ from datetime import timedelta import json import logging - import smart_open from airflow.models.variable import Variable -from airflow.operators.python import PythonOperator -from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator -from airflow.decorators import task_group, task +from airflow.decorators import task from veda_data_pipeline.utils.submit_stac import submission_handler group_kwgs = {"group_id": "Process", "tooltip": "Process"} @@ -15,11 +12,17 @@ def log_task(text: str): logging.info(text) + @task(retries=1, retry_delay=timedelta(minutes=1)) -def submit_to_stac_ingestor_task(built_stac:str): +def submit_to_stac_ingestor_task(built_stac: dict): """Submit STAC items to the STAC ingestor API.""" - event = json.loads(built_stac) + event = built_stac.copy() success_file = event["payload"]["success_event_key"] + + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + cognito_app_secret = airflow_vars_json.get("COGNITO_APP_SECRET") + stac_ingestor_api_url = airflow_vars_json.get("STAC_INGESTOR_API_URL") with smart_open.open(success_file, "r") as _file: stac_items = json.loads(_file.read()) @@ -27,201 +30,11 @@ def submit_to_stac_ingestor_task(built_stac:str): submission_handler( event=item, endpoint="/ingestions", - cognito_app_secret=Variable.get("COGNITO_APP_SECRET"), - stac_ingestor_api_url=Variable.get("STAC_INGESTOR_API_URL"), + cognito_app_secret=cognito_app_secret, + stac_ingestor_api_url=stac_ingestor_api_url, ) return event -@task -def build_stac_kwargs(event={}): - """Build kwargs for the ECS operator.""" - mwaa_stack_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True) - if event: - intermediate = { - **event - } # this is dumb but it resolves the MappedArgument to a dict that can be JSON serialized - payload = json.dumps(intermediate) - else: - payload = "{{ task_instance.dag_run.conf }}" - - return { - "overrides": { - "containerOverrides": [ - { - "name": f"{mwaa_stack_conf.get('PREFIX')}-veda-stac-build", - "command": [ - "/usr/local/bin/python", - "handler.py", - "--payload", - payload, - ], - "environment": [ - { - "name": "EXTERNAL_ROLE_ARN", - "value": Variable.get( - "ASSUME_ROLE_READ_ARN", default_var="" - ), - }, - { - "name": "BUCKET", - "value": "veda-data-pipelines-staging-lambda-ndjson-bucket", - }, - { - "name": "EVENT_BUCKET", - "value": mwaa_stack_conf.get("EVENT_BUCKET"), - } - ], - "memory": 2048, - "cpu": 1024, - }, - ], - }, - "network_configuration": { - "awsvpcConfiguration": { - "securityGroups": mwaa_stack_conf.get("SECURITYGROUPS"), - "subnets": mwaa_stack_conf.get("SUBNETS"), - }, - }, - "awslogs_group": mwaa_stack_conf.get("LOG_GROUP_NAME"), - "awslogs_stream_prefix": f"ecs/{mwaa_stack_conf.get('PREFIX')}-veda-stac-build", - } - -@task -def build_vector_kwargs(event={}): - """Build kwargs for the ECS operator.""" - mwaa_stack_conf = Variable.get( - "MWAA_STACK_CONF", default_var={}, deserialize_json=True - ) - vector_ecs_conf = Variable.get( - "VECTOR_ECS_CONF", default_var={}, deserialize_json=True - ) - - if event: - intermediate = { - **event - } - payload = json.dumps(intermediate) - else: - payload = "{{ task_instance.dag_run.conf }}" - - return { - "overrides": { - "containerOverrides": [ - { - "name": f"{mwaa_stack_conf.get('PREFIX')}-veda-vector_ingest", - "command": [ - "/var/lang/bin/python", - "handler.py", - "--payload", - payload, - ], - "environment": [ - { - "name": "EXTERNAL_ROLE_ARN", - "value": Variable.get( - "ASSUME_ROLE_READ_ARN", default_var="" - ), - }, - { - "name": "AWS_REGION", - "value": mwaa_stack_conf.get("AWS_REGION"), - }, - { - "name": "VECTOR_SECRET_NAME", - "value": Variable.get("VECTOR_SECRET_NAME"), - }, - { - "name": "AWS_STS_REGIONAL_ENDPOINTS", - "value": "regional", # to override this behavior, make sure AWS_REGION is set to `aws-global` - } - ], - }, - ], - }, - "network_configuration": { - "awsvpcConfiguration": { - "securityGroups": vector_ecs_conf.get("VECTOR_SECURITY_GROUP"), - "subnets": vector_ecs_conf.get("VECTOR_SUBNETS"), - }, - }, - "awslogs_group": mwaa_stack_conf.get("LOG_GROUP_NAME"), - "awslogs_stream_prefix": f"ecs/{mwaa_stack_conf.get('PREFIX')}-veda-vector_ingest", - } - - -@task -def build_generic_vector_kwargs(event={}): - """Build kwargs for the ECS operator.""" - mwaa_stack_conf = Variable.get( - "MWAA_STACK_CONF", default_var={}, deserialize_json=True - ) - vector_ecs_conf = Variable.get( - "VECTOR_ECS_CONF", default_var={}, deserialize_json=True - ) - - if event: - intermediate = { - **event - } - payload = json.dumps(intermediate) - else: - payload = "{{ task_instance.dag_run.conf }}" - - return { - "overrides":{ - "containerOverrides": [ - { - "name": f"{mwaa_stack_conf.get('PREFIX')}-veda-generic_vector_ingest", - "command": [ - "/var/lang/bin/python", - "handler.py", - "--payload", - payload, - ], - "environment": [ - { - "name": "EXTERNAL_ROLE_ARN", - "value": Variable.get( - "ASSUME_ROLE_READ_ARN", default_var="" - ), - }, - { - "name": "AWS_REGION", - "value": mwaa_stack_conf.get("AWS_REGION"), - }, - { - "name": "VECTOR_SECRET_NAME", - "value": Variable.get("VECTOR_SECRET_NAME"), - }, - { - "name": "AWS_STS_REGIONAL_ENDPOINTS", - "value": "regional", # to override this behavior, make sure AWS_REGION is set to `aws-global` - } - ], - }, - ], - }, - "network_configuration":{ - "awsvpcConfiguration": { - "securityGroups": vector_ecs_conf.get("VECTOR_SECURITY_GROUP") + mwaa_stack_conf.get("SECURITYGROUPS"), - "subnets": vector_ecs_conf.get("VECTOR_SUBNETS"), - }, - }, - "awslogs_group":mwaa_stack_conf.get("LOG_GROUP_NAME"), - "awslogs_stream_prefix":f"ecs/{mwaa_stack_conf.get('PREFIX')}-veda-generic-vector_ingest", # prefix with container name - } - - -@task_group -def subdag_process(event={}): - build_stac = EcsRunTaskOperator.partial( - task_id="build_stac" - ).expand_kwargs(build_stac_kwargs(event=event)) - submit_to_stac_ingestor = PythonOperator( - task_id="submit_to_stac_ingestor", - python_callable=submit_to_stac_ingestor_task, - ) - build_stac >> submit_to_stac_ingestor \ No newline at end of file diff --git a/dags/veda_data_pipeline/groups/transfer_group.py b/dags/veda_data_pipeline/groups/transfer_group.py index a4235496..3bf30e1f 100644 --- a/dags/veda_data_pipeline/groups/transfer_group.py +++ b/dags/veda_data_pipeline/groups/transfer_group.py @@ -2,12 +2,9 @@ from airflow.models.variable import Variable from airflow.operators.python import BranchPythonOperator, PythonOperator -from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator +import json from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule -from veda_data_pipeline.utils.transfer import ( - data_transfer_handler, -) group_kwgs = {"group_id": "Transfer", "tooltip": "Transfer"} @@ -22,12 +19,27 @@ def cogify_choice(ti): return f"{group_kwgs['group_id']}.copy_data" +def cogify_copy_task(ti): + from veda_data_pipeline.utils.cogify_transfer.handler import cogify_transfer_handler + config = ti.dag_run.conf + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + external_role_arn = airflow_vars_json.get("ASSUME_ROLE_WRITE_ARN") + return cogify_transfer_handler(event_src=config, external_role_arn=external_role_arn) + + def transfer_data(ti): """Transfer data from one S3 bucket to another; s3 copy, no need for docker""" + from veda_data_pipeline.utils.transfer import ( + data_transfer_handler, + ) config = ti.dag_run.conf - role_arn = Variable.get("ASSUME_ROLE_READ_ARN", default_var="") + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + external_role_arn = airflow_vars_json.get("ASSUME_ROLE_WRITE_ARN") # (event, chunk_size=2800, role_arn=None, bucket_output=None): - return data_transfer_handler(event=config, role_arn=role_arn) + return data_transfer_handler(event=config, role_arn=external_role_arn) + # TODO: cogify_transfer handler is missing arg parser so this subdag will not work def subdag_transfer(): @@ -43,47 +55,10 @@ def subdag_transfer(): python_callable=transfer_data, op_kwargs={"text": "Copy files on S3"}, ) - - mwaa_stack_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True) - run_cogify_copy = EcsRunTaskOperator( + run_cogify_copy = PythonOperator( task_id="cogify_and_copy_data", trigger_rule="none_failed", - cluster=f"{mwaa_stack_conf.get('PREFIX')}-cluster", - task_definition=f"{mwaa_stack_conf.get('PREFIX')}-transfer-tasks", - launch_type="FARGATE", - do_xcom_push=True, - execution_timeout=timedelta(minutes=120), - overrides={ - "containerOverrides": [ - { - "name": f"{mwaa_stack_conf.get('PREFIX')}-veda-cogify-transfer", - "command": [ - "/usr/local/bin/python", - "handler.py", - "--payload", - "{}".format("{{ task_instance.dag_run.conf }}"), - ], - "environment": [ - { - "name": "EXTERNAL_ROLE_ARN", - "value": Variable.get( - "ASSUME_ROLE_READ_ARN", default_var="" - ), - }, - ], - "memory": 2048, - "cpu": 1024, - }, - ], - }, - network_configuration={ - "awsvpcConfiguration": { - "securityGroups": mwaa_stack_conf.get("SECURITYGROUPS"), - "subnets": mwaa_stack_conf.get("SUBNETS"), - }, - }, - awslogs_group=mwaa_stack_conf.get("LOG_GROUP_NAME"), - awslogs_stream_prefix=f"ecs/{mwaa_stack_conf.get('PREFIX')}-veda-cogify-transfer", # prefix with container name + python_callable=cogify_copy_task ) (cogify_branching >> [run_copy, run_cogify_copy]) diff --git a/dags/veda_data_pipeline/utils/build_stac/handler.py b/dags/veda_data_pipeline/utils/build_stac/handler.py new file mode 100644 index 00000000..9eb017b0 --- /dev/null +++ b/dags/veda_data_pipeline/utils/build_stac/handler.py @@ -0,0 +1,110 @@ +import json +from typing import Any, Dict, TypedDict, Union +from uuid import uuid4 +import smart_open +from veda_data_pipeline.utils.build_stac.utils import events +from veda_data_pipeline.utils.build_stac.utils import stac + + +class S3LinkOutput(TypedDict): + stac_file_url: str + + +class StacItemOutput(TypedDict): + stac_item: Dict[str, Any] + + +def handler(event: Dict[str, Any]) -> Union[S3LinkOutput, StacItemOutput]: + """ + Handler for STAC Collection Item generation + + Arguments: + event - object with event parameters + { + "collection": "OMDOAO3e", + "id_regex": "_(.*).tif", + "assets": { + "OMDOAO3e_LUT": { + "title": "OMDOAO3e_LUT", + "description": "OMDOAO3e_LUT, described", + "href": "s3://climatedashboard-data/OMDOAO3e/OMDOAO3e_LUT.tif", + }, + "OMDOAO3e_LUT": { + "title": "OMDOAO3e_LUT", + "description": "OMDOAO3e_LUT, described", + "href": "s3://climatedashboard-data/OMDOAO3e/OMDOAO3e_LUT.tif", + } + } + } + + """ + + parsed_event = events.RegexEvent.parse_obj(event) + try: + stac_item = stac.generate_stac(parsed_event).to_dict() + except Exception as ex: + out_err: StacItemOutput = {"stac_item": {"error": f"{ex}", "event": event}} + return out_err + + output: StacItemOutput = {"stac_item": stac_item} + return output + + +def sequential_processing(objects): + returned_results = [] + for _object in objects: + result = handler(_object) + returned_results.append(result) + return returned_results + + +def write_outputs_to_s3(key, payload_success, payload_failures): + success_key = f"{key}/build_stac_output_{uuid4()}.json" + with smart_open.open(success_key, "w") as _file: + _file.write(json.dumps(payload_success)) + dead_letter_key = "" + if payload_failures: + dead_letter_key = f"{key}/dead_letter_events/build_stac_failed_{uuid4()}.json" + with smart_open.open(dead_letter_key, "w") as _file: + _file.write(json.dumps(payload_failures)) + return [success_key, dead_letter_key] + + +def stac_handler(payload_src: dict, bucket_output): + payload_event = payload_src.copy() + s3_event = payload_event.pop("payload") + collection = payload_event.get("collection", "not_provided") + key = f"s3://{bucket_output}/events/{collection}" + payload_success = [] + payload_failures = [] + with smart_open.open(s3_event, "r") as _file: + s3_event_read = _file.read() + event_received = json.loads(s3_event_read) + objects = event_received["objects"] + payloads = sequential_processing(objects) + for payload in payloads: + stac_item = payload["stac_item"] + if "error" in stac_item: + payload_failures.append(stac_item) + else: + payload_success.append(stac_item) + success_key, dead_letter_key = write_outputs_to_s3( + key=key, payload_success=payload_success, payload_failures=payload_failures + ) + + # Silent dead letters are nice, but we want the Airflow UI to quickly alert us if something went wrong. + if len(payload_failures) != 0: + raise ValueError( + f"Some items failed to be processed. Failures logged here: {dead_letter_key}" + ) + + return { + "payload": { + "success_event_key": success_key, + "failed_event_key": dead_letter_key, + "status": { + "successes": len(payload_success), + "failures": len(payload_failures), + }, + } + } diff --git a/dags/veda_data_pipeline/utils/build_stac/utils/__init__.py b/dags/veda_data_pipeline/utils/build_stac/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dags/veda_data_pipeline/utils/build_stac/utils/events.py b/dags/veda_data_pipeline/utils/build_stac/utils/events.py new file mode 100644 index 00000000..c7e835b9 --- /dev/null +++ b/dags/veda_data_pipeline/utils/build_stac/utils/events.py @@ -0,0 +1,19 @@ +from datetime import datetime +from typing import Dict, Literal, Optional + +from pydantic import BaseModel, Field + +INTERVAL = Literal["month", "year", "day"] + + +class RegexEvent(BaseModel, frozen=True): + collection: str + item_id: str + assets: Dict + + start_datetime: Optional[datetime] = None + end_datetime: Optional[datetime] = None + single_datetime: Optional[datetime] = None + + properties: Optional[Dict] = Field(default_factory=dict) + datetime_range: Optional[INTERVAL] = None diff --git a/dags/veda_data_pipeline/utils/build_stac/utils/regex.py b/dags/veda_data_pipeline/utils/build_stac/utils/regex.py new file mode 100644 index 00000000..471832ee --- /dev/null +++ b/dags/veda_data_pipeline/utils/build_stac/utils/regex.py @@ -0,0 +1,91 @@ +import re +from datetime import datetime +from typing import Callable, Dict, Tuple, Union + +from dateutil.relativedelta import relativedelta + +from . import events + +DATERANGE = Tuple[datetime, datetime] + + +def _calculate_year_range(datetime_obj: datetime) -> DATERANGE: + start_datetime = datetime_obj.replace(month=1, day=1) + end_datetime = datetime_obj.replace(month=12, day=31) + return start_datetime, end_datetime + + +def _calculate_month_range(datetime_obj: datetime) -> DATERANGE: + start_datetime = datetime_obj.replace(day=1) + end_datetime = datetime_obj + relativedelta(day=31) + return start_datetime, end_datetime + + +def _calculate_day_range(datetime_obj: datetime) -> DATERANGE: + start_datetime = datetime_obj + end_datetime = datetime_obj + relativedelta(hour=23, minute=59, second=59) + return start_datetime, end_datetime + + +DATETIME_RANGE_METHODS: Dict[events.INTERVAL, Callable[[datetime], DATERANGE]] = { + "month": _calculate_month_range, + "year": _calculate_year_range, + "day": _calculate_day_range, +} + + +def extract_dates( + filename: str, datetime_range: events.INTERVAL +) -> Union[Tuple[datetime, datetime, None], Tuple[None, None, datetime]]: + """ + Extracts start & end or single date string from filename. + """ + DATE_REGEX_STRATEGIES = [ + (r"_(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})", "%Y-%m-%dT%H:%M:%S"), + (r"_(\d{8}T\d{6})", "%Y%m%dT%H%M%S"), + (r"_(\d{4}_\d{2}_\d{2})", "%Y_%m_%d"), + (r"_(\d{4}-\d{2}-\d{2})", "%Y-%m-%d"), + (r"_(\d{8})", "%Y%m%d"), + (r"_(\d{6})", "%Y%m"), + (r"_(\d{4})", "%Y"), + ] + + # Find dates in filename + dates = [] + for pattern, dateformat in DATE_REGEX_STRATEGIES: + dates_found = re.compile(pattern).findall(filename) + if not dates_found: + continue + + for date_str in dates_found: + dates.append(datetime.strptime(date_str, dateformat)) + + break + + num_dates_found = len(dates) + + # No dates found + if not num_dates_found: + raise Exception( + f"No dates provided in {filename=}. " + "At least one date in format yyyy-mm-dd is required." + ) + + # Many dates found + if num_dates_found > 1: + dates.sort() + start_datetime, *_, end_datetime = dates + return start_datetime, end_datetime, None + + # Single date found + single_datetime = dates[0] + + # Convert single date to range + if datetime_range: + start_datetime, end_datetime = DATETIME_RANGE_METHODS[datetime_range]( + single_datetime + ) + return start_datetime, end_datetime, None + + # Return single date + return None, None, single_datetime diff --git a/dags/veda_data_pipeline/utils/build_stac/utils/role.py b/dags/veda_data_pipeline/utils/build_stac/utils/role.py new file mode 100644 index 00000000..817c0ad3 --- /dev/null +++ b/dags/veda_data_pipeline/utils/build_stac/utils/role.py @@ -0,0 +1,10 @@ +import boto3 + + +def assume_role(role_arn, session_name): + sts = boto3.client("sts") + creds = sts.assume_role( + RoleArn=role_arn, + RoleSessionName=session_name, + ) + return creds["Credentials"] diff --git a/dags/veda_data_pipeline/utils/build_stac/utils/stac.py b/dags/veda_data_pipeline/utils/build_stac/utils/stac.py new file mode 100644 index 00000000..9ec69f14 --- /dev/null +++ b/dags/veda_data_pipeline/utils/build_stac/utils/stac.py @@ -0,0 +1,142 @@ +import os + +import pystac +import rasterio +from pystac.utils import datetime_to_str +from rasterio.session import AWSSession +from rio_stac import stac +from rio_stac.stac import PROJECTION_EXT_VERSION, RASTER_EXT_VERSION + + +from . import events, regex, role + + +def get_sts_session(): + if role_arn := os.environ.get("EXTERNAL_ROLE_ARN"): + creds = role.assume_role(role_arn, "veda-data-pipelines_build-stac") + return AWSSession( + aws_access_key_id=creds["AccessKeyId"], + aws_secret_access_key=creds["SecretAccessKey"], + aws_session_token=creds["SessionToken"], + ) + return + + +def create_item( + item_id, + bbox, + properties, + datetime, + collection, + assets, +) -> pystac.Item: + """ + Function to create a stac item from a COG using rio_stac + """ + # item + item = pystac.Item( + id=item_id, + geometry=stac.bbox_to_geom(bbox), + bbox=bbox, + collection=collection, + stac_extensions=[ + f"https://stac-extensions.github.io/raster/{RASTER_EXT_VERSION}/schema.json", + f"https://stac-extensions.github.io/projection/{PROJECTION_EXT_VERSION}/schema.json", + ], + datetime=datetime, + properties=properties, + ) + + # if we add a collection we MUST add a link + if collection: + item.add_link( + pystac.Link( + pystac.RelType.COLLECTION, + collection, + media_type=pystac.MediaType.JSON, + ) + ) + + for key, asset in assets.items(): + item.add_asset(key=key, asset=asset) + return item + + +def generate_stac(event: events.RegexEvent) -> pystac.Item: + """ + Generate STAC item from user provided datetime range or regex & filename + """ + start_datetime = end_datetime = single_datetime = None + if event.start_datetime and event.end_datetime: + start_datetime = event.start_datetime + end_datetime = event.end_datetime + single_datetime = None + elif single_datetime := event.single_datetime: + start_datetime = end_datetime = None + single_datetime = single_datetime + else: + # Having multiple assets, we try against all filenames. + for asset_name, asset in event.assets.items(): + try: + filename = asset["href"].split("/")[-1] + start_datetime, end_datetime, single_datetime = regex.extract_dates( + filename, event.datetime_range + ) + break + except Exception: + continue + # Raise if dates can't be found + if not (start_datetime or end_datetime or single_datetime): + raise ValueError("No dates found in event config or by regex") + + properties = event.properties or {} + if start_datetime and end_datetime: + properties["start_datetime"] = datetime_to_str(start_datetime) + properties["end_datetime"] = datetime_to_str(end_datetime) + single_datetime = None + assets = {} + + rasterio_kwargs = {} + rasterio_kwargs["session"] = get_sts_session() + with rasterio.Env( + session=rasterio_kwargs.get("session"), + options={**rasterio_kwargs}, + ): + bboxes = [] + for asset_name, asset_definition in event.assets.items(): + with rasterio.open(asset_definition["href"]) as src: + # Get BBOX and Footprint + dataset_geom = stac.get_dataset_geom(src, densify_pts=0, precision=-1) + bboxes.append(dataset_geom["bbox"]) + + media_type = stac.get_media_type(src) + proj_info = { + f"proj:{name}": value + for name, value in stac.get_projection_info(src).items() + } + raster_info = {"raster:bands": stac.get_raster_info(src, max_size=1024)} + + # The default asset name for cogs is "cog_default", so we need to intercept 'default' + if asset_name == "default": + asset_name = "cog_default" + assets[asset_name] = pystac.Asset( + title=asset_definition["title"], + description=asset_definition["description"], + href=asset_definition["href"], + media_type=media_type, + roles=["data", "layer"], + extra_fields={**proj_info, **raster_info}, + ) + + minx, miny, maxx, maxy = zip(*bboxes) + bbox = [min(minx), min(miny), max(maxx), max(maxy)] + + create_item_response = create_item( + item_id=event.item_id, + bbox=bbox, + properties=properties, + datetime=single_datetime, + collection=event.collection, + assets=assets, + ) + return create_item_response diff --git a/dags/veda_data_pipeline/utils/cogify_transfer/handler.py b/dags/veda_data_pipeline/utils/cogify_transfer/handler.py new file mode 100644 index 00000000..0e9db5eb --- /dev/null +++ b/dags/veda_data_pipeline/utils/cogify_transfer/handler.py @@ -0,0 +1,86 @@ +import re +import tempfile + +import boto3 +from rio_cogeo.cogeo import cog_translate + + +def assume_role(role_arn, session_name="veda-airflow-pipelines_transfer_files"): + sts = boto3.client("sts") + credentials = sts.assume_role( + RoleArn=role_arn, + RoleSessionName=session_name, + ) + creds = credentials["Credentials"] + return { + "aws_access_key_id": creds["AccessKeyId"], + "aws_secret_access_key": creds.get("SecretAccessKey"), + "aws_session_token": creds.get("SessionToken"), + } + + +def get_matching_files(s3_client, bucket, prefix, regex_pattern): + matching_files = [] + + response = s3_client.list_objects_v2(Bucket=bucket, Prefix=prefix) + while True: + for obj in response["Contents"]: + file_key = obj["Key"] + if re.match(regex_pattern, file_key): + matching_files.append(file_key) + + if "NextContinuationToken" in response: + response = s3_client.list_objects_v2( + Bucket=bucket, + Prefix=prefix, + ContinuationToken=response["NextContinuationToken"], + ) + else: + break + + return matching_files + + +def transfer_file(s3_client, file_key, local_file_path, destination_bucket, collection): + filename = file_key.split("/")[-1] + target_key = f"{collection}/{filename}" + s3_client.upload_file(local_file_path, destination_bucket, target_key) + + +def cogify_transfer_handler(event_src, external_role_arn=None): + event = event_src.copy() + kwargs = {} + if external_role_arn: + creds = assume_role(external_role_arn, "veda-data-pipelines_data-transfer") + kwargs = { + "aws_access_key_id": creds["AccessKeyId"], + "aws_secret_access_key": creds["SecretAccessKey"], + "aws_session_token": creds["SessionToken"], + } + source_s3 = boto3.client("s3") + target_s3 = boto3.client("s3", **kwargs) + + origin_bucket = event.get("origin_bucket") + origin_prefix = event.get("origin_prefix") + regex_pattern = event.get("filename_regex") + target_bucket = event.get("target_bucket", "veda-data-store-staging") + collection = event.get("collection") + + matching_files = get_matching_files( + source_s3, origin_bucket, origin_prefix, regex_pattern + ) + if not event.get("dry_run"): + for origin_key in matching_files: + with tempfile.NamedTemporaryFile() as local_tif, tempfile.NamedTemporaryFile() as local_cog: + local_tif_path = local_tif.name + local_cog_path = local_cog.name + source_s3.download_file(origin_bucket, origin_key, local_tif_path) + cog_translate(local_tif_path, local_cog_path, quiet=True) + filename = origin_key.split("/")[-1] + destination_key = f"{collection}/{filename}" + target_s3.upload_file(local_cog_path, target_bucket, destination_key) + else: + print( + f"Would have copied {len(matching_files)} files from {origin_bucket} to {target_bucket}" + ) + print(f"Files matched: {matching_files}") diff --git a/dags/veda_data_pipeline/utils/cogify_transfer/requirements.txt b/dags/veda_data_pipeline/utils/cogify_transfer/requirements.txt new file mode 100644 index 00000000..56e091b1 --- /dev/null +++ b/dags/veda_data_pipeline/utils/cogify_transfer/requirements.txt @@ -0,0 +1,11 @@ +aws-lambda-powertools +awslambdaric +boto3 +pystac==1.4.0 +python-cmr +rasterio==1.3.3 +rio-cogeo==4.0.0 +shapely +smart-open==6.3.0 +pydantic==1.10.7 +typing-extensions==4.5.0 diff --git a/dags/veda_data_pipeline/utils/vector_ingest/handler.py b/dags/veda_data_pipeline/utils/vector_ingest/handler.py new file mode 100644 index 00000000..09e7d437 --- /dev/null +++ b/dags/veda_data_pipeline/utils/vector_ingest/handler.py @@ -0,0 +1,377 @@ +import base64 +from argparse import ArgumentParser +import boto3 +import os +import subprocess +import json +import smart_open +from urllib.parse import urlparse +import psycopg2 +import geopandas as gpd +from shapely import wkb +from geoalchemy2 import Geometry +import sqlalchemy +from sqlalchemy import create_engine, MetaData, Table, Column, inspect +import concurrent.futures +from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, INTEGER, VARCHAR, TIMESTAMP + + +def download_file(file_uri: str, role_arn:[str, None]): + session = boto3.Session() + if role_arn: + sts = boto3.client("sts") + response = sts.assume_role( + RoleArn=role_arn, + RoleSessionName="airflow_vector_ingest", + ) + session = boto3.Session( + aws_access_key_id=response["Credentials"]["AccessKeyId"], + aws_secret_access_key=response["Credentials"]["SecretAccessKey"], + aws_session_token=response["Credentials"]["SessionToken"], + ) + s3 = session.client("s3") + + url_parse = urlparse(file_uri) + + bucket = url_parse.netloc + path = url_parse.path[1:] + filename = url_parse.path.split("/")[-1] + target_filepath = os.path.join("/tmp", filename) + + s3.download_file(bucket, path, target_filepath) + + print(f"downloaded {target_filepath}") + + + return target_filepath + + +def get_connection_string(secret: dict, as_uri: bool = False) -> str: + if as_uri: + return f"postgresql://{secret['username']}:{secret['password']}@{secret['host']}:5432/{secret['dbname']}" + else: + return f"PG:host={secret['host']} dbname={secret['dbname']} user={secret['username']} password={secret['password']}" + + +def get_gdf_schema(gdf, target_projection): + """map GeoDataFrame columns into a table schema + + :param gdf: GeoDataFrame from geopandas + :param target_projection: srid for the target table geometry column + :return: + """ + # map geodatafrome dtypes to sqlalchemy types + dtype_map = { + "int64": INTEGER, + "float64": DOUBLE_PRECISION, + "object": VARCHAR, + "datetime64": TIMESTAMP, + } + schema = [] + for column, dtype in zip(gdf.columns, gdf.dtypes): + if str(dtype) == "geometry": + # do not inpsect to retrieve geom type, just use generic GEOMETRY + # geom_type = str(gdf[column].geom_type.unique()[0]).upper() + geom_type = str(dtype).upper() + # do not taKe SRID from existing file for target table + # we always want to transform from file EPSG to Table EPSG() + column_type = Geometry(geometry_type=geom_type, srid=target_projection) + else: + dtype_str = str(dtype) + column_type = dtype_map.get(dtype_str.split("[")[0], VARCHAR) + + if column == "primarykey": + schema.append(Column(column.lower(), column_type, unique=True)) + else: + schema.append(Column(column.lower(), column_type)) + return schema + + +def ensure_table_exists( + db_metadata: MetaData, gpkg_file: str, target_projection: int, table_name: str +): + """create a table if it doesn't exist or just + validate GeoDataFrame columns against existing table + + :param db_metadata: instance of sqlalchemy.MetaData + :param gpkg_file: file path to GPKG + :param target_projection: srid for target DB table geometry column + :param table_name: name of table to create + :return: None + """ + gdf = gpd.read_file(gpkg_file) + gdf_schema = get_gdf_schema(gdf, target_projection) + engine = db_metadata.bind + try: + Table(table_name, db_metadata, autoload_with=engine) + except sqlalchemy.exc.NoSuchTableError: + Table(table_name, db_metadata, *gdf_schema) + db_metadata.create_all(engine) + + # validate gdf schema against existing table schema + insp = inspect(engine) + existing_columns = insp.get_columns(table_name) + existing_column_names = [col["name"] for col in existing_columns] + for column in gdf_schema: + if column.name not in existing_column_names: + raise ValueError( + f"your .gpkg seems to have a column={column.name} that does not exist in the existing table columns={existing_column_names}" + ) + + +def delete_region( + engine, + gpkg_path: str, + table_name: str, +): + gdf = gpd.read_file(gpkg_path) + if 'region' in gdf.columns: + region_name = gdf["region"].iloc[0] + with engine.connect() as conn: + with conn.begin(): + delete_sql = sqlalchemy.text( + f""" + DELETE FROM {table_name} WHERE region=:region_name + """ + ) + conn.execute(delete_sql, {'region_name': region_name}) + else: + print(f"'region' column not found in {gpkg_path}. No records deleted.") + + +def upsert_to_postgis( + engine, + gpkg_path: str, + target_projection: int, + table_name: str, + batch_size: int = 10000, +): + """batch the GPKG file and upsert via threads + + :param engine: instance of sqlalchemy.Engine + :param gpkg_path: file path to GPKG + :param table_name: name of the target table + :param batch_size: upper limit of batch size + :return: + """ + gdf = gpd.read_file(gpkg_path) + source_epsg_code = gdf.crs.to_epsg() + if not source_epsg_code: + # assume NAD27 Equal Area for now :shrug: + # since that's what the default is for Fire Atlas team exports + # that's what PROJ4 does under the hood for 9311 :wethinksmirk: + source_epsg_code = 2163 + + # convert the `t` column to something suitable for sql insertion otherwise we get 'Timestamp()' + gdf["t"] = gdf["t"].dt.strftime("%Y-%m-%d %H:%M:%S") + # convert to WKB + gdf["geometry"] = gdf["geometry"].apply(lambda geom: wkb.dumps(geom, hex=True)) + + def upsert_batch(batch): + with engine.connect() as conn: + with conn.begin(): + for row in batch.to_dict(orient="records"): + # make sure all column names are lower case for keys and values + row = {k.lower(): v for k, v in row.items()} + columns = [col.lower() for col in batch.columns] + + non_geom_placeholders = ", ".join( + [f":{col}" for col in columns[:-1]] + ) + # NOTE: we need to escape `::geometry` so parameterized statements don't try to replace it + # because parametrized statements in sqlalchemy are `:` + geom_placeholder = f"ST_Transform(ST_SetSRID(ST_GeomFromWKB(:geometry\:\:geometry), {source_epsg_code}), {target_projection})" # noqa: W605 + upsert_sql = sqlalchemy.text( + f""" + INSERT INTO {table_name} ({', '.join([col for col in columns])}) + VALUES ({non_geom_placeholders},{geom_placeholder}) + ON CONFLICT (primarykey) + DO UPDATE SET {', '.join(f"{col}=EXCLUDED.{col}" for col in columns if col != 'primarykey')} + """ + ) + + # logging.debug(f"[ UPSERT SQL ]:\n{str(upsert_sql)}") + conn.execute(upsert_sql, row) + + batches = [gdf.iloc[i : i + batch_size] for i in range(0, len(gdf), batch_size)] + # set `max_workers` to something below max concurrent connections for postgresql + # https://www.postgresql.org/docs/14/runtime-config-connection.html + with concurrent.futures.ThreadPoolExecutor(max_workers=75) as executor: + executor.map(upsert_batch, batches) + + +def get_secret(secret_name: str, region_name: str = "us-west-2") -> None: + """Retrieve secrets from AWS Secrets Manager + + Args: + secret_name (str): name of aws secrets manager secret containing database connection secrets + + Returns: + secrets (dict): decrypted secrets in dict + """ + + # Create a Secrets Manager client + session = boto3.session.Session(region_name=region_name) + client = session.client(service_name="secretsmanager") + + # In this sample we only handle the specific exceptions for the 'GetSecretValue' API. + # See https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html + # We rethrow the exception by default. + + get_secret_value_response = client.get_secret_value(SecretId=secret_name) + + # Decrypts secret using the associated KMS key. + # Depending on whether the secret is a string or binary, one of these fields will be populated. + if "SecretString" in get_secret_value_response: + return json.loads(get_secret_value_response["SecretString"]) + else: + return json.loads(base64.b64decode(get_secret_value_response["SecretBinary"])) + + +def load_to_featuresdb( + filename: str, + collection: str, + vector_secret_name: str, + extra_flags: list = None, + target_projection: str = "EPSG:4326", +): + if extra_flags is None: + extra_flags = ["-overwrite", "-progress"] + + secret_name = vector_secret_name + + con_secrets = get_secret(secret_name) + connection = get_connection_string(con_secrets) + + print(f"running ogr2ogr import for collection: {collection}") + options = [ + "ogr2ogr", + "-f", + "PostgreSQL", + connection, + "-t_srs", + target_projection, + filename, + "-nln", + collection, + *extra_flags, + ] + out = subprocess.run( + options, + check=False, + capture_output=True, + ) + + if out.stderr: + error_description = f"Error: {out.stderr}" + print(error_description) + return {"status": "failure", "reason": error_description} + + return {"status": "success"} + + +def load_to_featuresdb_eis( + filename: str, + collection: str, + vector_secret_name: str, + target_projection: int = 4326, +): + """create table if not exists and upload GPKG + + :param filename: the file path to the downloaded GPKG + :param collection: the name of the collection + :param target_projection: srid for the target table + :return: None + """ + secret_name = vector_secret_name + conn_secrets = get_secret(secret_name) + connection_string = get_connection_string(conn_secrets, as_uri=True) + + # NOTE: about `collection.rsplit` below: + # + # EIS Fire team naming convention for outputs + # Snapshots: "snapshot_{layer_name}_nrt_{region_name}.gpkg" + # Lf_archive: "lf_{layer_name}_archive_{region_name}.gpkg" + # Lf_nrt: "lf_{layer_name}_nrt_{region_name}.gpkg" + # + # Insert/Alter on table call everything except the region name: + # e.g. `snapshot_perimeter_nrt_conus` this gets inserted into the table `eis_fire_snapshot_perimeter_nrt` + collection = collection.rsplit("_", 1)[0] + target_table_name = f"eis_fire_{collection}" + + engine = create_engine(connection_string) + metadata = MetaData() + metadata.bind = engine + + ensure_table_exists(metadata, filename, target_projection, target_table_name) + delete_region(engine, filename, target_table_name) + upsert_to_postgis(engine, filename, target_projection, target_table_name) + return {"status": "success"} + + +def alter_datetime_add_indexes_eis(collection: str,vector_secret_name: str ): + # NOTE: about `collection.rsplit` below: + # + # EIS Fire team naming convention for outputs + # Snapshots: "snapshot_{layer_name}_nrt_{region_name}.gpkg" + # Lf_archive: "lf_{layer_name}_archive_{region_name}.gpkg" + # Lf_nrt: "lf_{layer_name}_nrt_{region_name}.gpkg" + # + # Insert/Alter on table call everything except the region name: + # e.g. `snapshot_perimeter_nrt_conus` this gets inserted into the table `eis_fire_snapshot_perimeter_nrt` + collection = collection.rsplit("_", 1)[0] + + secret_name = vector_secret_name + conn_secrets = get_secret(secret_name) + conn = psycopg2.connect( + host=conn_secrets["host"], + dbname=conn_secrets["dbname"], + user=conn_secrets["username"], + password=conn_secrets["password"], + ) + + cur = conn.cursor() + cur.execute( + f"ALTER table eis_fire_{collection} " + f"ALTER COLUMN t TYPE TIMESTAMP USING t::timestamp without time zone; " + f"CREATE INDEX IF NOT EXISTS idx_eis_fire_{collection}_datetime ON eis_fire_{collection}(t);" + f"CREATE INDEX IF NOT EXISTS idx_eis_fire_{collection}_primarykey ON eis_fire_{collection}(primarykey);" + f"CREATE INDEX IF NOT EXISTS idx_eis_fire_{collection}_region ON eis_fire_{collection}(region);" + ) + conn.commit() + + +def handler(payload_src: dict, vector_secret_name: str, assume_role_arn: [str, None]): + + payload_event = payload_src.copy() + s3_event = payload_event.pop("payload") + with smart_open.open(s3_event, "r") as _file: + s3_event_read = _file.read() + event_received = json.loads(s3_event_read) + s3_objects = event_received["objects"] + status = list() + for s3_object in s3_objects: + href = s3_object["assets"]["default"]["href"] + collection = s3_object["collection"] + downloaded_filepath = download_file(href, assume_role_arn) + print(f"[ DOWNLOAD FILEPATH ]: {downloaded_filepath}") + print(f"[ COLLECTION ]: {collection}") + + s3_object_prefix = event_received["prefix"] + if s3_object_prefix.startswith("EIS/"): + coll_status = load_to_featuresdb_eis(downloaded_filepath, collection, vector_secret_name) + else: + coll_status = load_to_featuresdb(downloaded_filepath, collection, vector_secret_name) + + status.append(coll_status) + # delete file after ingest + os.remove(downloaded_filepath) + + if coll_status["status"] == "success" and s3_object_prefix.startswith("EIS/"): + alter_datetime_add_indexes_eis(collection, vector_secret_name) + elif coll_status["status"] != "success": + # bubble exception so Airflow shows it as a failure + raise Exception(coll_status["reason"]) + return status + + diff --git a/dags/veda_data_pipeline/utils/vector_ingest/requirements.txt b/dags/veda_data_pipeline/utils/vector_ingest/requirements.txt new file mode 100644 index 00000000..35d23946 --- /dev/null +++ b/dags/veda_data_pipeline/utils/vector_ingest/requirements.txt @@ -0,0 +1,7 @@ +smart-open==6.3.0 +psycopg2-binary==2.9.9 +requests==2.30.0 +boto3==1.26.129 +GeoAlchemy2==0.14.2 +geopandas==0.14.4 +SQLAlchemy==2.0.23 diff --git a/dags/veda_data_pipeline/veda_dataset_pipeline.py b/dags/veda_data_pipeline/veda_dataset_pipeline.py index faaa3d6a..1c8746cb 100644 --- a/dags/veda_data_pipeline/veda_dataset_pipeline.py +++ b/dags/veda_data_pipeline/veda_dataset_pipeline.py @@ -1,16 +1,12 @@ import pendulum -from datetime import timedelta - from airflow import DAG from airflow.decorators import task from airflow.operators.dummy_operator import DummyOperator as EmptyOperator -from airflow.utils.trigger_rule import TriggerRule from airflow.models.variable import Variable -from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator - +import json from veda_data_pipeline.groups.collection_group import collection_task_group from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_dataset_files_to_process -from veda_data_pipeline.groups.processing_tasks import build_stac_kwargs, submit_to_stac_ingestor_task +from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task dag_doc_md = """ ### Dataset Pipeline @@ -48,6 +44,7 @@ "tags": ["collection", "discovery"], } + @task def extract_discovery_items(**kwargs): ti = kwargs.get("ti") @@ -55,52 +52,50 @@ def extract_discovery_items(**kwargs): print(discovery_items) return discovery_items + +@task(max_active_tis_per_dag=3) +def build_stac_task(payload): + from veda_data_pipeline.utils.build_stac.handler import stac_handler + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + event_bucket = airflow_vars_json.get("EVENT_BUCKET") + return stac_handler(payload_src=payload, bucket_output=event_bucket) + + template_dag_run_conf = { - "collection": "", - "data_type": "cog", - "description": "", - "discovery_items": + "collection": "", + "data_type": "cog", + "description": "", + "discovery_items": [ { - "bucket": "", - "datetime_range": "", - "discovery": "s3", - "filename_regex": "", + "bucket": "", + "datetime_range": "", + "discovery": "s3", + "filename_regex": "", "prefix": "" } - ], - "is_periodic": "", - "license": "", - "time_density": "", + ], + "is_periodic": "", + "license": "", + "time_density": "", "title": "" } with DAG("veda_dataset_pipeline", params=template_dag_run_conf, **dag_args) as dag: # ECS dependency variable - mwaa_stack_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True) start = EmptyOperator(task_id="start", dag=dag) end = EmptyOperator(task_id="end", dag=dag) collection_grp = collection_task_group() discover = discover_from_s3_task.expand(event=extract_discovery_items()) - discover.set_upstream(collection_grp) # do not discover until collection exists + discover.set_upstream(collection_grp) # do not discover until collection exists get_files = get_dataset_files_to_process(payload=discover) - build_stac_kwargs_task = build_stac_kwargs.expand(event=get_files) - # partial() is needed for the operator to be used with taskflow inputs - build_stac = EcsRunTaskOperator.partial( - task_id="build_stac", - execution_timeout=timedelta(minutes=60), - trigger_rule=TriggerRule.NONE_FAILED, - cluster=f"{mwaa_stack_conf.get('PREFIX')}-cluster", - task_definition=f"{mwaa_stack_conf.get('PREFIX')}-tasks", - launch_type="FARGATE", - do_xcom_push=True - ).expand_kwargs(build_stac_kwargs_task) + + build_stac = build_stac_task.expand(payload=get_files) # .output is needed coming from a non-taskflow operator - submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac.output) + submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac) collection_grp.set_upstream(start) submit_stac.set_downstream(end) - - diff --git a/dags/veda_data_pipeline/veda_discover_pipeline.py b/dags/veda_data_pipeline/veda_discover_pipeline.py index a51e2fab..edd6780d 100644 --- a/dags/veda_data_pipeline/veda_discover_pipeline.py +++ b/dags/veda_data_pipeline/veda_discover_pipeline.py @@ -1,15 +1,12 @@ import pendulum - -from datetime import timedelta from airflow import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.utils.trigger_rule import TriggerRule +from airflow.decorators import task from airflow.models.variable import Variable -from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator - +import json from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_to_process -from veda_data_pipeline.groups.processing_tasks import build_stac_kwargs, submit_to_stac_ingestor_task - +from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task dag_doc_md = """ ### Discover files from S3 @@ -76,42 +73,41 @@ } -def get_discover_dag(id, event={}): +@task(max_active_tis_per_dag=3) +def build_stac_task(payload): + from veda_data_pipeline.utils.build_stac.handler import stac_handler + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + event_bucket = airflow_vars_json.get("EVENT_BUCKET") + return stac_handler(payload_src=payload, bucket_output=event_bucket) + + +def get_discover_dag(id, event=None): + if not event: + event = {} params_dag_run_conf = event or template_dag_run_conf with DAG( - id, - schedule_interval=event.get("schedule"), - params=params_dag_run_conf, - **dag_args + id, + schedule_interval=event.get("schedule"), + params=params_dag_run_conf, + **dag_args ) as dag: - # ECS dependency variable - mwaa_stack_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True) - start = DummyOperator(task_id="Start", dag=dag) end = DummyOperator( - task_id="End", trigger_rule=TriggerRule.ONE_SUCCESS, dag=dag + task_id="End", dag=dag ) # define DAG using taskflow notation - + discover = discover_from_s3_task(event=event) get_files = get_files_to_process(payload=discover) - build_stac_kwargs_task = build_stac_kwargs.expand(event=get_files) - # partial() is needed for the operator to be used with taskflow inputs - build_stac = EcsRunTaskOperator.partial( - task_id="build_stac", - execution_timeout=timedelta(minutes=60), - trigger_rule=TriggerRule.NONE_FAILED, - cluster=f"{mwaa_stack_conf.get('PREFIX')}-cluster", - task_definition=f"{mwaa_stack_conf.get('PREFIX')}-tasks", - launch_type="FARGATE", - do_xcom_push=True - ).expand_kwargs(build_stac_kwargs_task) + build_stac = build_stac_task.expand(payload=get_files) # .output is needed coming from a non-taskflow operator - submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac.output) + submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac) discover.set_upstream(start) submit_stac.set_downstream(end) return dag + get_discover_dag("veda_discover") diff --git a/dags/veda_data_pipeline/veda_generic_vector_pipeline.py b/dags/veda_data_pipeline/veda_generic_vector_pipeline.py index 5d0119e7..b359d6dd 100644 --- a/dags/veda_data_pipeline/veda_generic_vector_pipeline.py +++ b/dags/veda_data_pipeline/veda_generic_vector_pipeline.py @@ -1,17 +1,12 @@ import pendulum -from datetime import timedelta - +from airflow.decorators import task from airflow import DAG from airflow.models.variable import Variable -from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator +import json from airflow.operators.dummy_operator import DummyOperator from airflow.utils.trigger_rule import TriggerRule -from airflow.models.variable import Variable - -from veda_data_pipeline.groups.processing_tasks import build_generic_vector_kwargs from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_to_process - dag_doc_md = """ ### Generic Ingest Vector #### Purpose @@ -66,27 +61,24 @@ "doc_md": dag_doc_md, } + +@task +def ingest_vector_task(payload): + from veda_data_pipeline.utils.vector_ingest.handler import handler + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + read_role_arn = airflow_vars_json.get("ASSUME_ROLE_READ_ARN") + vector_secret_name = airflow_vars_json.get("VECTOR_SECRET_NAME") + return handler(payload_src=payload, vector_secret_name=vector_secret_name, + assume_role_arn=read_role_arn) + + with DAG(dag_id="veda_generic_ingest_vector", params=template_dag_run_conf, **dag_args) as dag: - # ECS dependency variable - mwaa_stack_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True) start = DummyOperator(task_id="Start", dag=dag) end = DummyOperator(task_id="End", trigger_rule=TriggerRule.ONE_SUCCESS, dag=dag) - discover = discover_from_s3_task() get_files = get_files_to_process(payload=discover) - build_generic_vector_kwargs_task = build_generic_vector_kwargs.expand(event=get_files) - vector_ingest = EcsRunTaskOperator.partial( - task_id="generic_ingest_vector", - execution_timeout=timedelta(minutes=60), - trigger_rule=TriggerRule.NONE_FAILED, - cluster=f"{mwaa_stack_conf.get('PREFIX')}-cluster", - task_definition=f"{mwaa_stack_conf.get('PREFIX')}-generic_vector-tasks", - launch_type="FARGATE", - do_xcom_push=True - ).expand_kwargs(build_generic_vector_kwargs_task) - + vector_ingest = ingest_vector_task.expand(payload=get_files) discover.set_upstream(start) vector_ingest.set_downstream(end) - - diff --git a/dags/veda_data_pipeline/veda_vector_pipeline.py b/dags/veda_data_pipeline/veda_vector_pipeline.py index 46f17c4f..c264ea92 100644 --- a/dags/veda_data_pipeline/veda_vector_pipeline.py +++ b/dags/veda_data_pipeline/veda_vector_pipeline.py @@ -1,16 +1,11 @@ import pendulum -from datetime import timedelta - +from airflow.decorators import task from airflow import DAG -from airflow.models.variable import Variable -from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator from airflow.operators.dummy_operator import DummyOperator from airflow.utils.trigger_rule import TriggerRule from airflow.models.variable import Variable - -from veda_data_pipeline.groups.processing_tasks import build_vector_kwargs from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_to_process - +import json dag_doc_md = """ ### Build and submit stac @@ -53,27 +48,24 @@ "doc_md": dag_doc_md, } -with DAG(dag_id="veda_ingest_vector", params=template_dag_run_conf, **dag_args) as dag: - # ECS dependency variable - mwaa_stack_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True) +@task +def ingest_vector_task(payload): + from veda_data_pipeline.utils.vector_ingest.handler import handler + + airflow_vars = Variable.get("aws_dags_variables") + airflow_vars_json = json.loads(airflow_vars) + read_role_arn = airflow_vars_json.get("ASSUME_ROLE_READ_ARN") + vector_secret_name = airflow_vars_json.get("VECTOR_SECRET_NAME") + return handler(payload_src=payload, vector_secret_name=vector_secret_name, + assume_role_arn=read_role_arn) + + +with DAG(dag_id="veda_ingest_vector", params=template_dag_run_conf, **dag_args) as dag: start = DummyOperator(task_id="Start", dag=dag) end = DummyOperator(task_id="End", trigger_rule=TriggerRule.ONE_SUCCESS, dag=dag) - discover = discover_from_s3_task() get_files = get_files_to_process(payload=discover) - build_vector_kwargs_task = build_vector_kwargs.expand(event=get_files) - vector_ingest = EcsRunTaskOperator.partial( - task_id="ingest_vector", - execution_timeout=timedelta(minutes=60), - trigger_rule=TriggerRule.NONE_FAILED, - cluster=f"{mwaa_stack_conf.get('PREFIX')}-cluster", - task_definition=f"{mwaa_stack_conf.get('PREFIX')}-vector-tasks", - launch_type="FARGATE", - do_xcom_push=True - ).expand_kwargs(build_vector_kwargs_task) - + vector_ingest = ingest_vector_task.expand(payload=get_files) discover.set_upstream(start) vector_ingest.set_downstream(end) - - diff --git a/sm2a/airflow_worker/requirements.txt b/sm2a/airflow_worker/requirements.txt index 7c66fbda..4796ad5f 100644 --- a/sm2a/airflow_worker/requirements.txt +++ b/sm2a/airflow_worker/requirements.txt @@ -21,4 +21,10 @@ fsspec s3fs xarray xstac +pystac +rasterio +rio-stac +GeoAlchemy2 +geopandas==0.14.4 +fiona==1.9.6 diff --git a/sm2a/infrastructure/main.tf b/sm2a/infrastructure/main.tf index f7b6aa22..09f9e372 100644 --- a/sm2a/infrastructure/main.tf +++ b/sm2a/infrastructure/main.tf @@ -44,7 +44,6 @@ module "sma-base" { rds_allocated_storage = var.rds_configuration[var.stage].rds_allocated_storage rds_max_allocated_storage = var.rds_configuration[var.stage].rds_max_allocated_storage workers_logs_retention_days = var.workers_configuration[var.stage].workers_logs_retention_days - airflow_custom_variables = var.airflow_custom_variables extra_airflow_task_common_environment = [ { @@ -83,5 +82,13 @@ module "sma-base" { stage = var.stage subdomain = var.subdomain worker_cmd = ["/home/airflow/.local/bin/airflow", "celery", "worker"] + + airflow_custom_variables = { + EVENT_BUCKET = var.event_bucket + COGNITO_APP_SECRET = var.workflows_client_secret + STAC_INGESTOR_API_URL = var.stac_ingestor_api_url + STAC_URL = var.stac_url + VECTOR_SECRET_NAME = var.vector_secret_name + } } diff --git a/sm2a/infrastructure/variables.tf b/sm2a/infrastructure/variables.tf index e2b7f54a..1c1429ea 100644 --- a/sm2a/infrastructure/variables.tf +++ b/sm2a/infrastructure/variables.tf @@ -194,13 +194,27 @@ variable "custom_worker_policy_statement" { ] } -variable "airflow_custom_variables" { - description = "Airflow custom variables" - type = map(string) - default = {} -} + variable "project_name" { type = string default = "SM2A" } + +variable "event_bucket" { + default = "veda-pipeline-sit-mwaa-853558080719" +} +variable "workflows_client_secret" { + default = "veda-auth-stack-dev/workflows-client" +} +variable "stac_ingestor_api_url" { + default = "https://dev.openveda.cloud/api/ingest/" +} + +variable "stac_url" { + default = "https://dev.openveda.cloud/api/stac/" +} + +variable "vector_secret_name" { + default = "veda-features-api-dev/features-tipg-db/7c4b47e4" +} \ No newline at end of file