Skip to content

Commit

Permalink
Adding Params to the automation DAG
Browse files Browse the repository at this point in the history
  • Loading branch information
amarouane-ABDELHAK committed Sep 11, 2024
1 parent b0a6524 commit d21333c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 61 deletions.
75 changes: 40 additions & 35 deletions dags/automated_transformation/automation_dag.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
from __future__ import annotations

import importlib

from airflow import DAG
from airflow.decorators import task
from airflow.models.param import Param
from airflow.operators.dummy_operator import DummyOperator
import importlib

DAG_ID = "automate-cog-transformation"

# Custom validation function


dag_run_config = {
"data_acquisition_method": "s3",
"data_acquisition_method": Param(
"s3", enum=["s3"]
), # To add Other protocols (HTTP, SFTP...)
"raw_data_bucket": "ghgc-data-store-develop",
"raw_data_prefix": "delivery/tm54dvar-ch4flux-mask-monthgrid-v5",
"raw_data_prefix": Param(
"delivery/tm54dvar-ch4flux-mask-monthgrid-v5",
type="string",
pattern="^[^/].*[^/]$",
),
"dest_data_bucket": "ghgc-data-store-develop",
"data_prefix": "transformed_cogs",
"collection_name":"tm54dvar-ch4flux-mask-monthgrid-v5",
"nodata":-9999,
"ext": ".nc" # .nc, .nc4, .tif, .tiff
"data_prefix": Param("transformed_cogs", type="string", pattern="^[^/].*[^/]$"),
"collection_name": "tm54dvar-ch4flux-mask-monthgrid-v5",
"nodata": -9999,
"ext": Param(".nc", type="string", pattern="^\\..*$"),
}

with DAG(
Expand All @@ -31,21 +43,26 @@
def check_function_exists(ti):
config = ti.dag_run.conf.copy()
collection_name = config.get("collection_name")
module = importlib.import_module('automated_transformation.transformation_functions')
module = importlib.import_module(
"automated_transformation.transformation_functions"
)
function_name = f'{collection_name.replace("-", "_")}_transformation'

if not hasattr(module, function_name):
raise Exception(f"The function {function_name} does not exist in the module {module}.")
return (f"The function {function_name} does not exist in the module {module}.")

raise Exception(
f"The function {function_name} does not exist in the module {module}."
)
return f"The function {function_name} does not exist in the module {module}."

@task
def discover_files(ti):
from dags.automated_transformation.transformation_functions import get_all_s3_keys
from dags.automated_transformation.transformation_functions import (
get_all_s3_keys,
)

config = ti.dag_run.conf.copy()
bucket = config.get("raw_data_bucket")
model_name = config.get("raw_data_prefix")
ext = config.get("ext") # .nc as well
ext = config.get("ext") # .nc as well
# return get_all_s3_keys(bucket, model_name, ext)
generated_list = get_all_s3_keys(bucket, model_name, ext)
chunk_size = int(len(generated_list) / 900) + 1
Expand All @@ -54,7 +71,6 @@ def discover_files(ti):
for i in range(0, len(generated_list), chunk_size)
]


@task(max_active_tis_per_dag=10)
def process_files(file_url, **kwargs):
dag_run = kwargs.get("dag_run")
Expand All @@ -70,31 +86,20 @@ def process_files(file_url, **kwargs):
print("len of files", len(file_url))
file_status = transform_cog(
file_url,
nodata = nodata,
nodata=nodata,
raw_data_bucket=raw_bucket_name,
dest_data_bucket=dest_data_bucket,
data_prefix=data_prefix,
collection_name=collection_name,
)
return file_status

# @task
# def generate_report(report_data, json_filename, **kwargs):
# from odiac_processing.processing import upload_json_report

# dag_run = kwargs.get("dag_run")
# config = dag_run.conf.copy()
# bucket_name = config.get("cog_data_bucket")
# s3_destination_folder_name = config.get("cog_data_prefix")
# report_json_filename = config.get("report_json_filename")
# return upload_json_report(
# report_data=report_data,
# bucket_name=bucket_name,
# s3_folder_name=s3_destination_folder_name,
# json_filename=report_json_filename,
# )
@task
def generate_report(reports, **kwargs):
dag_run = kwargs.get("dag_run")
collection_name = dag_run.conf.get("collection_name")
return {"collection": collection_name, "successes": len(reports)}

urls = start >>check_function_exists()>> discover_files()
report_data = process_files.expand(file_url=urls) >> end
# generate_report(report_data=report_data[0],
# json_filename=report_data[1]) >> end
urls = start >> check_function_exists() >> discover_files()
report_data = process_files.expand(file_url=urls)
generate_report(reports=report_data) >> end
32 changes: 18 additions & 14 deletions dags/automated_transformation/transformation_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import boto3
import xarray
import re
from datetime import datetime

import boto3
import xarray


def get_all_s3_keys(bucket, model_name, ext):
"""Function fetches all the s3 keys from the given bucket and model name.
Expand All @@ -12,7 +14,7 @@ def get_all_s3_keys(bucket, model_name, ext):
ext (str): extension of the file that is to be fetched.
Returns:
list : List of all the keys that match the given criteria
list : List of all the keys that match the given criteria
"""
session = boto3.session.Session()
s3_client = session.client("s3")
Expand All @@ -34,6 +36,7 @@ def get_all_s3_keys(bucket, model_name, ext):
print(f"Discovered {len(keys)}")
return keys


"""
The naming convention for the transformation function is as follows:
collectionname_transformation
Expand All @@ -45,6 +48,8 @@ def get_all_s3_keys(bucket, model_name, ext):
differentiate the transformation functions from any other functions in
file.
"""


def tm54dvar_ch4flux_mask_monthgrid_v5_transformation(file_obj, name, nodata):
"""Tranformation function for the tm5 ch4 influx dataset
Expand All @@ -56,7 +61,7 @@ def tm54dvar_ch4flux_mask_monthgrid_v5_transformation(file_obj, name, nodata):
Returns:
dict: Dictionary with the COG name and its corresponding data array.
"""

var_data_netcdf = {}
xds = xarray.open_dataset(file_obj)
xds = xds.rename({"latitude": "lat", "longitude": "lon"})
Expand All @@ -70,7 +75,7 @@ def tm54dvar_ch4flux_mask_monthgrid_v5_transformation(file_obj, name, nodata):
for var in variable:
data = getattr(xds.isel(months=time_increment), var)
data = data.isel(lat=slice(None, None, -1))
data = data.where(data==nodata, -9999)
data = data.where(data == nodata, -9999)
data.rio.set_spatial_dims("lon", "lat", inplace=True)
data.rio.write_crs("epsg:4326", inplace=True)
data.rio.write_nodata(-9999, inplace=True)
Expand All @@ -86,6 +91,7 @@ def tm54dvar_ch4flux_mask_monthgrid_v5_transformation(file_obj, name, nodata):

return var_data_netcdf


def gpw_transformation(file_obj, name, nodata):
"""Tranformation function for the gridded population dataset
Expand All @@ -99,14 +105,14 @@ def gpw_transformation(file_obj, name, nodata):
"""

var_data_netcdf = {}
xds = xarray.open_dataarray(file_obj, engine='rasterio')
xds = xarray.open_dataarray(file_obj, engine="rasterio")

filename = name.split("/")[-1]
filename_elements = re.split("[_ .]", filename)
# # insert date of generated COG into filename
filename_elements.pop()
filename_elements.append(filename_elements[-3])
xds = xds.where(xds==nodata, -9999)
xds = xds.where(xds == nodata, -9999)
xds.rio.set_spatial_dims("x", "y", inplace=True)
xds.rio.write_crs("epsg:4326", inplace=True)
xds.rio.write_nodata(-9999, inplace=True)
Expand All @@ -117,6 +123,7 @@ def gpw_transformation(file_obj, name, nodata):
var_data_netcdf[cog_filename] = xds
return var_data_netcdf


def geos_oco2_transformation(file_obj, name, nodata):
"""Tranformation function for the oco2 geos dataset
Expand All @@ -132,20 +139,16 @@ def geos_oco2_transformation(file_obj, name, nodata):
xds = xarray.open_dataset(file_obj)
xds = xds.assign_coords(lon=(((xds.lon + 180) % 360) - 180)).sortby("lon")
variable = [var for var in xds.data_vars]
filename = name.split("/ ")[-1]
filename_elements = re.split("[_ .]", filename)

for time_increment in range(0, len(xds.time)):
for var in variable:
filename = name.split("/ ")[-1]
filename_elements = re.split("[_ .]", filename)
data = getattr(xds.isel(time=time_increment), var)
data = data.isel(lat=slice(None, None, -1))
data = data.where(data==nodata, -9999)
data = data.where(data == nodata, -9999)
data.rio.set_spatial_dims("lon", "lat", inplace=True)
data.rio.write_crs("epsg:4326", inplace=True)
data.rio.write_nodata(-9999, inplace=True)

# # insert date of generated COG into filename
filename_elements[-1] = filename_elements[-3]
filename_elements.insert(2, var)
Expand All @@ -157,6 +160,7 @@ def geos_oco2_transformation(file_obj, name, nodata):

return var_data_netcdf


def ecco_darwin_transformation(file_obj, name, nodata):
"""Tranformation function for the ecco darwin dataset
Expand Down Expand Up @@ -187,7 +191,7 @@ def ecco_darwin_transformation(file_obj, name, nodata):
data = xds[var]

data = data.reindex(latitude=list(reversed(data.latitude)))
data = data.where(data==nodata, -9999)
data = data.where(data == nodata, -9999)
data.rio.set_spatial_dims("longitude", "latitude", inplace=True)
data.rio.write_crs("epsg:4326", inplace=True)
data.rio.write_nodata(-9999, inplace=True)
Expand All @@ -200,4 +204,4 @@ def ecco_darwin_transformation(file_obj, name, nodata):
# # add extension
cog_filename = f"{cog_filename}.tif"
var_data_netcdf[cog_filename] = data
return var_data_netcdf
return var_data_netcdf
37 changes: 25 additions & 12 deletions dags/automated_transformation/transformation_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import importlib
import tempfile

import boto3
import pandas as pd
import rasterio
import s3fs
import importlib

files_processed = pd.DataFrame(columns=["file_name", "COGs_created"])

def transform_cog(name_list, nodata, raw_data_bucket, dest_data_bucket, data_prefix, collection_name):

def transform_cog(
name_list, nodata, raw_data_bucket, dest_data_bucket, data_prefix, collection_name
):
"""This function calls the plugins (dataset specific transformation functions) and
generalizes the transformation of dataset to COGs.
Expand All @@ -22,22 +25,24 @@ def transform_cog(name_list, nodata, raw_data_bucket, dest_data_bucket, data_pre
Returns:
dict: Status and name of the file that is transformed
"""

session = boto3.session.Session()
s3_client = session.client("s3")
module = importlib.import_module('automated_transformation.transformation_functions')
module = importlib.import_module(
"automated_transformation.transformation_functions"
)
function_name = f'{collection_name.replace("-", "_")}_transformation'
for name in name_list:
url = f"s3://{raw_data_bucket}/{name}"
fs = s3fs.S3FileSystem()
print('the url is', url)
with fs.open(url, mode='rb') as file_obj:
print("the url is", url)
with fs.open(url, mode="rb") as file_obj:
try:
transform_func = getattr(module, function_name)
var_data_netcdf = transform_func(file_obj, name, nodata)

for cog_filename, data in var_data_netcdf.items():
# generate COG
# generate COG
COG_PROFILE = {"driver": "COG", "compress": "DEFLATE"}
with tempfile.NamedTemporaryFile() as temp_file:
data.rio.to_raster(temp_file.name, **COG_PROFILE)
Expand All @@ -46,7 +51,15 @@ def transform_cog(name_list, nodata, raw_data_bucket, dest_data_bucket, data_pre
Bucket=dest_data_bucket,
Key=f"{data_prefix}/{collection_name}/{cog_filename}",
)
status = {'transformed filename':f'{cog_filename}', 'S3uri':f's3://{dest_data_bucket}/{data_prefix}/{collection_name}/{cog_filename}', 'status':'success'}
except:
status = {'transformed filename':f'{name}', 'status':'failed'}
status = {
"transformed_filename": cog_filename,
"s3uri": f"s3://{dest_data_bucket}/{data_prefix}/{collection_name}/{cog_filename}",
"status": "success",
}
except Exception as ex:
status = {
"transformed_filename": name,
"status": "failed",
"reason": f"Error: {ex}",
}
return status

0 comments on commit d21333c

Please sign in to comment.