diff --git a/dags/veda_data_pipeline/utils/build_stac/handler.py b/dags/veda_data_pipeline/utils/build_stac/handler.py index 9eb017b0..a13350b0 100644 --- a/dags/veda_data_pipeline/utils/build_stac/handler.py +++ b/dags/veda_data_pipeline/utils/build_stac/handler.py @@ -4,12 +4,30 @@ import smart_open from veda_data_pipeline.utils.build_stac.utils import events from veda_data_pipeline.utils.build_stac.utils import stac +from concurrent.futures import ThreadPoolExecutor, as_completed + class S3LinkOutput(TypedDict): stac_file_url: str +def using_pool(objects, workers_count: int): + returned_results = [] + with ThreadPoolExecutor(max_workers=workers_count) as executor: + # Submit tasks to the executor + futures = {executor.submit(handler, obj): obj for obj in objects} + + for future in as_completed(futures): + try: + result = future.result() # Get result from future + returned_results.append(result) + except Exception as nex: + print(f"Error {nex} with object {futures[future]}") + + return returned_results + + class StacItemOutput(TypedDict): stac_item: Dict[str, Any] @@ -81,7 +99,12 @@ def stac_handler(payload_src: dict, bucket_output): s3_event_read = _file.read() event_received = json.loads(s3_event_read) objects = event_received["objects"] - payloads = sequential_processing(objects) + use_multithreading = payload_event.get("use_multithreading", True) + payloads = ( + using_pool(objects, workers_count=4) + if use_multithreading + else sequential_processing(objects) + ) for payload in payloads: stac_item = payload["stac_item"] if "error" in stac_item: diff --git a/dags/veda_data_pipeline/veda_discover_pipeline.py b/dags/veda_data_pipeline/veda_discover_pipeline.py index edd6780d..49d790b2 100644 --- a/dags/veda_data_pipeline/veda_discover_pipeline.py +++ b/dags/veda_data_pipeline/veda_discover_pipeline.py @@ -1,7 +1,6 @@ import pendulum from airflow import DAG from airflow.operators.dummy_operator import DummyOperator -from airflow.utils.trigger_rule import TriggerRule from airflow.decorators import task from airflow.models.variable import Variable import json @@ -73,7 +72,7 @@ } -@task(max_active_tis_per_dag=3) +@task(max_active_tis_per_dag=5) def build_stac_task(payload): from veda_data_pipeline.utils.build_stac.handler import stac_handler airflow_vars = Variable.get("aws_dags_variables")