From dda078c737aca30f2798f563999b1d16f786a1db Mon Sep 17 00:00:00 2001 From: paullongtan Date: Wed, 1 Jan 2025 16:50:52 -0800 Subject: [PATCH] feat: add options to output registered entity summary --- flytekit/clis/sdk_in_container/register.py | 23 ++++++ flytekit/tools/repo.py | 50 +++++++++++- .../unit/cli/pyflyte/test_register.py | 78 +++++++++++++++++++ 3 files changed, 147 insertions(+), 4 deletions(-) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 57a9b58448..1683b628e8 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -142,6 +142,21 @@ help="Skip errors during registration. This is useful when registering multiple packages and you want to skip " "errors for some packages.", ) +@click.option( + "--summary-format", + "-f", + required=False, + type=click.Choice(["json", "yaml"], case_sensitive=False), + default=None, + help="Set output format for registration summary. Lists registered workflows, tasks, and launch plans. 'json' and 'yaml' supported.", +) +@click.option( + "--summary-dir", + required=False, + type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True), + default=None, + help="Directory to save registration summary. Uses current working directory if not specified.", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -162,12 +177,15 @@ def register( activate_launchplans: bool, env: typing.Optional[typing.Dict[str, str]], skip_errors: bool, + summary_format: typing.Optional[str], + summary_dir: typing.Optional[str], ): """ see help """ # Set the relevant copy option if non_fast is set, this enables the individual file listing behavior # that the copy flag uses. + if non_fast: click.secho("The --non-fast flag is deprecated, please use --copy none instead", fg="yellow") if "--copy" in sys.argv: @@ -195,6 +213,9 @@ def register( "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", ) + if summary_dir is not None and summary_format is None: + raise click.UsageError("--summary-format is a required parameter when --summary-dir is specified") + # Use extra images in the config file if that file exists config_file = ctx.obj.get(constants.CTX_CONFIG_FILE) if config_file: @@ -225,6 +246,8 @@ def register( package_or_module=package_or_module, remote=remote, env=env, + summary_format=summary_format, + summary_dir=summary_dir, dry_run=dry_run, activate_launchplans=activate_launchplans, skip_errors=skip_errors, diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index e2e46f49d3..9e6755aecd 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -1,5 +1,6 @@ import asyncio import functools +import json import os import tarfile import tempfile @@ -7,6 +8,7 @@ from pathlib import Path import click +import yaml from rich import print as rprint from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings @@ -251,6 +253,8 @@ def register( remote: FlyteRemote, copy_style: CopyFileDetection, env: typing.Optional[typing.Dict[str, str]], + summary_format: typing.Optional[str], + summary_dir: typing.Optional[str], dry_run: bool = False, activate_launchplans: bool = False, skip_errors: bool = False, @@ -333,6 +337,14 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity): is_lp = True else: og_id = cp_entity.template.id + + result = { + "id": og_id.name, + "type": og_id.resource_type_name(), + "version": og_id.version, + "status": "skipped", # default status + } + try: if not dry_run: try: @@ -350,30 +362,60 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity): print_registration_status( i, console_url=console_url, verbosity=verbosity, activation=print_activation_message ) + result["status"] = "success" except Exception as e: if not skip_errors: raise e print_registration_status(og_id, success=False) + result["status"] = "failed" else: print_registration_status(og_id, dry_run=True) except RegistrationSkipped: print_registration_status(og_id, success=False) + result["status"] = "skipped" + + return result async def _register(entities: typing.List[task.TaskSpec]): loop = asyncio.get_running_loop() tasks = [] for entity in entities: tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity))) - await asyncio.gather(*tasks) - return + results = await asyncio.gather(*tasks) + return results # concurrent register cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities)) - asyncio.run(_register(cp_task_entities)) + task_results = asyncio.run(_register(cp_task_entities)) # serial register cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities)) + other_results = [] for entity in cp_other_entities: - _raw_register(entity) + other_results.append(_raw_register(entity)) + + all_results = task_results + other_results click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green") + + if summary_format is not None: + if summary_dir is not None: + # Directory path is already absolute and resolved via click.Path + os.makedirs(summary_dir, exist_ok=True) + else: + # Default to current working directory if not specified + summary_dir = os.getcwd() + + summary_file = f"registration_summary.{summary_format}" + summary_path = os.path.join(summary_dir, summary_file) + + if summary_format == "json": + with open(summary_path, "w") as f: + json.dump(all_results, f) + elif summary_format == "yaml": + with open(summary_path, "w") as f: + yaml.dump(all_results, f) + else: + raise ValueError(f"Unsupported file format: {summary_format}") + + click.secho(f"Registration summary written to: {summary_path}", fg="green") diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index ec14aa8227..97e09058e3 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -1,6 +1,7 @@ import os import shutil import subprocess +import json import mock import pytest @@ -163,3 +164,80 @@ def test_non_fast_register_require_version(mock_client, mock_remote): result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core3"]) assert result.exit_code == 1 shutil.rmtree("core3") + + +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_summary_dir_without_format(mock_client, mock_remote): + mock_remote._client = mock_client + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core4", exist_ok=True) + with open(os.path.join("core4", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke(pyflyte.main, ["register", "--summary-dir", "summaries", "core4"]) + assert result.exit_code == 2 + print(result.output) + assert "--summary-format is a required parameter when --summary-dir is specified" in result.output + + shutil.rmtree("core4") + + +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_registrated_summary_json(mock_client, mock_remote): + mock_remote._client = mock_client + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core5", exist_ok=True) + os.makedirs("summaries", exist_ok=True) + with open(os.path.join("core5", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + + # Run registration command + # result = runner.invoke( + # pyflyte.main, + # ["register", "--summary-format", "json", "--summary-dir", "summaries", "core5"] + # ) + result = runner.invoke( + pyflyte.main, + ["register", "--summary-format", "json", "core5"] + ) + + assert result.exit_code == 0 + + summary_path = os.path.join("summaries", "registration_summary.json") + assert os.path.exists(summary_path) + + with open(summary_path) as f: + summary_data = json.load(f) + + assert isinstance(summary_data, list) + assert len(summary_data) > 0 + for entry in summary_data: + assert "id" in entry + assert "type" in entry + assert "version" in entry + assert "status" in entry + + # Ensure cleanup happens even if test fails + if os.path.exists("core5"): + shutil.rmtree("core5") + if os.path.exists("summaries"): + shutil.rmtree("summaries")