Skip to content

Commit

Permalink
Switching to pythonOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
amarouane-ABDELHAK committed Oct 24, 2024
1 parent 6d3e29a commit 686c4c8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
25 changes: 24 additions & 1 deletion dags/veda_data_pipeline/utils/build_stac/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,30 @@
import smart_open
from veda_data_pipeline.utils.build_stac.utils import events
from veda_data_pipeline.utils.build_stac.utils import stac
from concurrent.futures import ThreadPoolExecutor, as_completed



class S3LinkOutput(TypedDict):
stac_file_url: str


def using_pool(objects, workers_count: int):
returned_results = []
with ThreadPoolExecutor(max_workers=workers_count) as executor:
# Submit tasks to the executor
futures = {executor.submit(handler, obj): obj for obj in objects}

for future in as_completed(futures):
try:
result = future.result() # Get result from future
returned_results.append(result)
except Exception as nex:
print(f"Error {nex} with object {futures[future]}")

return returned_results


class StacItemOutput(TypedDict):
stac_item: Dict[str, Any]

Expand Down Expand Up @@ -81,7 +99,12 @@ def stac_handler(payload_src: dict, bucket_output):
s3_event_read = _file.read()
event_received = json.loads(s3_event_read)
objects = event_received["objects"]
payloads = sequential_processing(objects)
use_multithreading = payload_event.get("use_multithreading", True)
payloads = (
using_pool(objects, workers_count=4)
if use_multithreading
else sequential_processing(objects)
)
for payload in payloads:
stac_item = payload["stac_item"]
if "error" in stac_item:
Expand Down
3 changes: 1 addition & 2 deletions dags/veda_data_pipeline/veda_discover_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pendulum
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.trigger_rule import TriggerRule
from airflow.decorators import task
from airflow.models.variable import Variable
import json
Expand Down Expand Up @@ -73,7 +72,7 @@
}


@task(max_active_tis_per_dag=3)
@task(max_active_tis_per_dag=5)
def build_stac_task(payload):
from veda_data_pipeline.utils.build_stac.handler import stac_handler
airflow_vars = Variable.get("aws_dags_variables")
Expand Down

0 comments on commit 686c4c8

Please sign in to comment.