From d65286a701f0449479f0f27127017ae2a4b09bfe Mon Sep 17 00:00:00 2001 From: Francois Campbell Date: Thu, 31 Oct 2024 15:15:35 -0400 Subject: [PATCH] NA+SPCS PoC --- .../nativeapp/entities/application_package.py | 96 ++++++++++++++++++- .../nativeapp/v2_conversions/compat.py | 16 +++- .../cli/_plugins/spcs/entities/__init__.py | 0 .../cli/_plugins/spcs/entities/service.py | 35 +++++++ .../api/project/schemas/entities/entities.py | 6 ++ 5 files changed, 144 insertions(+), 9 deletions(-) create mode 100644 src/snowflake/cli/_plugins/spcs/entities/__init__.py create mode 100644 src/snowflake/cli/_plugins/spcs/entities/service.py diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py index 10c6e9550c..373f15208f 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py @@ -6,11 +6,14 @@ from typing import List, Literal, Optional, Union import typer +import yaml from click import BadOptionUsage, ClickException from pydantic import Field, field_validator from snowflake.cli._plugins.nativeapp.artifacts import ( BundleMap, build_bundle, + find_manifest_file, + find_setup_script_file, find_version_info_in_manifest_file, ) from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext @@ -40,6 +43,10 @@ PolicyBase, ) from snowflake.cli._plugins.nativeapp.utils import needs_confirmation +from snowflake.cli._plugins.spcs.entities.service import ( + ServiceEntity, + ServiceEntityModel, +) from snowflake.cli._plugins.stage.diff import DiffResult from snowflake.cli._plugins.stage.manager import StageManager from snowflake.cli._plugins.workspace.context import ActionContext @@ -109,6 +116,12 @@ class ApplicationPackageEntityModel(EntityModelBase): title="Path to manifest.yml", ) + ### SPCS PoC + services: list[str] = Field( + title="List of Snowpark Container Service entity IDs to integrate into this application package", + default=[], + ) + @field_validator("identifier") @classmethod def append_test_resource_suffix_to_identifier( @@ -191,7 +204,8 @@ def post_deploy_hooks(self) -> list[PostDeployHook] | None: return model.meta and model.meta.post_deploy def action_bundle(self, action_ctx: ActionContext, *args, **kwargs): - return self._bundle() + spcs_services = [action_ctx.get_entity(s) for s in self._entity_model.services] + return self._bundle(spcs_services=spcs_services) def action_deploy( self, @@ -206,6 +220,7 @@ def action_deploy( *args, **kwargs, ): + spcs_services = [action_ctx.get_entity(s) for s in self._entity_model.services] return self._deploy( bundle_map=None, prune=prune, @@ -216,6 +231,7 @@ def action_deploy( stage_fqn=stage_fqn or self.stage_fqn, interactive=interactive, force=force, + spcs_services=spcs_services, ) def action_drop(self, action_ctx: ActionContext, force_drop: bool, *args, **kwargs): @@ -357,6 +373,8 @@ def action_version_create( else: git_policy = AllowAlwaysPolicy() + spcs_services = [action_ctx.get_entity(s) for s in self._entity_model.services] + # Make sure version is not None before proceeding any further. # This will raise an exception if version information is not found. Patch can be None. bundle_map = None @@ -369,7 +387,7 @@ def action_version_create( """ ) ) - bundle_map = self._bundle() + bundle_map = self._bundle(spcs_services=spcs_services) version, patch = find_version_info_in_manifest_file(self.deploy_root) if not version: raise ClickException( @@ -403,6 +421,7 @@ def action_version_create( stage_fqn=self.stage_fqn, interactive=interactive, force=force, + spcs_services=spcs_services, ) # Warn if the version exists in a release directive(s) @@ -489,7 +508,7 @@ def action_version_drop( """ ) ) - self._bundle() + self._bundle(spcs_services=[]) version, _ = find_version_info_in_manifest_file(self.deploy_root) if not version: raise ClickException( @@ -533,7 +552,7 @@ def action_version_drop( f"Version {version} in application package {self.name} dropped successfully." ) - def _bundle(self): + def _bundle(self, spcs_services: list[ServiceEntity]): model = self._entity_model bundle_map = build_bundle(self.project_root, self.deploy_root, model.artifacts) bundle_context = BundleContext( @@ -546,8 +565,73 @@ def _bundle(self): ) compiler = NativeAppCompiler(bundle_context) compiler.compile_artifacts() + + # TODO should this merged into NativeAppCompiler? + self._inject_spcs(spcs_services) + return bundle_map + def _inject_spcs(self, spcs_services: list[ServiceEntity]): + manifest_path = find_manifest_file(self.deploy_root) + manifest = yaml.safe_load(manifest_path.read_text()) + if "configuration" not in manifest: + manifest["configuration"] = {} + existing_grant_callback = manifest["configuration"].get("grant_callback") + wrapper_grant_callback = "_spcs_generation.grant_callback" + manifest["configuration"]["grant_callback"] = wrapper_grant_callback + # TODO set default_web_endpoint in manifest? + if manifest_path.is_symlink(): + manifest_path.unlink() + manifest_path.write_text(yaml.safe_dump(manifest, sort_keys=False)) + + generated_setup_script = self._spcs_grant_callback( + name=wrapper_grant_callback, + service=spcs_services[0]._entity_model, # noqa SLF001 + existing_grant_callback=existing_grant_callback, + ) + + setup_script_path = find_setup_script_file(self.deploy_root) + setup_script = setup_script_path.read_text() + if setup_script_path.is_symlink(): + setup_script_path.unlink() + setup_script_path.write_text(setup_script + generated_setup_script) + + @staticmethod + def _spcs_grant_callback( + name: str, service: ServiceEntityModel, existing_grant_callback: str + ): + return dedent( + f"""\ + -- Begin generated SPCS services, this section is managed by the Snowflake CLI + create schema if not exists _spcs_generation; + + create or replace procedure {name}(privileges array) + returns string + as $$ + begin + {f'call {existing_grant_callback}(privileges);' if existing_grant_callback else ''} + if (array_contains('CREATE COMPUTE POOL'::variant, privileges)) then + create compute pool if not exists {service.compute_pool} + min_nodes = {service.min_nodes} + max_nodes = {service.max_nodes} + instance_family = {service.instance_family}; + end if; + if (array_contains('BIND SERVICE ENDPOINT'::variant, privileges)) then + create service if not exists _spcs_generation.{service.fqn.name} + in compute pool {service.compute_pool} + from specification_file = '{service.specification_file}'; + end if; + return 'done'; + end; + $$; + + create application role if not exists _spcs_generation_role; + grant usage on procedure {name}(array) to application role _spcs_generation_role; + + -- End generated SPCS services + """ + ) + def _deploy( self, bundle_map: BundleMap | None, @@ -559,6 +643,7 @@ def _deploy( stage_fqn: str, interactive: bool, force: bool, + spcs_services: list[ServiceEntity], run_post_deploy_hooks: bool = True, ) -> DiffResult: model = self._entity_model @@ -574,7 +659,7 @@ def _deploy( stage_fqn = stage_fqn or self.stage_fqn # 1. Create a bundle if one wasn't passed in - bundle_map = bundle_map or self._bundle() + bundle_map = bundle_map or self._bundle(spcs_services=spcs_services) # 2. Create an empty application package, if none exists try: @@ -932,6 +1017,7 @@ def get_validation_result( stage_fqn=self.scratch_stage_fqn, interactive=interactive, force=force, + spcs_services=[], # TODO this affects the setup script, but it's under our control run_post_deploy_hooks=False, ) prefixed_stage_fqn = StageManager.get_standard_stage_prefix(stage_fqn) diff --git a/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py b/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py index 93d60c2e2b..449d222937 100644 --- a/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py +++ b/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py @@ -211,13 +211,21 @@ def wrapper(*args, **kwargs): app_definition, app_package_definition = _find_app_and_package_entities( original_pdf, package_entity_id, app_entity_id, app_required ) - entities_to_keep = {app_package_definition.entity_id} + native_app_entities_to_keep = {app_package_definition.entity_id} kwargs["package_entity_id"] = app_package_definition.entity_id if app_definition: - entities_to_keep.add(app_definition.entity_id) + native_app_entities_to_keep.add(app_definition.entity_id) kwargs["app_entity_id"] = app_definition.entity_id - for entity_id in list(original_pdf.entities): - if entity_id not in entities_to_keep: + + native_app_entity_classes = ( + ApplicationEntityModel, + ApplicationPackageEntityModel, + ) + for entity_id, entity_model in list(original_pdf.entities.items()): + if ( + isinstance(entity_model, native_app_entity_classes) + and entity_id not in native_app_entities_to_keep + ): # This happens after templates are rendered, # so we can safely remove the entity del original_pdf.entities[entity_id] diff --git a/src/snowflake/cli/_plugins/spcs/entities/__init__.py b/src/snowflake/cli/_plugins/spcs/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/snowflake/cli/_plugins/spcs/entities/service.py b/src/snowflake/cli/_plugins/spcs/entities/service.py new file mode 100644 index 0000000000..6ed3b6294b --- /dev/null +++ b/src/snowflake/cli/_plugins/spcs/entities/service.py @@ -0,0 +1,35 @@ +from typing import Literal + +from pydantic import Field +from snowflake.cli.api.entities.common import EntityBase +from snowflake.cli.api.project.schemas.entities.common import EntityModelBase +from snowflake.cli.api.project.schemas.updatable_model import DiscriminatorField + + +class ServiceEntityModel(EntityModelBase): + type: Literal["snowpark container service"] = DiscriminatorField() # noqa: A003 + specification_file: str = Field( + title="Path to the specification file for the SPCS service, relative to the deploy root", + ) + # TODO is a compute pool a separate entity? + compute_pool: str = Field( + title="Name of the compute pool to use for the SPCS service", + ) + min_nodes: int = Field( + title="Minimum number of nodes in the compute pool", + default=1, + ) + max_nodes: int = Field( + title="Maximum number of nodes in the compute pool", + default=1, + ) + instance_family: str = Field( + title="Instance family to use for the compute pool", + default="CPU_X64_XS", + ) + + +class ServiceEntity(EntityBase[ServiceEntityModel]): + # Local deploy of SPSC service not yet implemented + # We only use the model to deploy SPCS services in native apps + pass diff --git a/src/snowflake/cli/api/project/schemas/entities/entities.py b/src/snowflake/cli/api/project/schemas/entities/entities.py index 008cbc0db2..66f0072e27 100644 --- a/src/snowflake/cli/api/project/schemas/entities/entities.py +++ b/src/snowflake/cli/api/project/schemas/entities/entities.py @@ -32,6 +32,10 @@ FunctionEntityModel, ProcedureEntityModel, ) +from snowflake.cli._plugins.spcs.entities.service import ( + ServiceEntity, + ServiceEntityModel, +) from snowflake.cli._plugins.streamlit.streamlit_entity import StreamlitEntity from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( StreamlitEntityModel, @@ -43,6 +47,7 @@ StreamlitEntity, ProcedureEntity, FunctionEntity, + ServiceEntity, ] EntityModel = Union[ ApplicationEntityModel, @@ -50,6 +55,7 @@ StreamlitEntityModel, FunctionEntityModel, ProcedureEntityModel, + ServiceEntityModel, ] ALL_ENTITIES: List[Entity] = [*get_args(Entity)]