Skip to content

Commit

Permalink
Refactor: Merge templates
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Brugman <sfbbrugman@gmail.com>
  • Loading branch information
sbrugman committed Jul 11, 2023
1 parent 8e2ca88 commit df20823
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 121 deletions.
22 changes: 7 additions & 15 deletions kedro-airflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,6 @@ You can use the additional command line argument `--jinja-file` (alias `-j`) to
kedro airflow create --jinja-file=./custom/template.j2
```

Similarly, the included DAG `kwargs` template can be custom provided:

```bash
kedro airflow create --kwargs-template=./custom/kwargs_template.j2
```

#### How can I pass arguments to the Airflow DAGs dynamically?

`kedro-airflow` picks up configuration from `airflow.yml` files. Arguments can be specified globally, or per pipeline:
Expand All @@ -75,18 +69,16 @@ default:
schedule_interval: "@once"
catchup: false
# Default settings applied to all tasks
default_args:
owner: "airflow"
depends_on_past: false
email_on_failure: false
email_on_retry: false
retries: 1
retry_delay: 5
owner: "airflow"
depends_on_past: false
email_on_failure: false
email_on_retry: false
retries: 1
retry_delay: 5

# Arguments specific to the pipeline (overrides the parameters above)
data_science:
default_args:
owner: "airflow-ds"
owner: "airflow-ds"
```
Arguments can also be passed via `--params` in the command line:
Expand Down
14 changes: 0 additions & 14 deletions kedro-airflow/kedro_airflow/airflow_dag_kwargs_template.j2

This file was deleted.

18 changes: 14 additions & 4 deletions kedro-airflow/kedro_airflow/airflow_dag_template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,20 @@ package_name = "{{ package_name }}"
# Using a DAG context manager, you don't have to specify the dag property of each task
with DAG(
dag_id="{{ dag_name | safe | slugify }}",
{% if kwargs_template -%}
{% filter indent(width=4) %} {% include kwargs_template with context -%}
{% endfilter %}
{%- endif %}
start_date=datetime(2023, 1, 1),
max_active_runs=3,
# https://airflow.apache.org/docs/stable/scheduler.html#dag-runs
schedule_interval="{{ schedule_interval | default('@once') }}",
catchup=False,
# Default settings applied to all tasks
default_args=dict(
owner="{{ owner | default('airflow') }}",
depends_on_past=False,
email_on_failure=False,
email_on_retry=False,
retries=1,
retry_delay=timedelta(minutes=5)
)
) as dag:
tasks = {
{% for node in pipeline.nodes %} "{{ node.name | safe | slugify }}": KedroOperator(
Expand Down
160 changes: 76 additions & 84 deletions kedro-airflow/kedro_airflow/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from collections import defaultdict
from pathlib import Path
from typing import Any

import click
import jinja2
from click import secho
from kedro.config import MissingConfigException
from kedro.framework.cli.project import PARAMS_ARG_HELP
from kedro.framework.cli.utils import ENV_HELP, KedroCliError, _split_params
from kedro.framework.context import KedroContext
from kedro.framework.project import pipelines
from kedro.framework.session import KedroSession
from kedro.framework.startup import ProjectMetadata, bootstrap_project
Expand All @@ -30,6 +32,21 @@ def airflow_commands():
pass


def load_config(context: KedroContext, pipeline_name: str) -> dict[str, Any]:
try:
config_airflow = context.config_loader.get("airflow*", "airflow/**")
dag_config = {}
# Load the default config if specified
if "default" in config_airflow:
dag_config.update(config_airflow["default"])
# Update with pipeline-specific config if present
if pipeline_name in config_airflow:
dag_config.update(config_airflow[pipeline_name])
except MissingConfigException:
dag_config = {}
return dag_config


@airflow_commands.command()
@click.option(
"-p", "--pipeline", "pipeline_name", default="__default__", help=PIPELINE_ARG_HELP
Expand All @@ -52,15 +69,6 @@ def airflow_commands():
default=Path(__file__).parent / "airflow_dag_template.j2",
help="The template file for the generated Airflow dags",
)
@click.option(
"-k",
"--kwargs-template",
type=click.Path(
exists=True, readable=True, resolve_path=True, file_okay=True, dir_okay=False
),
default=Path(__file__).parent / "airflow_dag_kwargs_template.j2",
help="The template file for the kwargs in the Airflow dags",
)
@click.option(
"--params",
type=click.UNPROCESSED,
Expand All @@ -75,89 +83,73 @@ def create(
env,
target_path,
jinja_file,
kwargs_template,
params,
): # pylint: disable=too-many-locals,too-many-arguments
"""Create an Airflow DAG for a project"""
project_path = Path().cwd()
bootstrap_project(project_path)
with KedroSession.create(project_path=project_path, env=env) as session:
context = session.load_context()
try:
config_airflow = context.config_loader.get("airflow*", "airflow/**")
dag_config = {}
# Load the default config if specified
if "default" in config_airflow:
dag_config.update(config_airflow["default"])
# Update with pipeline-specific config if present
if pipeline_name in config_airflow:
dag_config.update(config_airflow[pipeline_name])
except MissingConfigException:
dag_config = {}
dag_config = load_config(context, pipeline_name)

# Update with params if provided
dag_config.update(params)

jinja_file = Path(jinja_file).resolve()
kwargs_template = Path(kwargs_template).resolve()
if jinja_file.parent != kwargs_template.parent:
raise KedroCliError(f"Templates should be placed in the same directory.")
loader = jinja2.FileSystemLoader(jinja_file.parent)
jinja_env = jinja2.Environment(
autoescape=True, loader=loader, lstrip_blocks=True # , trim_blocks=True
)
jinja_env.filters["slugify"] = slugify
template = jinja_env.get_template(jinja_file.name)
kwargs_template = kwargs_template.name

package_name = metadata.package_name
dag_filename = f"{package_name}_{pipeline_name}_dag.py"

target_path = Path(target_path)
target_path = target_path / dag_filename

target_path.parent.mkdir(parents=True, exist_ok=True)

pipeline = pipelines.get(pipeline_name)
if pipeline is None:
raise KedroCliError(f"Pipeline {pipeline_name} not found.")

dependencies = defaultdict(list)
for node, parent_nodes in pipeline.node_dependencies.items():
for parent in parent_nodes:
dependencies[parent].append(node)

template.stream(
dag_name=package_name,
dependencies=dependencies,
env=env,
pipeline_name=pipeline_name,
package_name=package_name,
pipeline=pipeline,
kwargs_template=kwargs_template,
**dag_config,
).dump(str(target_path))

secho("")
secho("An Airflow DAG has been generated in:", fg="green")
secho(str(target_path))
secho("This file should be copied to your Airflow DAG folder.", fg="yellow")
secho(
"The Airflow configuration can be customized by editing this file.",
fg="green",
)
secho("")
secho(
"This file also contains the path to the config directory, this directory will need to "
"be available to Airflow and any workers.",
fg="yellow",
)
secho("")
secho(
"Additionally all data sets must have an entry in the data catalog.",
fg="yellow",
)
secho(
"And all local paths in both the data catalog and log config must be absolute paths.",
fg="yellow",
)
jinja_file = Path(jinja_file).resolve()
loader = jinja2.FileSystemLoader(jinja_file.parent)
jinja_env = jinja2.Environment(
autoescape=True, loader=loader, lstrip_blocks=True
)
jinja_env.filters["slugify"] = slugify
template = jinja_env.get_template(jinja_file.name)

package_name = metadata.package_name
dag_filename = f"{package_name}_{pipeline_name}_dag.py"

target_path = Path(target_path)
target_path = target_path / dag_filename

target_path.parent.mkdir(parents=True, exist_ok=True)

pipeline = pipelines.get(pipeline_name)
if pipeline is None:
raise KedroCliError(f"Pipeline {pipeline_name} not found.")

dependencies = defaultdict(list)
for node, parent_nodes in pipeline.node_dependencies.items():
for parent in parent_nodes:
dependencies[parent].append(node)

template.stream(
dag_name=package_name,
dependencies=dependencies,
env=env,
pipeline_name=pipeline_name,
package_name=package_name,
pipeline=pipeline,
**dag_config,
).dump(str(target_path))

secho("")
secho("An Airflow DAG has been generated in:", fg="green")
secho(str(target_path))
secho("This file should be copied to your Airflow DAG folder.", fg="yellow")
secho(
"The Airflow configuration can be customized by editing this file.",
fg="green",
)
secho("")
secho(
"This file also contains the path to the config directory, this directory will need to "
"be available to Airflow and any workers.",
fg="yellow",
)
secho("")
secho(
"Additionally all data sets must have an entry in the data catalog.",
fg="yellow",
)
secho(
"And all local paths in both the data catalog and log config must be absolute paths.",
fg="yellow",
)
5 changes: 1 addition & 4 deletions kedro-airflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ packages = ["kedro_airflow"]
zip-safe = false

[tool.setuptools.package-data]
kedro_airflow = [
"kedro_airflow/airflow_dag_template.j2",
"kedro_airflow/airflow_dag_kwargs_template.j2"
]
kedro_airflow = ["kedro_airflow/airflow_dag_template.j2"]

[tool.setuptools.dynamic]
readme = {file = "README.md", content-type = "text/markdown"}
Expand Down

0 comments on commit df20823

Please sign in to comment.