diff --git a/dags/automated_transformation/automation_dag.py b/dags/automated_transformation/automation_dag.py index 8d01988a..ae72a256 100644 --- a/dags/automated_transformation/automation_dag.py +++ b/dags/automated_transformation/automation_dag.py @@ -1,20 +1,32 @@ from __future__ import annotations + +import importlib + from airflow import DAG from airflow.decorators import task +from airflow.models.param import Param from airflow.operators.dummy_operator import DummyOperator -import importlib DAG_ID = "automate-cog-transformation" +# Custom validation function + + dag_run_config = { - "data_acquisition_method": "s3", + "data_acquisition_method": Param( + "s3", enum=["s3"] + ), # To add Other protocols (HTTP, SFTP...) "raw_data_bucket": "ghgc-data-store-develop", - "raw_data_prefix": "delivery/tm54dvar-ch4flux-mask-monthgrid-v5", + "raw_data_prefix": Param( + "delivery/tm54dvar-ch4flux-mask-monthgrid-v5", + type="string", + pattern="^[^/].*[^/]$", + ), "dest_data_bucket": "ghgc-data-store-develop", - "data_prefix": "transformed_cogs", - "collection_name":"tm54dvar-ch4flux-mask-monthgrid-v5", - "nodata":-9999, - "ext": ".nc" # .nc, .nc4, .tif, .tiff + "data_prefix": Param("transformed_cogs", type="string", pattern="^[^/].*[^/]$"), + "collection_name": "tm54dvar-ch4flux-mask-monthgrid-v5", + "nodata": -9999, + "ext": Param(".nc", type="string", pattern="^\\..*$"), } with DAG( @@ -31,21 +43,26 @@ def check_function_exists(ti): config = ti.dag_run.conf.copy() collection_name = config.get("collection_name") - module = importlib.import_module('automated_transformation.transformation_functions') + module = importlib.import_module( + "automated_transformation.transformation_functions" + ) function_name = f'{collection_name.replace("-", "_")}_transformation' - if not hasattr(module, function_name): - raise Exception(f"The function {function_name} does not exist in the module {module}.") - return (f"The function {function_name} does not exist in the module {module}.") - + raise Exception( + f"The function {function_name} does not exist in the module {module}." + ) + return f"The function {function_name} does not exist in the module {module}." + @task def discover_files(ti): - from dags.automated_transformation.transformation_functions import get_all_s3_keys + from dags.automated_transformation.transformation_functions import ( + get_all_s3_keys, + ) config = ti.dag_run.conf.copy() bucket = config.get("raw_data_bucket") model_name = config.get("raw_data_prefix") - ext = config.get("ext") # .nc as well + ext = config.get("ext") # .nc as well # return get_all_s3_keys(bucket, model_name, ext) generated_list = get_all_s3_keys(bucket, model_name, ext) chunk_size = int(len(generated_list) / 900) + 1 @@ -54,7 +71,6 @@ def discover_files(ti): for i in range(0, len(generated_list), chunk_size) ] - @task(max_active_tis_per_dag=10) def process_files(file_url, **kwargs): dag_run = kwargs.get("dag_run") @@ -70,7 +86,7 @@ def process_files(file_url, **kwargs): print("len of files", len(file_url)) file_status = transform_cog( file_url, - nodata = nodata, + nodata=nodata, raw_data_bucket=raw_bucket_name, dest_data_bucket=dest_data_bucket, data_prefix=data_prefix, @@ -78,23 +94,12 @@ def process_files(file_url, **kwargs): ) return file_status - # @task - # def generate_report(report_data, json_filename, **kwargs): - # from odiac_processing.processing import upload_json_report - - # dag_run = kwargs.get("dag_run") - # config = dag_run.conf.copy() - # bucket_name = config.get("cog_data_bucket") - # s3_destination_folder_name = config.get("cog_data_prefix") - # report_json_filename = config.get("report_json_filename") - # return upload_json_report( - # report_data=report_data, - # bucket_name=bucket_name, - # s3_folder_name=s3_destination_folder_name, - # json_filename=report_json_filename, - # ) + @task + def generate_report(reports, **kwargs): + dag_run = kwargs.get("dag_run") + collection_name = dag_run.conf.get("collection_name") + return {"collection": collection_name, "successes": len(reports)} - urls = start >>check_function_exists()>> discover_files() - report_data = process_files.expand(file_url=urls) >> end - # generate_report(report_data=report_data[0], - # json_filename=report_data[1]) >> end + urls = start >> check_function_exists() >> discover_files() + report_data = process_files.expand(file_url=urls) + generate_report(reports=report_data) >> end diff --git a/dags/automated_transformation/transformation_functions.py b/dags/automated_transformation/transformation_functions.py index 6debebd0..5b3205d9 100644 --- a/dags/automated_transformation/transformation_functions.py +++ b/dags/automated_transformation/transformation_functions.py @@ -1,8 +1,10 @@ -import boto3 -import xarray import re from datetime import datetime +import boto3 +import xarray + + def get_all_s3_keys(bucket, model_name, ext): """Function fetches all the s3 keys from the given bucket and model name. @@ -12,7 +14,7 @@ def get_all_s3_keys(bucket, model_name, ext): ext (str): extension of the file that is to be fetched. Returns: - list : List of all the keys that match the given criteria + list : List of all the keys that match the given criteria """ session = boto3.session.Session() s3_client = session.client("s3") @@ -34,6 +36,7 @@ def get_all_s3_keys(bucket, model_name, ext): print(f"Discovered {len(keys)}") return keys + """ The naming convention for the transformation function is as follows: collectionname_transformation @@ -45,6 +48,8 @@ def get_all_s3_keys(bucket, model_name, ext): differentiate the transformation functions from any other functions in file. """ + + def tm54dvar_ch4flux_mask_monthgrid_v5_transformation(file_obj, name, nodata): """Tranformation function for the tm5 ch4 influx dataset @@ -56,7 +61,7 @@ def tm54dvar_ch4flux_mask_monthgrid_v5_transformation(file_obj, name, nodata): Returns: dict: Dictionary with the COG name and its corresponding data array. """ - + var_data_netcdf = {} xds = xarray.open_dataset(file_obj) xds = xds.rename({"latitude": "lat", "longitude": "lon"}) @@ -70,7 +75,7 @@ def tm54dvar_ch4flux_mask_monthgrid_v5_transformation(file_obj, name, nodata): for var in variable: data = getattr(xds.isel(months=time_increment), var) data = data.isel(lat=slice(None, None, -1)) - data = data.where(data==nodata, -9999) + data = data.where(data == nodata, -9999) data.rio.set_spatial_dims("lon", "lat", inplace=True) data.rio.write_crs("epsg:4326", inplace=True) data.rio.write_nodata(-9999, inplace=True) @@ -86,6 +91,7 @@ def tm54dvar_ch4flux_mask_monthgrid_v5_transformation(file_obj, name, nodata): return var_data_netcdf + def gpw_transformation(file_obj, name, nodata): """Tranformation function for the gridded population dataset @@ -99,14 +105,14 @@ def gpw_transformation(file_obj, name, nodata): """ var_data_netcdf = {} - xds = xarray.open_dataarray(file_obj, engine='rasterio') + xds = xarray.open_dataarray(file_obj, engine="rasterio") filename = name.split("/")[-1] filename_elements = re.split("[_ .]", filename) # # insert date of generated COG into filename filename_elements.pop() filename_elements.append(filename_elements[-3]) - xds = xds.where(xds==nodata, -9999) + xds = xds.where(xds == nodata, -9999) xds.rio.set_spatial_dims("x", "y", inplace=True) xds.rio.write_crs("epsg:4326", inplace=True) xds.rio.write_nodata(-9999, inplace=True) @@ -117,6 +123,7 @@ def gpw_transformation(file_obj, name, nodata): var_data_netcdf[cog_filename] = xds return var_data_netcdf + def geos_oco2_transformation(file_obj, name, nodata): """Tranformation function for the oco2 geos dataset @@ -132,20 +139,16 @@ def geos_oco2_transformation(file_obj, name, nodata): xds = xarray.open_dataset(file_obj) xds = xds.assign_coords(lon=(((xds.lon + 180) % 360) - 180)).sortby("lon") variable = [var for var in xds.data_vars] - filename = name.split("/ ")[-1] - filename_elements = re.split("[_ .]", filename) - for time_increment in range(0, len(xds.time)): for var in variable: filename = name.split("/ ")[-1] filename_elements = re.split("[_ .]", filename) data = getattr(xds.isel(time=time_increment), var) data = data.isel(lat=slice(None, None, -1)) - data = data.where(data==nodata, -9999) + data = data.where(data == nodata, -9999) data.rio.set_spatial_dims("lon", "lat", inplace=True) data.rio.write_crs("epsg:4326", inplace=True) data.rio.write_nodata(-9999, inplace=True) - # # insert date of generated COG into filename filename_elements[-1] = filename_elements[-3] filename_elements.insert(2, var) @@ -157,6 +160,7 @@ def geos_oco2_transformation(file_obj, name, nodata): return var_data_netcdf + def ecco_darwin_transformation(file_obj, name, nodata): """Tranformation function for the ecco darwin dataset @@ -187,7 +191,7 @@ def ecco_darwin_transformation(file_obj, name, nodata): data = xds[var] data = data.reindex(latitude=list(reversed(data.latitude))) - data = data.where(data==nodata, -9999) + data = data.where(data == nodata, -9999) data.rio.set_spatial_dims("longitude", "latitude", inplace=True) data.rio.write_crs("epsg:4326", inplace=True) data.rio.write_nodata(-9999, inplace=True) @@ -200,4 +204,4 @@ def ecco_darwin_transformation(file_obj, name, nodata): # # add extension cog_filename = f"{cog_filename}.tif" var_data_netcdf[cog_filename] = data - return var_data_netcdf \ No newline at end of file + return var_data_netcdf diff --git a/dags/automated_transformation/transformation_pipeline.py b/dags/automated_transformation/transformation_pipeline.py index bd73e6ec..bf50baa7 100644 --- a/dags/automated_transformation/transformation_pipeline.py +++ b/dags/automated_transformation/transformation_pipeline.py @@ -1,13 +1,16 @@ +import importlib import tempfile + import boto3 import pandas as pd -import rasterio import s3fs -import importlib files_processed = pd.DataFrame(columns=["file_name", "COGs_created"]) -def transform_cog(name_list, nodata, raw_data_bucket, dest_data_bucket, data_prefix, collection_name): + +def transform_cog( + name_list, nodata, raw_data_bucket, dest_data_bucket, data_prefix, collection_name +): """This function calls the plugins (dataset specific transformation functions) and generalizes the transformation of dataset to COGs. @@ -22,22 +25,24 @@ def transform_cog(name_list, nodata, raw_data_bucket, dest_data_bucket, data_pre Returns: dict: Status and name of the file that is transformed """ - + session = boto3.session.Session() s3_client = session.client("s3") - module = importlib.import_module('automated_transformation.transformation_functions') + module = importlib.import_module( + "automated_transformation.transformation_functions" + ) function_name = f'{collection_name.replace("-", "_")}_transformation' for name in name_list: url = f"s3://{raw_data_bucket}/{name}" fs = s3fs.S3FileSystem() - print('the url is', url) - with fs.open(url, mode='rb') as file_obj: + print("the url is", url) + with fs.open(url, mode="rb") as file_obj: try: transform_func = getattr(module, function_name) var_data_netcdf = transform_func(file_obj, name, nodata) - + for cog_filename, data in var_data_netcdf.items(): - # generate COG + # generate COG COG_PROFILE = {"driver": "COG", "compress": "DEFLATE"} with tempfile.NamedTemporaryFile() as temp_file: data.rio.to_raster(temp_file.name, **COG_PROFILE) @@ -46,7 +51,15 @@ def transform_cog(name_list, nodata, raw_data_bucket, dest_data_bucket, data_pre Bucket=dest_data_bucket, Key=f"{data_prefix}/{collection_name}/{cog_filename}", ) - status = {'transformed filename':f'{cog_filename}', 'S3uri':f's3://{dest_data_bucket}/{data_prefix}/{collection_name}/{cog_filename}', 'status':'success'} - except: - status = {'transformed filename':f'{name}', 'status':'failed'} + status = { + "transformed_filename": cog_filename, + "s3uri": f"s3://{dest_data_bucket}/{data_prefix}/{collection_name}/{cog_filename}", + "status": "success", + } + except Exception as ex: + status = { + "transformed_filename": name, + "status": "failed", + "reason": f"Error: {ex}", + } return status