Skip to content

Commit

Permalink
feat: add options to output registered entity summary
Browse files Browse the repository at this point in the history
  • Loading branch information
paullongtan committed Jan 2, 2025
1 parent 60fa417 commit dda078c
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 4 deletions.
23 changes: 23 additions & 0 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,21 @@
help="Skip errors during registration. This is useful when registering multiple packages and you want to skip "
"errors for some packages.",
)
@click.option(
"--summary-format",
"-f",
required=False,
type=click.Choice(["json", "yaml"], case_sensitive=False),
default=None,
help="Set output format for registration summary. Lists registered workflows, tasks, and launch plans. 'json' and 'yaml' supported.",
)
@click.option(
"--summary-dir",
required=False,
type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True),
default=None,
help="Directory to save registration summary. Uses current working directory if not specified.",
)
@click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1)
@click.pass_context
def register(
Expand All @@ -162,12 +177,15 @@ def register(
activate_launchplans: bool,
env: typing.Optional[typing.Dict[str, str]],
skip_errors: bool,
summary_format: typing.Optional[str],
summary_dir: typing.Optional[str],
):
"""
see help
"""
# Set the relevant copy option if non_fast is set, this enables the individual file listing behavior
# that the copy flag uses.

if non_fast:
click.secho("The --non-fast flag is deprecated, please use --copy none instead", fg="yellow")
if "--copy" in sys.argv:
Expand Down Expand Up @@ -195,6 +213,9 @@ def register(
"Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed",
)

if summary_dir is not None and summary_format is None:
raise click.UsageError("--summary-format is a required parameter when --summary-dir is specified")

# Use extra images in the config file if that file exists
config_file = ctx.obj.get(constants.CTX_CONFIG_FILE)
if config_file:
Expand Down Expand Up @@ -225,6 +246,8 @@ def register(
package_or_module=package_or_module,
remote=remote,
env=env,
summary_format=summary_format,
summary_dir=summary_dir,
dry_run=dry_run,
activate_launchplans=activate_launchplans,
skip_errors=skip_errors,
Expand Down
50 changes: 46 additions & 4 deletions flytekit/tools/repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import functools
import json
import os
import tarfile
import tempfile
import typing
from pathlib import Path

import click
import yaml
from rich import print as rprint

from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
Expand Down Expand Up @@ -251,6 +253,8 @@ def register(
remote: FlyteRemote,
copy_style: CopyFileDetection,
env: typing.Optional[typing.Dict[str, str]],
summary_format: typing.Optional[str],
summary_dir: typing.Optional[str],
dry_run: bool = False,
activate_launchplans: bool = False,
skip_errors: bool = False,
Expand Down Expand Up @@ -333,6 +337,14 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity):
is_lp = True
else:
og_id = cp_entity.template.id

result = {
"id": og_id.name,
"type": og_id.resource_type_name(),
"version": og_id.version,
"status": "skipped", # default status
}

try:
if not dry_run:
try:
Expand All @@ -350,30 +362,60 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity):
print_registration_status(
i, console_url=console_url, verbosity=verbosity, activation=print_activation_message
)
result["status"] = "success"

except Exception as e:
if not skip_errors:
raise e
print_registration_status(og_id, success=False)
result["status"] = "failed"
else:
print_registration_status(og_id, dry_run=True)
except RegistrationSkipped:
print_registration_status(og_id, success=False)
result["status"] = "skipped"

return result

async def _register(entities: typing.List[task.TaskSpec]):
loop = asyncio.get_running_loop()
tasks = []
for entity in entities:
tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity)))
await asyncio.gather(*tasks)
return
results = await asyncio.gather(*tasks)
return results

# concurrent register
cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities))
asyncio.run(_register(cp_task_entities))
task_results = asyncio.run(_register(cp_task_entities))
# serial register
cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities))
other_results = []
for entity in cp_other_entities:
_raw_register(entity)
other_results.append(_raw_register(entity))

all_results = task_results + other_results

click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green")

if summary_format is not None:
if summary_dir is not None:
# Directory path is already absolute and resolved via click.Path
os.makedirs(summary_dir, exist_ok=True)
else:
# Default to current working directory if not specified
summary_dir = os.getcwd()

summary_file = f"registration_summary.{summary_format}"
summary_path = os.path.join(summary_dir, summary_file)

if summary_format == "json":
with open(summary_path, "w") as f:
json.dump(all_results, f)
elif summary_format == "yaml":
with open(summary_path, "w") as f:
yaml.dump(all_results, f)
else:
raise ValueError(f"Unsupported file format: {summary_format}")

click.secho(f"Registration summary written to: {summary_path}", fg="green")
78 changes: 78 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_register.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shutil
import subprocess
import json

import mock
import pytest
Expand Down Expand Up @@ -163,3 +164,80 @@ def test_non_fast_register_require_version(mock_client, mock_remote):
result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core3"])
assert result.exit_code == 1
shutil.rmtree("core3")


@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_summary_dir_without_format(mock_client, mock_remote):
mock_remote._client = mock_client
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"

runner = CliRunner()
context_manager.FlyteEntities.entities.clear()

with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs("core4", exist_ok=True)
with open(os.path.join("core4", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()
result = runner.invoke(pyflyte.main, ["register", "--summary-dir", "summaries", "core4"])
assert result.exit_code == 2
print(result.output)
assert "--summary-format is a required parameter when --summary-dir is specified" in result.output

shutil.rmtree("core4")


@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_registrated_summary_json(mock_client, mock_remote):
mock_remote._client = mock_client
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"

runner = CliRunner()
context_manager.FlyteEntities.entities.clear()

with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs("core5", exist_ok=True)
os.makedirs("summaries", exist_ok=True)
with open(os.path.join("core5", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()

# Run registration command
# result = runner.invoke(
# pyflyte.main,
# ["register", "--summary-format", "json", "--summary-dir", "summaries", "core5"]
# )
result = runner.invoke(
pyflyte.main,
["register", "--summary-format", "json", "core5"]
)

assert result.exit_code == 0

summary_path = os.path.join("summaries", "registration_summary.json")
assert os.path.exists(summary_path)

with open(summary_path) as f:
summary_data = json.load(f)

assert isinstance(summary_data, list)
assert len(summary_data) > 0
for entry in summary_data:
assert "id" in entry
assert "type" in entry
assert "version" in entry
assert "status" in entry

# Ensure cleanup happens even if test fails
if os.path.exists("core5"):
shutil.rmtree("core5")
if os.path.exists("summaries"):
shutil.rmtree("summaries")

0 comments on commit dda078c

Please sign in to comment.