From 9305955fb0783fcdf1d49cd7f4ca3f9b0050d00e Mon Sep 17 00:00:00 2001 From: Patryk Czajka Date: Thu, 12 Dec 2024 16:44:52 +0100 Subject: [PATCH 1/5] Temporarily disable schema existence check for object create (#1949) * Temporarily disable schema existance check for object create * skip test --- src/snowflake/cli/api/rest_api.py | 5 +++-- tests_integration/test_object.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/snowflake/cli/api/rest_api.py b/src/snowflake/cli/api/rest_api.py index ae829645b8..0125d51428 100644 --- a/src/snowflake/cli/api/rest_api.py +++ b/src/snowflake/cli/api/rest_api.py @@ -155,8 +155,9 @@ def determine_url_for_create_query( raise SchemaNotDefinedException( "Schema not defined in connection. Please try again with `--schema` flag." ) - if not self._schema_exists(db_name=db, schema_name=schema): - raise SchemaNotExistsException(f"Schema '{schema}' does not exist.") + # temporarily disable this check due to an issue on server side: SNOW-1747450 + # if not self._schema_exists(db_name=db, schema_name=schema): + # raise SchemaNotExistsException(f"Schema '{schema}' does not exist.") if self.get_endpoint_exists( url := f"{SF_REST_API_URL_PREFIX}/databases/{self.conn.database}/schemas/{self.conn.schema}/{plural_object_type}/" ): diff --git a/tests_integration/test_object.py b/tests_integration/test_object.py index 386150438b..9dbe526125 100644 --- a/tests_integration/test_object.py +++ b/tests_integration/test_object.py @@ -313,6 +313,7 @@ def test_create_error_database_not_exist(runner): @pytest.mark.integration +@pytest.mark.skip(reason="Server-side issue: SNOW-1855040") def test_create_error_schema_not_exist(runner, test_database): # schema does not exist result = runner.invoke_with_connection( From 5a3e36ee378c955a6079f37dbcd376698c5942e9 Mon Sep 17 00:00:00 2001 From: Guy Bloom Date: Thu, 12 Dec 2024 10:51:52 -0500 Subject: [PATCH 2/5] POC: Add child entities to application package (#1856) * add child entities * children_artifacts_dir * unit tests * sanitize dir name docstring * error message on child directory collision --- .../cli/_plugins/nativeapp/commands.py | 5 +- .../nativeapp/entities/application_package.py | 181 ++++++++++++++++-- .../application_package_child_interface.py | 43 +++++ .../cli/_plugins/nativeapp/feature_flags.py | 1 + src/snowflake/cli/_plugins/nativeapp/utils.py | 11 ++ .../nativeapp/v2_conversions/compat.py | 6 +- .../_plugins/streamlit/streamlit_entity.py | 64 ++++++- .../cli/_plugins/workspace/manager.py | 12 +- src/snowflake/cli/api/entities/common.py | 4 + .../api/project/schemas/project_definition.py | 33 +++- .../test_application_package_entity.py | 4 +- tests/nativeapp/test_children.py | 152 +++++++++++++++ tests/nativeapp/test_manager.py | 7 +- tests/streamlit/test_streamlit_entity.py | 53 +++++ .../projects/napp_children/app/README.md | 1 + .../projects/napp_children/app/manifest.yml | 7 + .../napp_children/app/setup_script.sql | 3 + .../projects/napp_children/snowflake.yml | 21 ++ .../projects/napp_children/streamlit_app.py | 20 ++ 19 files changed, 598 insertions(+), 30 deletions(-) create mode 100644 src/snowflake/cli/_plugins/nativeapp/entities/application_package_child_interface.py create mode 100644 tests/nativeapp/test_children.py create mode 100644 tests/streamlit/test_streamlit_entity.py create mode 100644 tests/test_data/projects/napp_children/app/README.md create mode 100644 tests/test_data/projects/napp_children/app/manifest.yml create mode 100644 tests/test_data/projects/napp_children/app/setup_script.sql create mode 100644 tests/test_data/projects/napp_children/snowflake.yml create mode 100644 tests/test_data/projects/napp_children/streamlit_app.py diff --git a/src/snowflake/cli/_plugins/nativeapp/commands.py b/src/snowflake/cli/_plugins/nativeapp/commands.py index 2d5bcbf901..411067ac5e 100644 --- a/src/snowflake/cli/_plugins/nativeapp/commands.py +++ b/src/snowflake/cli/_plugins/nativeapp/commands.py @@ -362,7 +362,10 @@ def app_validate( if cli_context.output_format == OutputFormat.JSON: return ObjectResult( package.get_validation_result( - use_scratch_stage=True, interactive=False, force=True + action_ctx=ws.action_ctx, + use_scratch_stage=True, + interactive=False, + force=True, ) ) diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py index 54c643a628..8f76c7ac4c 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py @@ -1,10 +1,11 @@ from __future__ import annotations import json +import os import re from pathlib import Path from textwrap import dedent -from typing import Any, List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Set, Union import typer from click import BadOptionUsage, ClickException @@ -14,6 +15,7 @@ BundleMap, VersionInfo, build_bundle, + find_setup_script_file, find_version_info_in_manifest_file, ) from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext @@ -30,6 +32,9 @@ PATCH_COL, VERSION_COL, ) +from snowflake.cli._plugins.nativeapp.entities.application_package_child_interface import ( + ApplicationPackageChildInterface, +) from snowflake.cli._plugins.nativeapp.exceptions import ( ApplicationPackageAlreadyExistsError, ApplicationPackageDoesNotExistError, @@ -48,9 +53,16 @@ from snowflake.cli._plugins.nativeapp.sf_facade_exceptions import ( InsufficientPrivilegesError, ) -from snowflake.cli._plugins.nativeapp.utils import needs_confirmation +from snowflake.cli._plugins.nativeapp.utils import needs_confirmation, sanitize_dir_name +from snowflake.cli._plugins.snowpark.snowpark_entity_model import ( + FunctionEntityModel, + ProcedureEntityModel, +) from snowflake.cli._plugins.stage.diff import DiffResult from snowflake.cli._plugins.stage.manager import StageManager +from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( + StreamlitEntityModel, +) from snowflake.cli._plugins.workspace.context import ActionContext from snowflake.cli.api.cli_global_context import span from snowflake.cli.api.entities.common import ( @@ -75,6 +87,7 @@ from snowflake.cli.api.project.schemas.updatable_model import ( DiscriminatorField, IdentifierField, + UpdatableModel, ) from snowflake.cli.api.project.schemas.v1.native_app.package import DistributionOptions from snowflake.cli.api.project.schemas.v1.native_app.path_mapping import PathMapping @@ -94,6 +107,43 @@ from snowflake.connector import DictCursor, ProgrammingError from snowflake.connector.cursor import SnowflakeCursor +ApplicationPackageChildrenTypes = ( + StreamlitEntityModel | FunctionEntityModel | ProcedureEntityModel +) + + +class ApplicationPackageChildIdentifier(UpdatableModel): + schema_: Optional[str] = Field( + title="Child entity schema", alias="schema", default=None + ) + + +class EnsureUsableByField(UpdatableModel): + application_roles: Optional[Union[str, Set[str]]] = Field( + title="One or more application roles to be granted with the required privileges", + default=None, + ) + + @field_validator("application_roles") + @classmethod + def ensure_app_roles_is_a_set( + cls, application_roles: Optional[Union[str, Set[str]]] + ) -> Optional[Union[Set[str]]]: + if isinstance(application_roles, str): + return set([application_roles]) + return application_roles + + +class ApplicationPackageChildField(UpdatableModel): + target: str = Field(title="The key of the entity to include in this package") + ensure_usable_by: Optional[EnsureUsableByField] = Field( + title="Automatically grant the required privileges on the child object and its schema", + default=None, + ) + identifier: ApplicationPackageChildIdentifier = Field( + title="Entity identifier", default=None + ) + class ApplicationPackageEntityModel(EntityModelBase): type: Literal["application package"] = DiscriminatorField() # noqa: A003 @@ -101,23 +151,27 @@ class ApplicationPackageEntityModel(EntityModelBase): title="List of paths or file source/destination pairs to add to the deploy root", ) bundle_root: Optional[str] = Field( - title="Folder at the root of your project where artifacts necessary to perform the bundle step are stored.", + title="Folder at the root of your project where artifacts necessary to perform the bundle step are stored", default="output/bundle/", ) deploy_root: Optional[str] = Field( title="Folder at the root of your project where the build step copies the artifacts", default="output/deploy/", ) + children_artifacts_dir: Optional[str] = Field( + title="Folder under deploy_root where the child artifacts will be stored", + default="_children/", + ) generated_root: Optional[str] = Field( - title="Subdirectory of the deploy root where files generated by the Snowflake CLI will be written.", + title="Subdirectory of the deploy root where files generated by the Snowflake CLI will be written", default="__generated/", ) stage: Optional[str] = IdentifierField( - title="Identifier of the stage that stores the application artifacts.", + title="Identifier of the stage that stores the application artifacts", default="app_src.stage", ) scratch_stage: Optional[str] = IdentifierField( - title="Identifier of the stage that stores temporary scratch data used by the Snowflake CLI.", + title="Identifier of the stage that stores temporary scratch data used by the Snowflake CLI", default="app_src.stage_snowflake_cli_scratch", ) distribution: Optional[DistributionOptions] = Field( @@ -128,6 +182,19 @@ class ApplicationPackageEntityModel(EntityModelBase): title="Path to manifest.yml. Unused and deprecated starting with Snowflake CLI 3.2", default="", ) + children: Optional[List[ApplicationPackageChildField]] = Field( + title="Entities that will be bundled and deployed as part of this application package", + default=[], + ) + + @field_validator("children") + @classmethod + def verify_children_behind_flag( + cls, input_value: Optional[List[ApplicationPackageChildField]] + ) -> Optional[List[ApplicationPackageChildField]]: + if input_value and not FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled(): + raise AttributeError("Application package children are not supported yet") + return input_value @field_validator("identifier") @classmethod @@ -183,6 +250,10 @@ def project_root(self) -> Path: def deploy_root(self) -> Path: return self.project_root / self._entity_model.deploy_root + @property + def children_artifacts_deploy_root(self) -> Path: + return self.deploy_root / self._entity_model.children_artifacts_dir + @property def bundle_root(self) -> Path: return self.project_root / self._entity_model.bundle_root @@ -221,7 +292,7 @@ def post_deploy_hooks(self) -> list[PostDeployHook] | None: return model.meta and model.meta.post_deploy def action_bundle(self, action_ctx: ActionContext, *args, **kwargs): - return self._bundle() + return self._bundle(action_ctx) def action_deploy( self, @@ -237,6 +308,7 @@ def action_deploy( **kwargs, ): return self._deploy( + action_ctx=action_ctx, bundle_map=None, prune=prune, recursive=recursive, @@ -336,6 +408,7 @@ def action_validate( **kwargs, ): self.validate_setup_script( + action_ctx=action_ctx, use_scratch_stage=use_scratch_stage, interactive=interactive, force=force, @@ -390,7 +463,7 @@ def action_version_create( else: git_policy = AllowAlwaysPolicy() - bundle_map = self._bundle() + bundle_map = self._bundle(action_ctx) resolved_version, resolved_patch, resolved_label = self.resolve_version_info( version=version, patch=patch, @@ -404,6 +477,7 @@ def action_version_create( self.check_index_changes_in_git_repo(policy=policy, interactive=interactive) self._deploy( + action_ctx=action_ctx, bundle_map=bundle_map, prune=True, recursive=True, @@ -507,7 +581,7 @@ def action_version_drop( """ ) ) - self._bundle() + self._bundle(action_ctx) version_info = find_version_info_in_manifest_file(self.deploy_root) version = version_info.version_name if not version: @@ -692,7 +766,7 @@ def action_release_directive_unset( role=self.role, ) - def _bundle(self): + def _bundle(self, action_ctx: ActionContext = None): model = self._entity_model bundle_map = build_bundle(self.project_root, self.deploy_root, model.artifacts) bundle_context = BundleContext( @@ -705,10 +779,80 @@ def _bundle(self): ) compiler = NativeAppCompiler(bundle_context) compiler.compile_artifacts() + + if self._entity_model.children: + # Bundle children and append their SQL to setup script + # TODO Consider re-writing the logic below as a processor + children_sql = self._bundle_children(action_ctx=action_ctx) + setup_file_path = find_setup_script_file(deploy_root=self.deploy_root) + with open(setup_file_path, "r", encoding="utf-8") as file: + existing_setup_script = file.read() + if setup_file_path.is_symlink(): + setup_file_path.unlink() + with open(setup_file_path, "w", encoding="utf-8") as file: + file.write(existing_setup_script) + file.write("\n-- AUTO GENERATED CHILDREN SECTION\n") + file.write("\n".join(children_sql)) + file.write("\n") + return bundle_map + def _bundle_children(self, action_ctx: ActionContext) -> List[str]: + # Create _children directory + children_artifacts_dir = self.children_artifacts_deploy_root + os.makedirs(children_artifacts_dir) + children_sql = [] + for child in self._entity_model.children: + # Create child sub directory + child_artifacts_dir = children_artifacts_dir / sanitize_dir_name( + child.target + ) + try: + os.makedirs(child_artifacts_dir) + except FileExistsError: + raise ClickException( + f"Could not create sub-directory at {child_artifacts_dir}. Make sure child entity names do not collide with each other." + ) + child_entity: ApplicationPackageChildInterface = action_ctx.get_entity( + child.target + ) + child_entity.bundle(child_artifacts_dir) + app_role = ( + to_identifier( + child.ensure_usable_by.application_roles.pop() # TODO Support more than one application role + ) + if child.ensure_usable_by and child.ensure_usable_by.application_roles + else None + ) + child_schema = ( + to_identifier(child.identifier.schema_) + if child.identifier and child.identifier.schema_ + else None + ) + children_sql.append( + child_entity.get_deploy_sql( + artifacts_dir=child_artifacts_dir.relative_to(self.deploy_root), + schema=child_schema, + ) + ) + if app_role: + children_sql.append( + f"CREATE APPLICATION ROLE IF NOT EXISTS {app_role};" + ) + if child_schema: + children_sql.append( + f"GRANT USAGE ON SCHEMA {child_schema} TO APPLICATION ROLE {app_role};" + ) + children_sql.append( + child_entity.get_usage_grant_sql( + app_role=app_role, schema=child_schema + ) + ) + return children_sql + def _deploy( self, + action_ctx: ActionContext, bundle_map: BundleMap | None, prune: bool, recursive: bool, @@ -733,7 +877,7 @@ def _deploy( stage_fqn = stage_fqn or self.stage_fqn # 1. Create a bundle if one wasn't passed in - bundle_map = bundle_map or self._bundle() + bundle_map = bundle_map or self._bundle(action_ctx) # 2. Create an empty application package, if none exists try: @@ -765,6 +909,7 @@ def _deploy( if validate: self.validate_setup_script( + action_ctx=action_ctx, use_scratch_stage=False, interactive=interactive, force=force, @@ -1054,7 +1199,11 @@ def execute_post_deploy_hooks(self): ) def validate_setup_script( - self, use_scratch_stage: bool, interactive: bool, force: bool + self, + action_ctx: ActionContext, + use_scratch_stage: bool, + interactive: bool, + force: bool, ): workspace_ctx = self._workspace_ctx console = workspace_ctx.console @@ -1062,6 +1211,7 @@ def validate_setup_script( """Validates Native App setup script SQL.""" with console.phase(f"Validating Snowflake Native App setup script."): validation_result = self.get_validation_result( + action_ctx=action_ctx, use_scratch_stage=use_scratch_stage, force=force, interactive=interactive, @@ -1083,13 +1233,18 @@ def validate_setup_script( @span("validate_setup_script") def get_validation_result( - self, use_scratch_stage: bool, interactive: bool, force: bool + self, + action_ctx: ActionContext, + use_scratch_stage: bool, + interactive: bool, + force: bool, ): """Call system$validate_native_app_setup() to validate deployed Native App setup script.""" stage_fqn = self.stage_fqn if use_scratch_stage: stage_fqn = self.scratch_stage_fqn self._deploy( + action_ctx=action_ctx, bundle_map=None, prune=True, recursive=True, diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application_package_child_interface.py b/src/snowflake/cli/_plugins/nativeapp/entities/application_package_child_interface.py new file mode 100644 index 0000000000..c4f13871e4 --- /dev/null +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application_package_child_interface.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + + +class ApplicationPackageChildInterface(ABC): + @abstractmethod + def bundle(self, bundle_root=Path, *args, **kwargs) -> None: + """ + Bundles the entity artifacts into the provided root directory. Must not have any side-effects, such as deploying the artifacts into a stage, etc. + @param bundle_root: The directory where the bundle contents should be put. + """ + pass + + @abstractmethod + def get_deploy_sql( + self, + artifacts_dir: Path, + schema: Optional[str], + *args, + **kwargs, + ) -> str: + """ + Returns the SQL that would create the entity object. Must not execute the SQL or have any other side-effects. + @param artifacts_dir: Path to the child entity artifacts directory relative to the deploy root. + @param [Optional] schema: Schema to use when creating the object. + """ + pass + + @abstractmethod + def get_usage_grant_sql( + self, + app_role: str, + schema: Optional[str], + *args, + **kwargs, + ) -> str: + """ + Returns the SQL that would grant the required USAGE privilege to the provided application role on the entity object. Must not execute the SQL or have any other side-effects. + @param app_role: The application role to grant the privileges to. + @param [Optional] schema: The schema where the object was created. + """ + pass diff --git a/src/snowflake/cli/_plugins/nativeapp/feature_flags.py b/src/snowflake/cli/_plugins/nativeapp/feature_flags.py index dbc47e7483..dc7e93bf51 100644 --- a/src/snowflake/cli/_plugins/nativeapp/feature_flags.py +++ b/src/snowflake/cli/_plugins/nativeapp/feature_flags.py @@ -22,4 +22,5 @@ class FeatureFlag(FeatureFlagMixin): ENABLE_NATIVE_APP_PYTHON_SETUP = BooleanFlag( "ENABLE_NATIVE_APP_PYTHON_SETUP", False ) + ENABLE_NATIVE_APP_CHILDREN = BooleanFlag("ENABLE_NATIVE_APP_CHILDREN", False) ENABLE_RELEASE_CHANNELS = BooleanFlag("ENABLE_RELEASE_CHANNELS", None) diff --git a/src/snowflake/cli/_plugins/nativeapp/utils.py b/src/snowflake/cli/_plugins/nativeapp/utils.py index 87fa989d2a..fa2a4cebd5 100644 --- a/src/snowflake/cli/_plugins/nativeapp/utils.py +++ b/src/snowflake/cli/_plugins/nativeapp/utils.py @@ -96,3 +96,14 @@ def verify_no_directories(paths_to_sync: Iterable[Path]): def verify_exists(path: Path): if not path.exists(): raise ClickException(f"The following path does not exist: {path}") + + +def sanitize_dir_name(dir_name: str) -> str: + """ + Returns a string that is safe to use as a directory name. + For simplicity, this function is over restricitive: it strips non alphanumeric characters, + unless listed in the allow list. Additional characters can be allowed in the future, but + we need to be careful to consider both Unix/Windows directory naming rules. + """ + allowed_chars = [" ", "_"] + return "".join(char for char in dir_name if char in allowed_chars or char.isalnum()) diff --git a/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py b/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py index 93d60c2e2b..a72a12f68d 100644 --- a/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py +++ b/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py @@ -217,7 +217,11 @@ def wrapper(*args, **kwargs): entities_to_keep.add(app_definition.entity_id) kwargs["app_entity_id"] = app_definition.entity_id for entity_id in list(original_pdf.entities): - if entity_id not in entities_to_keep: + entity_type = original_pdf.entities[entity_id].type.lower() + if ( + entity_type in ["application", "application package"] + and entity_id not in entities_to_keep + ): # This happens after templates are rendered, # so we can safely remove the entity del original_pdf.entities[entity_id] diff --git a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py index 6def772525..6b187ba54b 100644 --- a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py +++ b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py @@ -1,12 +1,72 @@ +from pathlib import Path +from typing import Optional + +from snowflake.cli._plugins.nativeapp.artifacts import build_bundle +from snowflake.cli._plugins.nativeapp.entities.application_package_child_interface import ( + ApplicationPackageChildInterface, +) +from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( StreamlitEntityModel, ) from snowflake.cli.api.entities.common import EntityBase +from snowflake.cli.api.project.schemas.v1.native_app.path_mapping import PathMapping -class StreamlitEntity(EntityBase[StreamlitEntityModel]): +# WARNING: This entity is not implemented yet. The logic below is only for demonstrating the +# required interfaces for composability (used by ApplicationPackageEntity behind a feature flag). +class StreamlitEntity( + EntityBase[StreamlitEntityModel], ApplicationPackageChildInterface +): """ A Streamlit app. """ - pass + def __init__(self, *args, **kwargs): + if not FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled(): + raise NotImplementedError("Streamlit entity is not implemented yet") + super().__init__(*args, **kwargs) + + @property + def project_root(self) -> Path: + return self._workspace_ctx.project_root + + @property + def deploy_root(self) -> Path: + return self.project_root / "output" / "deploy" + + def action_bundle( + self, + *args, + **kwargs, + ): + return self.bundle() + + def bundle(self, bundle_root=None): + return build_bundle( + self.project_root, + bundle_root or self.deploy_root, + [ + PathMapping(src=str(artifact)) + for artifact in self._entity_model.artifacts + ], + ) + + def get_deploy_sql( + self, + artifacts_dir: Optional[Path] = None, + schema: Optional[str] = None, + ): + entity_id = self.entity_id + if artifacts_dir: + streamlit_name = f"{schema}.{entity_id}" if schema else entity_id + return f"CREATE OR REPLACE STREAMLIT {streamlit_name} FROM '{artifacts_dir}' MAIN_FILE='{self._entity_model.main_file}';" + else: + return f"CREATE OR REPLACE STREAMLIT {entity_id} MAIN_FILE='{self._entity_model.main_file}';" + + def get_usage_grant_sql(self, app_role: str, schema: Optional[str] = None): + entity_id = self.entity_id + streamlit_name = f"{schema}.{entity_id}" if schema else entity_id + return ( + f"GRANT USAGE ON STREAMLIT {streamlit_name} TO APPLICATION ROLE {app_role};" + ) diff --git a/src/snowflake/cli/_plugins/workspace/manager.py b/src/snowflake/cli/_plugins/workspace/manager.py index 25b56d542f..10d7fef9c7 100644 --- a/src/snowflake/cli/_plugins/workspace/manager.py +++ b/src/snowflake/cli/_plugins/workspace/manager.py @@ -1,3 +1,4 @@ +from functools import cached_property from pathlib import Path from typing import Dict @@ -58,10 +59,7 @@ def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs) """ entity = self.get_entity(entity_id) if entity.supports(action): - action_ctx = ActionContext( - get_entity=self.get_entity, - ) - return entity.perform(action, action_ctx, *args, **kwargs) + return entity.perform(action, self.action_ctx, *args, **kwargs) else: raise ValueError(f'This entity type does not support "{action.value}"') @@ -69,6 +67,12 @@ def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs) def project_root(self) -> Path: return self._project_root + @cached_property + def action_ctx(self) -> ActionContext: + return ActionContext( + get_entity=self.get_entity, + ) + def _get_default_role() -> str: role = default_role() diff --git a/src/snowflake/cli/api/entities/common.py b/src/snowflake/cli/api/entities/common.py index c7bd6bfb0f..c444dc0897 100644 --- a/src/snowflake/cli/api/entities/common.py +++ b/src/snowflake/cli/api/entities/common.py @@ -63,6 +63,10 @@ def __init__(self, entity_model: T, workspace_ctx: WorkspaceContext): self._entity_model = entity_model self._workspace_ctx = workspace_ctx + @property + def entity_id(self): + return self._entity_model.entity_id + @classmethod def get_entity_model_type(cls) -> Type[T]: """ diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index cda6ecd8eb..2b0f4f5cf0 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -15,12 +15,17 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from types import UnionType +from typing import Any, Dict, List, Optional, Union, get_args, get_origin from packaging.version import Version from pydantic import Field, ValidationError, field_validator, model_validator from pydantic_core.core_schema import ValidationInfo from snowflake.cli._plugins.nativeapp.entities.application import ApplicationEntityModel +from snowflake.cli._plugins.nativeapp.entities.application_package import ( + ApplicationPackageChildrenTypes, + ApplicationPackageEntityModel, +) from snowflake.cli.api.project.errors import SchemaValidationError from snowflake.cli.api.project.schemas.entities.common import ( TargetField, @@ -159,6 +164,12 @@ def _validate_single_entity( target_object = entity.from_ target_type = target_object.get_type() cls._validate_target_field(target_key, target_type, entities) + elif entity.type == ApplicationPackageEntityModel.get_type(): + for child_entity in entity.children: + target_key = child_entity.target + cls._validate_target_field( + target_key, ApplicationPackageChildrenTypes, entities + ) @classmethod def _validate_target_field( @@ -168,11 +179,20 @@ def _validate_target_field( raise ValueError(f"No such target: {target_key}") # Validate the target type - actual_target_type = entities[target_key].__class__ - if target_type and target_type is not actual_target_type: - raise ValueError( - f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}" - ) + if target_type: + actual_target_type = entities[target_key].__class__ + if get_origin(target_type) in (Union, UnionType): + if actual_target_type not in get_args(target_type): + expected_types_str = ", ".join( + [t.__name__ for t in get_args(target_type)] + ) + raise ValueError( + f"Target type mismatch. Expected one of [{expected_types_str}], got {actual_target_type.__name__}" + ) + elif target_type is not actual_target_type: + raise ValueError( + f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}" + ) @model_validator(mode="before") @classmethod @@ -200,6 +220,7 @@ def apply_mixins(cls, data: Dict, info: ValidationInfo) -> Dict: mixin_defs=data["mixins"], ) entities[entity_name] = merged_values + return data @classmethod diff --git a/tests/nativeapp/test_application_package_entity.py b/tests/nativeapp/test_application_package_entity.py index 2a0e632a6d..0772a5ada0 100644 --- a/tests/nativeapp/test_application_package_entity.py +++ b/tests/nativeapp/test_application_package_entity.py @@ -45,8 +45,8 @@ ) -def _get_app_pkg_entity(project_directory): - with project_directory("workspaces_simple") as project_root: +def _get_app_pkg_entity(project_directory, test_dir="workspaces_simple"): + with project_directory(test_dir) as project_root: with Path(project_root / "snowflake.yml").open() as definition_file_path: project_definition = yaml.safe_load(definition_file_path) model = ApplicationPackageEntityModel( diff --git a/tests/nativeapp/test_children.py b/tests/nativeapp/test_children.py new file mode 100644 index 0000000000..fca85666e3 --- /dev/null +++ b/tests/nativeapp/test_children.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from pathlib import Path +from textwrap import dedent + +import pytest +import yaml +from snowflake.cli._plugins.nativeapp.entities.application_package import ( + ApplicationPackageEntityModel, +) +from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag +from snowflake.cli._plugins.streamlit.streamlit_entity import StreamlitEntity +from snowflake.cli._plugins.workspace.context import ActionContext +from snowflake.cli._plugins.workspace.manager import WorkspaceManager +from snowflake.cli.api.project.errors import SchemaValidationError +from snowflake.cli.api.project.schemas.project_definition import ( + DefinitionV20, +) + +from tests.testing_utils.mock_config import mock_config_key + + +def _get_app_pkg_entity(project_directory): + with project_directory("napp_children") as project_root: + with Path(project_root / "snowflake.yml").open() as definition_file_path: + project_definition = DefinitionV20(**yaml.safe_load(definition_file_path)) + wm = WorkspaceManager( + project_definition=project_definition, + project_root=project_root, + ) + pkg_entity = wm.get_entity("pkg") + streamlit_entity = wm.get_entity("my_streamlit") + action_ctx = ActionContext( + get_entity=lambda entity_id: streamlit_entity, + ) + return ( + pkg_entity, + action_ctx, + ) + + +def test_children_feature_flag_is_disabled(): + assert FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled() == False + with pytest.raises(AttributeError) as err: + ApplicationPackageEntityModel( + **{"type": "application package", "children": [{"target": "some_child"}]} + ) + assert str(err.value) == "Application package children are not supported yet" + + +def test_invalid_children_type(): + with mock_config_key("enable_native_app_children", True): + definition_input = { + "definition_version": "2", + "entities": { + "pkg": { + "type": "application package", + "artifacts": [], + "children": [ + { + # packages cannot contain other packages as children + "target": "pkg2" + } + ], + }, + "pkg2": { + "type": "application package", + "artifacts": [], + }, + }, + } + with pytest.raises(SchemaValidationError) as err: + DefinitionV20(**definition_input) + assert "Target type mismatch" in str(err.value) + + +def test_invalid_children_target(): + with mock_config_key("enable_native_app_children", True): + definition_input = { + "definition_version": "2", + "entities": { + "pkg": { + "type": "application package", + "artifacts": [], + "children": [ + { + # no such entity + "target": "sl" + } + ], + }, + }, + } + with pytest.raises(SchemaValidationError) as err: + DefinitionV20(**definition_input) + assert "No such target: sl" in str(err.value) + + +def test_valid_children(): + with mock_config_key("enable_native_app_children", True): + definition_input = { + "definition_version": "2", + "entities": { + "pkg": { + "type": "application package", + "artifacts": [], + "children": [{"target": "sl"}], + }, + "sl": {"type": "streamlit", "identifier": "my_streamlit"}, + }, + } + project_definition = DefinitionV20(**definition_input) + wm = WorkspaceManager( + project_definition=project_definition, + project_root="", + ) + child_entity_id = project_definition.entities["pkg"].children[0] + child_entity = wm.get_entity(child_entity_id.target) + assert child_entity.__class__ == StreamlitEntity + + +def test_children_bundle_with_custom_dir(project_directory): + with mock_config_key("enable_native_app_children", True): + app_pkg, action_ctx = _get_app_pkg_entity(project_directory) + bundle_result = app_pkg.action_bundle(action_ctx) + deploy_root = bundle_result.deploy_root() + + # Application package artifacts + assert (deploy_root / "README.md").exists() + assert (deploy_root / "manifest.yml").exists() + assert (deploy_root / "setup_script.sql").exists() + + # Child artifacts + assert ( + deploy_root / "_entities" / "my_streamlit" / "streamlit_app.py" + ).exists() + + # Generated setup script section + with open(deploy_root / "setup_script.sql", "r") as f: + setup_script_content = f.read() + custom_dir_path = Path("_entities", "my_streamlit") + assert setup_script_content.endswith( + dedent( + f""" + -- AUTO GENERATED CHILDREN SECTION + CREATE OR REPLACE STREAMLIT v_schema.my_streamlit FROM '{custom_dir_path}' MAIN_FILE='streamlit_app.py'; + CREATE APPLICATION ROLE IF NOT EXISTS my_app_role; + GRANT USAGE ON SCHEMA v_schema TO APPLICATION ROLE my_app_role; + GRANT USAGE ON STREAMLIT v_schema.my_streamlit TO APPLICATION ROLE my_app_role; + """ + ) + ) diff --git a/tests/nativeapp/test_manager.py b/tests/nativeapp/test_manager.py index 57cafe7a07..c61467c044 100644 --- a/tests/nativeapp/test_manager.py +++ b/tests/nativeapp/test_manager.py @@ -1376,6 +1376,7 @@ def test_validate_use_scratch_stage(mock_execute, mock_deploy, temp_dir, mock_cu pd = wm._project_definition # noqa: SLF001 pkg_model: ApplicationPackageEntityModel = pd.entities["app_pkg"] mock_deploy.assert_called_with( + action_ctx=wm.action_ctx, bundle_map=None, prune=True, recursive=True, @@ -1452,6 +1453,7 @@ def test_validate_failing_drops_scratch_stage( pd = wm._project_definition # noqa: SLF001 pkg_model: ApplicationPackageEntityModel = pd.entities["app_pkg"] mock_deploy.assert_called_with( + action_ctx=wm.action_ctx, bundle_map=None, prune=True, recursive=True, @@ -1511,7 +1513,10 @@ def test_validate_raw_returns_data(mock_execute, temp_dir, mock_cursor): pkg = wm.get_entity("app_pkg") assert ( pkg.get_validation_result( - use_scratch_stage=False, interactive=False, force=True + action_ctx=wm.action_ctx, + use_scratch_stage=False, + interactive=False, + force=True, ) == failure_data ) diff --git a/tests/streamlit/test_streamlit_entity.py b/tests/streamlit/test_streamlit_entity.py new file mode 100644 index 0000000000..315e34b8e5 --- /dev/null +++ b/tests/streamlit/test_streamlit_entity.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from snowflake.cli._plugins.streamlit.streamlit_entity import ( + StreamlitEntity, +) +from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( + StreamlitEntityModel, +) +from snowflake.cli._plugins.workspace.context import WorkspaceContext +from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.project.definition_manager import DefinitionManager + +from tests.testing_utils.mock_config import mock_config_key + + +def test_cannot_instantiate_without_feature_flag(): + with pytest.raises(NotImplementedError) as err: + StreamlitEntity() + assert str(err.value) == "Streamlit entity is not implemented yet" + + +def test_nativeapp_children_interface(temp_dir): + with mock_config_key("enable_native_app_children", True): + dm = DefinitionManager() + ctx = WorkspaceContext( + console=cc, + project_root=dm.project_root, + get_default_role=lambda: "mock_role", + get_default_warehouse=lambda: "mock_warehouse", + ) + main_file = "main.py" + (Path(temp_dir) / main_file).touch() + model = StreamlitEntityModel( + type="streamlit", + main_file=main_file, + artifacts=[main_file], + ) + sl = StreamlitEntity(model, ctx) + + sl.bundle() + bundle_artifact = Path(temp_dir) / "output" / "deploy" / main_file + deploy_sql_str = sl.get_deploy_sql() + grant_sql_str = sl.get_usage_grant_sql(app_role="app_role") + + assert bundle_artifact.exists() + assert deploy_sql_str == "CREATE OR REPLACE STREAMLIT None MAIN_FILE='main.py';" + assert ( + grant_sql_str + == "GRANT USAGE ON STREAMLIT None TO APPLICATION ROLE app_role;" + ) diff --git a/tests/test_data/projects/napp_children/app/README.md b/tests/test_data/projects/napp_children/app/README.md new file mode 100644 index 0000000000..7e59600739 --- /dev/null +++ b/tests/test_data/projects/napp_children/app/README.md @@ -0,0 +1 @@ +# README diff --git a/tests/test_data/projects/napp_children/app/manifest.yml b/tests/test_data/projects/napp_children/app/manifest.yml new file mode 100644 index 0000000000..0b8b9b892c --- /dev/null +++ b/tests/test_data/projects/napp_children/app/manifest.yml @@ -0,0 +1,7 @@ +# This is the v2 version of the napp_init_v1 project + +manifest_version: 1 + +artifacts: + setup_script: setup_script.sql + readme: README.md diff --git a/tests/test_data/projects/napp_children/app/setup_script.sql b/tests/test_data/projects/napp_children/app/setup_script.sql new file mode 100644 index 0000000000..ade6eccbd6 --- /dev/null +++ b/tests/test_data/projects/napp_children/app/setup_script.sql @@ -0,0 +1,3 @@ +CREATE OR ALTER VERSIONED SCHEMA v_schema; +CREATE APPLICATION ROLE IF NOT EXISTS my_app_role; +GRANT USAGE ON SCHEMA v_schema TO APPLICATION ROLE my_app_role; diff --git a/tests/test_data/projects/napp_children/snowflake.yml b/tests/test_data/projects/napp_children/snowflake.yml new file mode 100644 index 0000000000..52667820df --- /dev/null +++ b/tests/test_data/projects/napp_children/snowflake.yml @@ -0,0 +1,21 @@ +definition_version: 2 +entities: + pkg: + type: application package + identifier: my_pkg + artifacts: + - src: app/* + dest: ./ + children_artifacts_dir: _entities + children: + - target: my_streamlit + identifier: + schema: v_schema + ensure_usable_by: + application_roles: ["my_app_role"] + + my_streamlit: + type: streamlit + main_file: streamlit_app.py + artifacts: + - streamlit_app.py diff --git a/tests/test_data/projects/napp_children/streamlit_app.py b/tests/test_data/projects/napp_children/streamlit_app.py new file mode 100644 index 0000000000..45c8ad3822 --- /dev/null +++ b/tests/test_data/projects/napp_children/streamlit_app.py @@ -0,0 +1,20 @@ +from http.client import HTTPSConnection + +import _snowflake +import streamlit as st + + +def get_secret_value(): + return _snowflake.get_generic_secret_string("generic_secret") + + +def send_request(): + host = "docs.snowflake.com" + conn = HTTPSConnection(host) + conn.request("GET", "/") + response = conn.getresponse() + st.success(f"Response status: {response.status}") + + +st.title(f"Example streamlit app.") +st.button("Send request", on_click=send_request) From b22b5b4c3f53f32ba6ccf540bfd67cc8910a02cf Mon Sep 17 00:00:00 2001 From: Abby Shen Date: Thu, 12 Dec 2024 15:56:12 -0800 Subject: [PATCH 3/5] SNOW-1846319: Add Feature Flag for log streaming (#1925) * Add Feature Flag for log streaming * New error message * update release notes * fix comment and use new error message --- RELEASE-NOTES.md | 1 + .../cli/_plugins/spcs/services/commands.py | 11 +++- src/snowflake/cli/api/exceptions.py | 10 ++++ src/snowflake/cli/api/feature_flags.py | 1 + tests/spcs/test_services.py | 50 ++++++++++++++++++- 5 files changed, 70 insertions(+), 3 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index d2b0da52b7..1f002a8c12 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -30,6 +30,7 @@ * Fixed crashes with older x86_64 Intel CPUs. * Fixed inability to add patches to lowercase quoted versions * Fixes label being set to blank instead of None when not provided. +* Added a feature flag `ENABLE_SPCS_LOG_STREAMING` to control the rollout of the log streaming feature # v3.2.0 diff --git a/src/snowflake/cli/_plugins/spcs/services/commands.py b/src/snowflake/cli/_plugins/spcs/services/commands.py index 063a2e3e11..a5d9fdcba0 100644 --- a/src/snowflake/cli/_plugins/spcs/services/commands.py +++ b/src/snowflake/cli/_plugins/spcs/services/commands.py @@ -37,7 +37,11 @@ ) from snowflake.cli.api.commands.snow_typer import SnowTyperFactory from snowflake.cli.api.constants import ObjectType -from snowflake.cli.api.exceptions import IncompatibleParametersError +from snowflake.cli.api.exceptions import ( + FeatureNotEnabledError, + IncompatibleParametersError, +) +from snowflake.cli.api.feature_flags import FeatureFlag from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.output.types import ( CommandResult, @@ -250,6 +254,11 @@ def logs( Retrieves local logs from a service container. """ if follow: + if FeatureFlag.ENABLE_SPCS_LOG_STREAMING.is_disabled(): + raise FeatureNotEnabledError( + "ENABLE_SPCS_LOG_STREAMING", + "Streaming logs from spcs containers is disabled.", + ) if num_lines != DEFAULT_NUM_LINES: raise IncompatibleParametersError(["--follow", "--num-lines"]) if previous_logs: diff --git a/src/snowflake/cli/api/exceptions.py b/src/snowflake/cli/api/exceptions.py index fad1e97c10..2aac4e9608 100644 --- a/src/snowflake/cli/api/exceptions.py +++ b/src/snowflake/cli/api/exceptions.py @@ -229,3 +229,13 @@ def __init__(self, show_obj_query: str): super().__init__( f"Received multiple rows from result of SQL statement: {show_obj_query}. Usage of 'show_specific_object' may not be properly scoped." ) + + +class FeatureNotEnabledError(ClickException): + def __init__(self, feature_name: str, custom_message: Optional[str] = None): + base_message = f"To enable it, add '{feature_name} = true' to '[cli.features]' section of your configuration file." + if custom_message: + message = f"{custom_message} {base_message}" + else: + message = base_message + super().__init__(message) diff --git a/src/snowflake/cli/api/feature_flags.py b/src/snowflake/cli/api/feature_flags.py index d504056e02..2a56458083 100644 --- a/src/snowflake/cli/api/feature_flags.py +++ b/src/snowflake/cli/api/feature_flags.py @@ -63,3 +63,4 @@ class FeatureFlag(FeatureFlagMixin): ENABLE_STREAMLIT_VERSIONED_STAGE = BooleanFlag( "ENABLE_STREAMLIT_VERSIONED_STAGE", False ) + ENABLE_SPCS_LOG_STREAMING = BooleanFlag("ENABLE_SPCS_LOG_STREAMING", False) diff --git a/tests/spcs/test_services.py b/tests/spcs/test_services.py index 3c62648020..b48dedd1ce 100644 --- a/tests/spcs/test_services.py +++ b/tests/spcs/test_services.py @@ -605,7 +605,11 @@ def test_stream_logs_with_include_timestamps_true(mock_sleep, mock_logs): @patch("snowflake.cli._plugins.spcs.services.manager.ServiceManager.execute_query") -def test_logs_incompatible_flags(mock_execute_query, runner): +@patch( + "snowflake.cli.api.feature_flags.FeatureFlag.ENABLE_SPCS_LOG_STREAMING.is_disabled" +) +def test_logs_incompatible_flags(mock_is_disabled, mock_execute_query, runner): + mock_is_disabled.return_value = False result = runner.invoke( [ "spcs", @@ -628,7 +632,13 @@ def test_logs_incompatible_flags(mock_execute_query, runner): @patch("snowflake.cli._plugins.spcs.services.manager.ServiceManager.execute_query") -def test_logs_incompatible_flags_follow_previous_logs(mock_execute_query, runner): +@patch( + "snowflake.cli.api.feature_flags.FeatureFlag.ENABLE_SPCS_LOG_STREAMING.is_disabled" +) +def test_logs_incompatible_flags_follow_previous_logs( + mock_is_disabled, mock_execute_query, runner +): + mock_is_disabled.return_value = False result = runner.invoke( [ "spcs", @@ -653,6 +663,42 @@ def test_logs_incompatible_flags_follow_previous_logs(mock_execute_query, runner ) +@patch( + "snowflake.cli.api.feature_flags.FeatureFlag.ENABLE_SPCS_LOG_STREAMING.is_disabled" +) +def test_logs_streaming_disabled(mock_is_disabled, runner): + mock_is_disabled.return_value = True + result = runner.invoke( + [ + "spcs", + "service", + "logs", + "test_service", + "--container-name", + "test_container", + "--instance-id", + "0", + "--follow", + "--num-lines", + "100", + ] + ) + assert ( + result.exit_code != 0 + ), "Expected a non-zero exit code due to feature flag disabled" + + expected_output = ( + "+- Error ----------------------------------------------------------------------+\n" + "| Streaming logs from spcs containers is disabled. To enable it, add |\n" + "| 'ENABLE_SPCS_LOG_STREAMING = true' to '[cli.features]' section of your |\n" + "| configuration file. |\n" + "+------------------------------------------------------------------------------+\n" + ) + assert ( + result.output == expected_output + ), f"Expected formatted output not found: {result.output}" + + def test_read_yaml(other_directory): tmp_dir = Path(other_directory) spec_path = tmp_dir / "spec.yml" From 24809a8ac434e21a4e3f6fe76b8f1c176e4e827b Mon Sep 17 00:00:00 2001 From: Michel El Nacouzi Date: Fri, 13 Dec 2024 08:16:30 -0500 Subject: [PATCH 4/5] Add support for snow app run from release channel (#1951) --- RELEASE-NOTES.md | 1 + .../cli/_plugins/nativeapp/commands.py | 7 + .../nativeapp/entities/application.py | 46 ++ .../nativeapp/release_directive/commands.py | 2 +- .../cli/_plugins/nativeapp/sf_sql_facade.py | 69 ++- tests/__snapshots__/test_help_messages.ambr | 12 + tests/nativeapp/test_event_sharing.py | 18 +- tests/nativeapp/test_run_processor.py | 534 ++++++++++++++++++ tests/nativeapp/test_sf_sql_facade.py | 155 ++++- tests_integration/nativeapp/test_init_run.py | 72 +++ 10 files changed, 890 insertions(+), 26 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 1f002a8c12..f158788b7c 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -25,6 +25,7 @@ * `snow app release-directive unset` * Add support for release channels feature in native app version creation/drop. * `snow app version create` now returns version, patch, and label in JSON format. +* Add ability to specify release channel when creating application instance from release directive: `snow app run --from-release-directive --channel=` ## Fixes and improvements * Fixed crashes with older x86_64 Intel CPUs. diff --git a/src/snowflake/cli/_plugins/nativeapp/commands.py b/src/snowflake/cli/_plugins/nativeapp/commands.py index 411067ac5e..beb62bfee7 100644 --- a/src/snowflake/cli/_plugins/nativeapp/commands.py +++ b/src/snowflake/cli/_plugins/nativeapp/commands.py @@ -151,6 +151,12 @@ def app_run( The command fails if no release directive exists for your Snowflake account for a given application package, which is determined from the project definition file. Default: unset.""", is_flag=True, ), + channel: str = typer.Option( + None, + show_default=False, + help=f"""The name of the release channel to use when creating or upgrading an application instance from a release directive. + Requires the `--from-release-directive` flag to be set. If unset, the default channel will be used.""", + ), interactive: bool = InteractiveOption, force: Optional[bool] = ForceOption, validate: bool = ValidateOption, @@ -179,6 +185,7 @@ def app_run( paths=[], interactive=interactive, force=force, + release_channel=channel, ) app = ws.get_entity(app_id) return MessageResult( diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application.py b/src/snowflake/cli/_plugins/nativeapp/entities/application.py index 5057f86c60..bd2096ba1f 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application.py @@ -26,6 +26,7 @@ from snowflake.cli._plugins.nativeapp.constants import ( ALLOWED_SPECIAL_COMMENTS, COMMENT_COL, + DEFAULT_CHANNEL, OWNER_COL, ) from snowflake.cli._plugins.nativeapp.entities.application_package import ( @@ -85,6 +86,8 @@ append_test_resource_suffix, extract_schema, identifier_for_url, + identifier_in_list, + same_identifiers, to_identifier, unquote_identifier, ) @@ -329,6 +332,7 @@ def action_deploy( prune: bool, recursive: bool, paths: List[Path], + release_channel: Optional[str] = None, validate: bool = ValidateOption, stage_fqn: Optional[str] = None, interactive: bool = InteractiveOption, @@ -356,15 +360,25 @@ def action_deploy( # same-account release directive if from_release_directive: + release_channel = _get_verified_release_channel( + package_entity, release_channel + ) + self.create_or_upgrade_app( package=package_entity, stage_fqn=stage_fqn, install_method=SameAccountInstallMethod.release_directive(), + release_channel=release_channel, policy=policy, interactive=interactive, ) return + if release_channel: + raise UsageError( + f"Release channel is only supported when --from-release-directive is used." + ) + # versioned dev if version: try: @@ -603,6 +617,7 @@ def _upgrade_app( event_sharing: EventSharingHandler, policy: PolicyBase, interactive: bool, + release_channel: Optional[str] = None, ) -> list[tuple[str]] | None: self.console.step(f"Upgrading existing application object {self.name}.") @@ -613,6 +628,7 @@ def _upgrade_app( stage_fqn=stage_fqn, debug_mode=self.debug, should_authorize_event_sharing=event_sharing.should_authorize_event_sharing(), + release_channel=release_channel, role=self.role, warehouse=self.warehouse, ) @@ -627,6 +643,7 @@ def _create_app( install_method: SameAccountInstallMethod, event_sharing: EventSharingHandler, package: ApplicationPackageEntity, + release_channel: Optional[str] = None, ) -> list[tuple[str]]: self.console.step(f"Creating new application object {self.name} in account.") @@ -665,6 +682,7 @@ def _create_app( should_authorize_event_sharing=event_sharing.should_authorize_event_sharing(), role=self.role, warehouse=self.warehouse, + release_channel=release_channel, ) @span("update_app_object") @@ -675,6 +693,7 @@ def create_or_upgrade_app( install_method: SameAccountInstallMethod, policy: PolicyBase, interactive: bool, + release_channel: Optional[str] = None, ): event_sharing = EventSharingHandler( telemetry_definition=self.telemetry, @@ -699,6 +718,7 @@ def create_or_upgrade_app( event_sharing=event_sharing, policy=policy, interactive=interactive, + release_channel=release_channel, ) # 3. If no existing application found, or we performed a drop before the upgrade, we proceed to create @@ -708,6 +728,7 @@ def create_or_upgrade_app( install_method=install_method, event_sharing=event_sharing, package=package, + release_channel=release_channel, ) print_messages(self.console, create_or_upgrade_result) @@ -1004,3 +1025,28 @@ def _application_objects_to_str( def _application_object_to_str(obj: ApplicationOwnedObject) -> str: return f"({obj['type']}) {obj['name']}" + + +def _get_verified_release_channel( + package_entity: ApplicationPackageEntity, + release_channel: Optional[str], +) -> Optional[str]: + release_channel = release_channel or DEFAULT_CHANNEL + available_release_channels = get_snowflake_facade().show_release_channels( + package_entity.name, role=package_entity.role + ) + if available_release_channels: + release_channel_names = [c["name"] for c in available_release_channels] + if not identifier_in_list(release_channel, release_channel_names): + raise UsageError( + f"Release channel '{release_channel}' is not available for application package {package_entity.name}. Available release channels: ({', '.join(release_channel_names)})." + ) + else: + if same_identifiers(release_channel, DEFAULT_CHANNEL): + return None + else: + raise UsageError( + f"Release channels are not enabled for application package {package_entity.name}." + ) + + return release_channel diff --git a/src/snowflake/cli/_plugins/nativeapp/release_directive/commands.py b/src/snowflake/cli/_plugins/nativeapp/release_directive/commands.py index 2573bfdfe9..17cd351f65 100644 --- a/src/snowflake/cli/_plugins/nativeapp/release_directive/commands.py +++ b/src/snowflake/cli/_plugins/nativeapp/release_directive/commands.py @@ -140,7 +140,7 @@ def release_directive_unset( show_default=False, help="Name of the release directive", ), - channel: Optional[str] = typer.Option( + channel: str = typer.Option( DEFAULT_CHANNEL, help="Name of the release channel to use", ), diff --git a/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py b/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py index f3c164f4a8..16a8d43102 100644 --- a/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py +++ b/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py @@ -15,12 +15,15 @@ import logging from contextlib import contextmanager +from functools import cache from textwrap import dedent from typing import Any, Dict, List from snowflake.cli._plugins.connection.util import UIParameter, get_ui_parameter from snowflake.cli._plugins.nativeapp.constants import ( AUTHORIZE_TELEMETRY_COL, + CHANNEL_COL, + DEFAULT_CHANNEL, DEFAULT_DIRECTIVE, NAME_COL, SPECIAL_COMMENT, @@ -637,6 +640,7 @@ def upgrade_application( warehouse: str, debug_mode: bool | None, should_authorize_event_sharing: bool | None, + release_channel: str | None = None, ) -> list[tuple[str]]: """ Upgrades an application object using the provided clauses @@ -648,17 +652,36 @@ def upgrade_application( @param warehouse: Warehouse which is required to create an application object @param debug_mode: Whether to enable debug mode; None means not explicitly enabled or disabled @param should_authorize_event_sharing: Whether to enable event sharing; None means not explicitly enabled or disabled + @param release_channel [Optional]: Release channel to use when upgrading the application """ + + name = to_identifier(name) + release_channel = to_identifier(release_channel) if release_channel else None + install_method.ensure_app_usable( app_name=name, app_role=role, show_app_row=self.get_existing_app_info(name, role), ) + # If all the above checks are in order, proceed to upgrade + @cache # only cache within the scope of this method + def get_app_properties(): + return self.get_app_properties(name, role) + with self._use_role_optional(role), self._use_warehouse_optional(warehouse): try: using_clause = install_method.using_clause(stage_fqn) + if release_channel: + current_release_channel = get_app_properties().get( + CHANNEL_COL, DEFAULT_CHANNEL + ) + if not same_identifiers(release_channel, current_release_channel): + raise UpgradeApplicationRestrictionError( + f"Application {name} is currently on release channel {current_release_channel}. Cannot upgrade to release channel {release_channel}." + ) + upgrade_cursor = self._sql_executor.execute_query( f"alter application {name} upgrade {using_clause}", ) @@ -669,6 +692,9 @@ def upgrade_application( self._sql_executor.execute_query( f"alter application {name} set debug_mode = {debug_mode}" ) + + except UpgradeApplicationRestrictionError as err: + raise err except ProgrammingError as err: if err.errno in UPGRADE_RESTRICTION_CODES: raise UpgradeApplicationRestrictionError(err.msg) from err @@ -687,7 +713,7 @@ def upgrade_application( # Only update event sharing if the current value is different as the one we want to set if should_authorize_event_sharing is not None: current_authorize_event_sharing = ( - self.get_app_properties(name, role) + get_app_properties() .get(AUTHORIZE_TELEMETRY_COL, "false") .lower() == "true" @@ -733,6 +759,7 @@ def create_application( warehouse: str, debug_mode: bool | None, should_authorize_event_sharing: bool | None, + release_channel: str | None = None, ) -> list[tuple[str]]: """ Creates a new application object using an application package, @@ -746,7 +773,11 @@ def create_application( @param warehouse: Warehouse which is required to create an application object @param debug_mode: Whether to enable debug mode; None means not explicitly enabled or disabled @param should_authorize_event_sharing: Whether to enable event sharing; None means not explicitly enabled or disabled + @param release_channel [Optional]: Release channel to use when creating the application """ + package_name = to_identifier(package_name) + name = to_identifier(name) + release_channel = to_identifier(release_channel) if release_channel else None # by default, applications are created in debug mode when possible; # this can be overridden in the project definition @@ -761,18 +792,28 @@ def create_application( "Setting AUTHORIZE_TELEMETRY_EVENT_SHARING to %s", should_authorize_event_sharing, ) - authorize_telemetry_clause = f" AUTHORIZE_TELEMETRY_EVENT_SHARING = {str(should_authorize_event_sharing).upper()}" + authorize_telemetry_clause = f"AUTHORIZE_TELEMETRY_EVENT_SHARING = {str(should_authorize_event_sharing).upper()}" using_clause = install_method.using_clause(stage_fqn) + release_channel_clause = ( + f"using release channel {release_channel}" if release_channel else "" + ) + with self._use_role_optional(role), self._use_warehouse_optional(warehouse): try: create_cursor = self._sql_executor.execute_query( dedent( - f"""\ - create application {name} - from application package {package_name} {using_clause} {debug_mode_clause}{authorize_telemetry_clause} - comment = {SPECIAL_COMMENT} - """ + _strip_empty_lines( + f"""\ + create application {name} + from application package {package_name} + {using_clause} + {release_channel_clause} + {debug_mode_clause} + {authorize_telemetry_clause} + comment = {SPECIAL_COMMENT} + """ + ) ), ) except ProgrammingError as err: @@ -823,10 +864,10 @@ def create_application_package( dedent( _strip_empty_lines( f"""\ - create application package {package_name} - comment = {SPECIAL_COMMENT} - distribution = {distribution} - {enable_release_channels_clause} + create application package {package_name} + comment = {SPECIAL_COMMENT} + distribution = {distribution} + {enable_release_channels_clause} """ ) ) @@ -862,9 +903,9 @@ def alter_application_package_properties( self._sql_executor.execute_query( dedent( f"""\ - alter application package {package_name} - set enable_release_channels = {str(enable_release_channels).lower()} - """ + alter application package {package_name} + set enable_release_channels = {str(enable_release_channels).lower()} + """ ) ) except ProgrammingError as err: diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index f78ec957d5..e367fb918f 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -983,6 +983,18 @@ | determined from the | | project definition | | file. Default: unset. | + | --channel TEXT The name of the | + | release channel to | + | use when creating or | + | upgrading an | + | application instance | + | from a release | + | directive. Requires | + | the | + | --from-release-direc… | + | flag to be set. If | + | unset, the default | + | channel will be used. | | --interactive --no-interactive When enabled, this | | option displays | | prompts even if the | diff --git a/tests/nativeapp/test_event_sharing.py b/tests/nativeapp/test_event_sharing.py index 8da67f6c29..f6b1169a77 100644 --- a/tests/nativeapp/test_event_sharing.py +++ b/tests/nativeapp/test_event_sharing.py @@ -294,14 +294,17 @@ def _setup_mocks_for_create_app( mock.call( name=DEFAULT_APP_ID, package_name=DEFAULT_PKG_ID, - install_method=SameAccountInstallMethod.release_directive() - if is_prod - else SameAccountInstallMethod.unversioned_dev(), + install_method=( + SameAccountInstallMethod.release_directive() + if is_prod + else SameAccountInstallMethod.unversioned_dev() + ), stage_fqn=DEFAULT_STAGE_FQN, debug_mode=None, should_authorize_event_sharing=expected_authorize_telemetry_flag, role="app_role", warehouse="app_warehouse", + release_channel=None, ) ] @@ -397,14 +400,17 @@ def _setup_mocks_for_upgrade_app( mock_sql_facade_upgrade_application_expected = [ mock.call( name=DEFAULT_APP_ID, - install_method=SameAccountInstallMethod.release_directive() - if is_prod - else SameAccountInstallMethod.unversioned_dev(), + install_method=( + SameAccountInstallMethod.release_directive() + if is_prod + else SameAccountInstallMethod.unversioned_dev() + ), stage_fqn=DEFAULT_STAGE_FQN, debug_mode=None, should_authorize_event_sharing=expected_authorize_telemetry_flag, role="app_role", warehouse="app_warehouse", + release_channel=None, ) ] return [*mock_execute_query_expected, *mock_sql_facade_upgrade_application_expected] diff --git a/tests/nativeapp/test_run_processor.py b/tests/nativeapp/test_run_processor.py index 8d2bb4c545..3b478bad2e 100644 --- a/tests/nativeapp/test_run_processor.py +++ b/tests/nativeapp/test_run_processor.py @@ -84,6 +84,7 @@ SQL_FACADE_GET_EVENT_DEFINITIONS, SQL_FACADE_GET_EXISTING_APP_INFO, SQL_FACADE_GRANT_PRIVILEGES_TO_ROLE, + SQL_FACADE_SHOW_RELEASE_CHANNELS, SQL_FACADE_UPGRADE_APPLICATION, TYPER_CONFIRM, mock_execute_helper, @@ -274,6 +275,7 @@ def test_create_dev_app_w_warehouse_access_exception( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) assert mock_sql_facade_grant_privileges_to_role.mock_calls == [ mock.call( @@ -345,6 +347,7 @@ def test_create_dev_app_create_new_w_no_additional_privileges( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] mock_sql_facade_get_event_definitions.assert_called_once_with( @@ -418,6 +421,7 @@ def test_create_or_upgrade_dev_app_with_warning( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] mock_sql_facade_upgrade_application.assert_not_called() @@ -432,6 +436,7 @@ def test_create_or_upgrade_dev_app_with_warning( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] @@ -486,6 +491,7 @@ def test_create_dev_app_create_new_with_additional_privileges( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] assert mock_sql_facade_grant_privileges_to_role.mock_calls == [ @@ -563,6 +569,7 @@ def test_create_dev_app_create_new_w_missing_warehouse_exception( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] @@ -674,6 +681,7 @@ def test_create_dev_app_incorrect_owner( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] @@ -727,6 +735,7 @@ def test_create_dev_app_no_diff_changes( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] mock_sql_facade_get_event_definitions.assert_called_once_with( @@ -783,6 +792,7 @@ def test_create_dev_app_w_diff_changes( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] mock_sql_facade_get_event_definitions.assert_called_once_with( @@ -907,6 +917,7 @@ def test_create_dev_app_create_new_quoted( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] mock_sql_facade_get_event_definitions.assert_called_once_with( @@ -964,6 +975,7 @@ def test_create_dev_app_create_new_quoted_override( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) mock_sql_facade_get_event_definitions.assert_called_once_with( '"My Application"', DEFAULT_ROLE @@ -1046,6 +1058,7 @@ def test_create_dev_app_recreate_app_when_orphaned( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] assert mock_sql_facade_create_application.mock_calls == [ @@ -1058,6 +1071,7 @@ def test_create_dev_app_recreate_app_when_orphaned( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] assert mock_sql_facade_grant_privileges_to_role.mock_calls == [ @@ -1185,6 +1199,7 @@ def test_create_dev_app_recreate_app_when_orphaned_requires_cascade( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] @@ -1198,6 +1213,7 @@ def test_create_dev_app_recreate_app_when_orphaned_requires_cascade( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] @@ -1322,6 +1338,7 @@ def test_create_dev_app_recreate_app_when_orphaned_requires_cascade_unknown_obje should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] assert mock_sql_facade_create_application.mock_calls == [ @@ -1334,6 +1351,7 @@ def test_create_dev_app_recreate_app_when_orphaned_requires_cascade_unknown_obje should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] assert mock_sql_facade_grant_privileges_to_role.mock_calls == [ @@ -1479,6 +1497,7 @@ def test_upgrade_app_incorrect_owner( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] @@ -1534,6 +1553,7 @@ def test_upgrade_app_succeeds( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) mock_sql_facade_get_event_definitions.assert_called_once_with( DEFAULT_APP_ID, DEFAULT_ROLE @@ -1593,6 +1613,7 @@ def test_upgrade_app_fails_generic_error( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] @@ -1674,6 +1695,7 @@ def test_upgrade_app_fails_upgrade_restriction_error( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] assert mock_execute.mock_calls == expected @@ -1754,6 +1776,7 @@ def test_versioned_app_upgrade_to_unversioned( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) mock_sql_facade_create_application.assert_called_with( @@ -1765,6 +1788,7 @@ def test_versioned_app_upgrade_to_unversioned( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) assert mock_sql_facade_grant_privileges_to_role.mock_calls == [ @@ -1873,6 +1897,7 @@ def test_upgrade_app_fails_drop_fails( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] @@ -1954,6 +1979,7 @@ def test_upgrade_app_recreate_app( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] assert mock_sql_facade_create_application.mock_calls == [ @@ -1966,6 +1992,7 @@ def test_upgrade_app_recreate_app( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] assert mock_sql_facade_grant_privileges_to_role.mock_calls == [ @@ -2135,6 +2162,7 @@ def test_upgrade_app_recreate_app_from_version( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] assert mock_sql_facade_create_application.mock_calls == [ @@ -2147,6 +2175,7 @@ def test_upgrade_app_recreate_app_from_version( should_authorize_event_sharing=None, role=DEFAULT_ROLE, warehouse=DEFAULT_WAREHOUSE, + release_channel=None, ) ] @@ -2179,6 +2208,511 @@ def test_upgrade_app_recreate_app_from_version( ) +@mock.patch( + APP_PACKAGE_ENTITY_GET_EXISTING_VERSION_INFO, + return_value={"key": "val"}, +) +@mock.patch(SQL_FACADE_CREATE_APPLICATION) +@mock.patch(SQL_FACADE_UPGRADE_APPLICATION) +@mock.patch(SQL_FACADE_GRANT_PRIVILEGES_TO_ROLE) +@mock.patch(SQL_FACADE_GET_EVENT_DEFINITIONS) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(SQL_FACADE_GET_EXISTING_APP_INFO) +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch( + f"snowflake.cli._plugins.nativeapp.policy.{TYPER_CONFIRM}", return_value=True +) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.NA_EVENT_SHARING_V2: False, + UIParameter.NA_ENFORCE_MANDATORY_FILTERS: False, + }, +) +@pytest.mark.parametrize("policy_param", [allow_always_policy, ask_always_policy]) +def test_run_app_from_release_directive_with_channel( + mock_param, + mock_conn, + mock_typer_confirm, + mock_show_release_channels, + mock_get_existing_app_info, + mock_execute, + mock_sql_facade_get_event_definitions, + mock_sql_facade_grant_privileges_to_role, + mock_sql_facade_upgrade_application, + mock_sql_facade_create_application, + mock_existing, + policy_param, + temp_dir, + mock_cursor, + mock_bundle_map, +): + mock_get_existing_app_info.return_value = { + "name": "myapp", + "comment": SPECIAL_COMMENT, + "owner": "app_role", + } + mock_show_release_channels.return_value = [{"name": "my_channel"}] + side_effects, expected = mock_execute_helper( + [ + ( + mock_cursor([("old_role",)], []), + mock.call("select current_role()"), + ), + (None, mock.call("use role app_role")), + (None, mock.call("drop application myapp")), + (None, mock.call("use role old_role")), + ] + ) + mock_conn.return_value = MockConnectionCtx() + mock_execute.side_effect = side_effects + mock_sql_facade_upgrade_application.side_effect = ( + UpgradeApplicationRestrictionError(DEFAULT_USER_INPUT_ERROR_MESSAGE) + ) + mock_sql_facade_create_application.side_effect = mock_cursor( + [[(DEFAULT_CREATE_SUCCESS_MESSAGE,)]], [] + ) + + setup_project_file(os.getcwd()) + + wm = _get_wm() + wm.perform_action( + "app_pkg", + EntityActions.BUNDLE, + ) + wm.perform_action( + "myapp", + EntityActions.DEPLOY, + from_release_directive=True, + prune=True, + recursive=True, + paths=[], + validate=False, + version="v1", + release_channel="my_channel", + ) + assert mock_execute.mock_calls == expected + assert mock_sql_facade_upgrade_application.mock_calls == [ + mock.call( + name=DEFAULT_APP_ID, + install_method=SameAccountInstallMethod.release_directive(), + stage_fqn=DEFAULT_STAGE_FQN, + debug_mode=True, + should_authorize_event_sharing=None, + role=DEFAULT_ROLE, + warehouse=DEFAULT_WAREHOUSE, + release_channel="my_channel", + ) + ] + assert mock_sql_facade_create_application.mock_calls == [ + mock.call( + name=DEFAULT_APP_ID, + package_name=DEFAULT_PKG_ID, + install_method=SameAccountInstallMethod.release_directive(), + stage_fqn=DEFAULT_STAGE_FQN, + debug_mode=True, + should_authorize_event_sharing=None, + role=DEFAULT_ROLE, + warehouse=DEFAULT_WAREHOUSE, + release_channel="my_channel", + ) + ] + + assert mock_sql_facade_grant_privileges_to_role.mock_calls == [ + mock.call( + privileges=["install", "develop"], + object_type=ObjectType.APPLICATION_PACKAGE, + object_identifier="app_pkg", + role_to_grant="app_role", + role_to_use="package_role", + ), + mock.call( + privileges=["usage"], + object_type=ObjectType.SCHEMA, + object_identifier="app_pkg.app_src", + role_to_grant="app_role", + role_to_use="package_role", + ), + mock.call( + privileges=["read"], + object_type=ObjectType.STAGE, + object_identifier="app_pkg.app_src.stage", + role_to_grant="app_role", + role_to_use="package_role", + ), + ] + + mock_sql_facade_get_event_definitions.assert_called_once_with( + DEFAULT_APP_ID, DEFAULT_ROLE + ) + + +@mock.patch( + APP_PACKAGE_ENTITY_GET_EXISTING_VERSION_INFO, + return_value={"key": "val"}, +) +@mock.patch(SQL_FACADE_CREATE_APPLICATION) +@mock.patch(SQL_FACADE_UPGRADE_APPLICATION) +@mock.patch(SQL_FACADE_GRANT_PRIVILEGES_TO_ROLE) +@mock.patch(SQL_FACADE_GET_EVENT_DEFINITIONS) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(SQL_FACADE_GET_EXISTING_APP_INFO) +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch( + f"snowflake.cli._plugins.nativeapp.policy.{TYPER_CONFIRM}", return_value=True +) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.NA_EVENT_SHARING_V2: False, + UIParameter.NA_ENFORCE_MANDATORY_FILTERS: False, + }, +) +def test_run_app_from_release_directive_with_channel_but_not_from_release_directive( + mock_param, + mock_conn, + mock_typer_confirm, + mock_show_release_channels, + mock_get_existing_app_info, + mock_execute, + mock_sql_facade_get_event_definitions, + mock_sql_facade_grant_privileges_to_role, + mock_sql_facade_upgrade_application, + mock_sql_facade_create_application, + mock_existing, + temp_dir, + mock_cursor, + mock_bundle_map, +): + mock_get_existing_app_info.return_value = { + "name": "myapp", + "comment": SPECIAL_COMMENT, + "owner": "app_role", + } + mock_show_release_channels.return_value = [] + mock_conn.return_value = MockConnectionCtx() + mock_sql_facade_upgrade_application.side_effect = ( + UpgradeApplicationRestrictionError(DEFAULT_USER_INPUT_ERROR_MESSAGE) + ) + mock_sql_facade_create_application.side_effect = mock_cursor( + [[(DEFAULT_CREATE_SUCCESS_MESSAGE,)]], [] + ) + + setup_project_file(os.getcwd()) + + wm = _get_wm() + wm.perform_action( + "app_pkg", + EntityActions.BUNDLE, + ) + with pytest.raises(UsageError) as err: + wm.perform_action( + "myapp", + EntityActions.DEPLOY, + from_release_directive=False, + prune=True, + recursive=True, + paths=[], + validate=False, + version="v1", + release_channel="my_channel", + ) + + assert ( + str(err.value) + == "Release channel is only supported when --from-release-directive is used." + ) + mock_sql_facade_upgrade_application.assert_not_called() + mock_sql_facade_create_application.assert_not_called() + + +# Provide a release channel that is not in the list of available release channels -> error: +@mock.patch( + APP_PACKAGE_ENTITY_GET_EXISTING_VERSION_INFO, + return_value={"key": "val"}, +) +@mock.patch(SQL_FACADE_CREATE_APPLICATION) +@mock.patch(SQL_FACADE_UPGRADE_APPLICATION) +@mock.patch(SQL_FACADE_GRANT_PRIVILEGES_TO_ROLE) +@mock.patch(SQL_FACADE_GET_EVENT_DEFINITIONS) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(SQL_FACADE_GET_EXISTING_APP_INFO) +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch( + f"snowflake.cli._plugins.nativeapp.policy.{TYPER_CONFIRM}", return_value=True +) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.NA_EVENT_SHARING_V2: False, + UIParameter.NA_ENFORCE_MANDATORY_FILTERS: False, + }, +) +def test_run_app_from_release_directive_with_channel_not_in_list( + mock_param, + mock_conn, + mock_typer_confirm, + mock_show_release_channels, + mock_get_existing_app_info, + mock_execute, + mock_sql_facade_get_event_definitions, + mock_sql_facade_grant_privileges_to_role, + mock_sql_facade_upgrade_application, + mock_sql_facade_create_application, + mock_existing, + temp_dir, + mock_cursor, + mock_bundle_map, +): + mock_get_existing_app_info.return_value = { + "name": "myapp", + "comment": SPECIAL_COMMENT, + "owner": "app_role", + } + mock_show_release_channels.return_value = [ + {"name": "channel1"}, + {"name": "channel2"}, + ] + mock_conn.return_value = MockConnectionCtx() + mock_sql_facade_upgrade_application.side_effect = ( + UpgradeApplicationRestrictionError(DEFAULT_USER_INPUT_ERROR_MESSAGE) + ) + mock_sql_facade_create_application.side_effect = mock_cursor( + [[(DEFAULT_CREATE_SUCCESS_MESSAGE,)]], [] + ) + + setup_project_file(os.getcwd()) + + wm = _get_wm() + wm.perform_action( + "app_pkg", + EntityActions.BUNDLE, + ) + with pytest.raises(UsageError) as err: + wm.perform_action( + "myapp", + EntityActions.DEPLOY, + from_release_directive=True, + prune=True, + recursive=True, + paths=[], + validate=False, + version="v1", + release_channel="unknown_channel", + ) + + assert ( + str(err.value) + == "Release channel 'unknown_channel' is not available for application package app_pkg. Available release channels: (channel1, channel2)." + ) + mock_sql_facade_upgrade_application.assert_not_called() + mock_sql_facade_create_application.assert_not_called() + + +@mock.patch( + APP_PACKAGE_ENTITY_GET_EXISTING_VERSION_INFO, + return_value={"key": "val"}, +) +@mock.patch(SQL_FACADE_CREATE_APPLICATION) +@mock.patch(SQL_FACADE_UPGRADE_APPLICATION) +@mock.patch(SQL_FACADE_GRANT_PRIVILEGES_TO_ROLE) +@mock.patch(SQL_FACADE_GET_EVENT_DEFINITIONS) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(SQL_FACADE_GET_EXISTING_APP_INFO) +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch( + f"snowflake.cli._plugins.nativeapp.policy.{TYPER_CONFIRM}", return_value=True +) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.NA_EVENT_SHARING_V2: False, + UIParameter.NA_ENFORCE_MANDATORY_FILTERS: False, + }, +) +def test_run_app_from_release_directive_with_non_default_channel_but_release_channels_not_enabled( + mock_param, + mock_conn, + mock_typer_confirm, + mock_show_release_channels, + mock_get_existing_app_info, + mock_execute, + mock_sql_facade_get_event_definitions, + mock_sql_facade_grant_privileges_to_role, + mock_sql_facade_upgrade_application, + mock_sql_facade_create_application, + mock_existing, + temp_dir, + mock_cursor, + mock_bundle_map, +): + mock_get_existing_app_info.return_value = { + "name": "myapp", + "comment": SPECIAL_COMMENT, + "owner": "app_role", + } + mock_show_release_channels.return_value = [] + mock_conn.return_value = MockConnectionCtx() + mock_sql_facade_upgrade_application.side_effect = ( + UpgradeApplicationRestrictionError(DEFAULT_USER_INPUT_ERROR_MESSAGE) + ) + mock_sql_facade_create_application.side_effect = mock_cursor( + [[(DEFAULT_CREATE_SUCCESS_MESSAGE,)]], [] + ) + + setup_project_file(os.getcwd()) + + wm = _get_wm() + wm.perform_action( + "app_pkg", + EntityActions.BUNDLE, + ) + with pytest.raises(UsageError) as err: + wm.perform_action( + "myapp", + EntityActions.DEPLOY, + from_release_directive=True, + prune=True, + recursive=True, + paths=[], + validate=False, + version="v1", + release_channel="my_channel", + ) + + assert ( + str(err.value) + == "Release channels are not enabled for application package app_pkg." + ) + mock_sql_facade_upgrade_application.assert_not_called() + mock_sql_facade_create_application.assert_not_called() + + +# test with default release channel when release channels not enabled -> success: +@mock.patch( + APP_PACKAGE_ENTITY_GET_EXISTING_VERSION_INFO, + return_value={"key": "val"}, +) +@mock.patch(SQL_FACADE_CREATE_APPLICATION) +@mock.patch(SQL_FACADE_UPGRADE_APPLICATION) +@mock.patch(SQL_FACADE_GRANT_PRIVILEGES_TO_ROLE) +@mock.patch(SQL_FACADE_GET_EVENT_DEFINITIONS) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(SQL_FACADE_GET_EXISTING_APP_INFO) +@mock.patch(SQL_FACADE_SHOW_RELEASE_CHANNELS) +@mock.patch( + f"snowflake.cli._plugins.nativeapp.policy.{TYPER_CONFIRM}", return_value=True +) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.NA_EVENT_SHARING_V2: False, + UIParameter.NA_ENFORCE_MANDATORY_FILTERS: False, + }, +) +def test_run_app_from_release_directive_with_default_channel_when_release_channels_not_enabled( + mock_param, + mock_conn, + mock_typer_confirm, + mock_show_release_channels, + mock_get_existing_app_info, + mock_execute, + mock_sql_facade_get_event_definitions, + mock_sql_facade_grant_privileges_to_role, + mock_sql_facade_upgrade_application, + mock_sql_facade_create_application, + mock_existing, + temp_dir, + mock_cursor, + mock_bundle_map, +): + mock_get_existing_app_info.return_value = { + "name": "myapp", + "comment": SPECIAL_COMMENT, + "owner": "app_role", + } + mock_show_release_channels.return_value = [] + mock_conn.return_value = MockConnectionCtx() + mock_sql_facade_upgrade_application.side_effect = ( + UpgradeApplicationRestrictionError(DEFAULT_USER_INPUT_ERROR_MESSAGE) + ) + mock_sql_facade_create_application.side_effect = mock_cursor( + [[(DEFAULT_CREATE_SUCCESS_MESSAGE,)]], [] + ) + + setup_project_file(os.getcwd()) + + wm = _get_wm() + wm.perform_action( + "app_pkg", + EntityActions.BUNDLE, + ) + wm.perform_action( + "myapp", + EntityActions.DEPLOY, + from_release_directive=True, + prune=True, + recursive=True, + paths=[], + validate=False, + version="v1", + release_channel="default", + ) + + mock_sql_facade_upgrade_application.assert_called_once_with( + name=DEFAULT_APP_ID, + install_method=SameAccountInstallMethod.release_directive(), + stage_fqn=DEFAULT_STAGE_FQN, + debug_mode=True, + should_authorize_event_sharing=None, + role=DEFAULT_ROLE, + warehouse=DEFAULT_WAREHOUSE, + release_channel=None, + ) + mock_sql_facade_create_application.assert_called_once_with( + name=DEFAULT_APP_ID, + package_name=DEFAULT_PKG_ID, + install_method=SameAccountInstallMethod.release_directive(), + stage_fqn=DEFAULT_STAGE_FQN, + debug_mode=True, + should_authorize_event_sharing=None, + role=DEFAULT_ROLE, + warehouse=DEFAULT_WAREHOUSE, + release_channel=None, + ) + + mock_sql_facade_grant_privileges_to_role.assert_has_calls( + [ + mock.call( + privileges=["install", "develop"], + object_type=ObjectType.APPLICATION_PACKAGE, + object_identifier="app_pkg", + role_to_grant="app_role", + role_to_use="package_role", + ), + mock.call( + privileges=["usage"], + object_type=ObjectType.SCHEMA, + object_identifier="app_pkg.app_src", + role_to_grant="app_role", + role_to_use="package_role", + ), + mock.call( + privileges=["read"], + object_type=ObjectType.STAGE, + object_identifier="app_pkg.app_src.stage", + role_to_grant="app_role", + role_to_use="package_role", + ), + ] + ) + + # Test get_existing_version_info returns version info correctly @mock.patch(SQL_EXECUTOR_EXECUTE) def test_get_existing_version_info( diff --git a/tests/nativeapp/test_sf_sql_facade.py b/tests/nativeapp/test_sf_sql_facade.py index 8d8107249c..bcdf4cd0e8 100644 --- a/tests/nativeapp/test_sf_sql_facade.py +++ b/tests/nativeapp/test_sf_sql_facade.py @@ -20,6 +20,7 @@ from snowflake.cli._plugins.connection.util import UIParameter from snowflake.cli._plugins.nativeapp.constants import ( AUTHORIZE_TELEMETRY_COL, + CHANNEL_COL, COMMENT_COL, NAME_COL, SPECIAL_COMMENT, @@ -34,6 +35,7 @@ InvalidSQLError, UnknownConnectorError, UnknownSQLError, + UpgradeApplicationRestrictionError, UserInputError, UserScriptError, ) @@ -2063,6 +2065,91 @@ def test_upgrade_application_converts_unexpected_programmingerrors_to_unclassifi assert_programmingerror_cause_with_errno(err, SQL_COMPILATION_ERROR) +def test_upgrade_application_with_release_channel_same_as_app_properties( + mock_get_app_properties, + mock_get_existing_app_info, + mock_use_warehouse, + mock_use_role, + mock_execute_query, + mock_cursor, +): + app_name = "test_app" + stage_fqn = "app_pkg.app_src.stage" + role = "test_role" + warehouse = "test_warehouse" + release_channel = "test_channel" + mock_get_app_properties.return_value = { + COMMENT_COL: SPECIAL_COMMENT, + AUTHORIZE_TELEMETRY_COL: "true", + CHANNEL_COL: release_channel, + } + + side_effects, expected = mock_execute_helper( + [ + ( + mock_cursor([], []), + mock.call(f"alter application {app_name} upgrade "), + ) + ] + ) + mock_execute_query.side_effect = side_effects + + expected_use_objects = [ + (mock_use_role, mock.call(role)), + (mock_use_warehouse, mock.call(warehouse)), + ] + expected_execute_query = [(mock_execute_query, call) for call in expected] + + with assert_in_context(expected_use_objects, expected_execute_query): + sql_facade.upgrade_application( + name=app_name, + install_method=SameAccountInstallMethod.release_directive(), + stage_fqn=stage_fqn, + debug_mode=True, + should_authorize_event_sharing=True, + role=role, + warehouse=warehouse, + release_channel=release_channel, + ) + + +def test_upgrade_application_with_release_channel_not_same_as_app_properties_then_upgrade_error( + mock_get_app_properties, + mock_get_existing_app_info, + mock_use_warehouse, + mock_use_role, + mock_execute_query, + mock_cursor, +): + app_name = "test_app" + stage_fqn = "app_pkg.app_src.stage" + role = "test_role" + warehouse = "test_warehouse" + release_channel = "test_channel" + mock_get_app_properties.return_value = { + COMMENT_COL: SPECIAL_COMMENT, + AUTHORIZE_TELEMETRY_COL: "true", + CHANNEL_COL: "different_channel", + } + + with pytest.raises(UpgradeApplicationRestrictionError) as err: + sql_facade.upgrade_application( + name=app_name, + install_method=SameAccountInstallMethod.release_directive(), + stage_fqn=stage_fqn, + debug_mode=True, + should_authorize_event_sharing=True, + role=role, + warehouse=warehouse, + release_channel=release_channel, + ) + + assert ( + str(err.value) + == f"Application {app_name} is currently on release channel different_channel. Cannot upgrade to release channel {release_channel}." + ) + + def test_create_application_with_minimal_clauses( mock_use_warehouse, mock_use_role, @@ -2083,7 +2170,7 @@ def test_create_application_with_minimal_clauses( dedent( f"""\ create application {app_name} - from application package {pkg_name} + from application package {pkg_name} comment = {SPECIAL_COMMENT} """ ) @@ -2132,7 +2219,10 @@ def test_create_application_with_all_clauses( dedent( f"""\ create application {app_name} - from application package {pkg_name} using @{stage_fqn} debug_mode = True AUTHORIZE_TELEMETRY_EVENT_SHARING = TRUE + from application package {pkg_name} + using @{stage_fqn} + debug_mode = True + AUTHORIZE_TELEMETRY_EVENT_SHARING = TRUE comment = {SPECIAL_COMMENT} """ ) @@ -2182,7 +2272,7 @@ def test_create_application_converts_expected_programmingerrors_to_user_errors( dedent( f"""\ create application {app_name} - from application package {pkg_name} + from application package {pkg_name} comment = {SPECIAL_COMMENT} """ ) @@ -2241,7 +2331,10 @@ def test_create_application_special_message_for_event_sharing_error( dedent( f"""\ create application {app_name} - from application package {pkg_name} using version "3" patch 1 debug_mode = False AUTHORIZE_TELEMETRY_EVENT_SHARING = FALSE + from application package {pkg_name} + using version "3" patch 1 + debug_mode = False + AUTHORIZE_TELEMETRY_EVENT_SHARING = FALSE comment = {SPECIAL_COMMENT} """ ) @@ -2299,7 +2392,7 @@ def test_create_application_converts_unexpected_programmingerrors_to_unclassifie dedent( f"""\ create application {app_name} - from application package {pkg_name} + from application package {pkg_name} comment = {SPECIAL_COMMENT} """ ) @@ -2333,6 +2426,58 @@ def test_create_application_converts_unexpected_programmingerrors_to_unclassifie assert_programmingerror_cause_with_errno(err, SQL_COMPILATION_ERROR) +def test_create_application_with_release_channel( + mock_use_warehouse, + mock_use_role, + mock_execute_query, + mock_cursor, +): + app_name = "test_app" + pkg_name = "test_pkg" + stage_fqn = "app_pkg.app_src.stage" + role = "test_role" + warehouse = "test_warehouse" + release_channel = "test_channel" + + side_effects, expected = mock_execute_helper( + [ + ( + mock_cursor([], []), + mock.call( + dedent( + f"""\ + create application {app_name} + from application package {pkg_name} + using release channel {release_channel} + comment = {SPECIAL_COMMENT} + """ + ) + ), + ) + ] + ) + mock_execute_query.side_effect = side_effects + + expected_use_objects = [ + (mock_use_role, mock.call(role)), + (mock_use_warehouse, mock.call(warehouse)), + ] + expected_execute_query = [(mock_execute_query, call) for call in expected] + + with assert_in_context(expected_use_objects, expected_execute_query): + sql_facade.create_application( + name=app_name, + package_name=pkg_name, + install_method=SameAccountInstallMethod.release_directive(), + stage_fqn=stage_fqn, + debug_mode=None, + should_authorize_event_sharing=None, + role=role, + warehouse=warehouse, + release_channel=release_channel, + ) + + @pytest.mark.parametrize( "pkg_name, sanitized_pkg_name", [("test_pkg", "test_pkg"), ("test.pkg", '"test.pkg"')], diff --git a/tests_integration/nativeapp/test_init_run.py b/tests_integration/nativeapp/test_init_run.py index 57c9488d6d..9a96656964 100644 --- a/tests_integration/nativeapp/test_init_run.py +++ b/tests_integration/nativeapp/test_init_run.py @@ -494,3 +494,75 @@ def test_nativeapp_force_cross_upgrade( assert result.exit_code == 0 if is_cross_upgrade: assert f"Dropping application object {app_name}." in result.output + + +@pytest.mark.integration +@pytest.mark.parametrize( + "test_project", + [ + "napp_init_v2", + ], +) +def test_nativeapp_upgrade_from_release_directive_and_default_channel( + test_project, + nativeapp_project_directory, + runner, +): + + with nativeapp_project_directory(test_project): + # Create version + result = runner.invoke_with_connection(["app", "version", "create", "v1"]) + assert result.exit_code == 0 + + # Set default release directive + result = runner.invoke_with_connection( + ["app", "release-directive", "set", "default", "--version=v1", "--patch=0"] + ) + assert result.exit_code == 0 + + # Initial create + result = runner.invoke_with_connection(["app", "run"]) + assert result.exit_code == 0 + + # (Cross-)upgrade + result = runner.invoke_with_connection( + [ + "app", + "run", + "--from-release-directive", + "--channel", + "default", + "--force", + ] + ) + assert result.exit_code == 0 + + +@pytest.mark.integration +@pytest.mark.parametrize( + "test_project", + [ + "napp_init_v2", + ], +) +def test_nativeapp_create_from_release_directive_and_default_channel( + test_project, + nativeapp_project_directory, + runner, +): + with nativeapp_project_directory(test_project): + # Create version + result = runner.invoke_with_connection(["app", "version", "create", "v1"]) + assert result.exit_code == 0 + + # Set default release directive + result = runner.invoke_with_connection( + ["app", "release-directive", "set", "default", "--version=v1", "--patch=0"] + ) + assert result.exit_code == 0 + + # Initial create + result = runner.invoke_with_connection( + ["app", "run", "--from-release-directive", "--channel", "default"] + ) + assert result.exit_code == 0 From dae4e10f3eb0e896dbb4eacef0d0774fea410719 Mon Sep 17 00:00:00 2001 From: Adam Stus Date: Mon, 16 Dec 2024 16:04:11 +0100 Subject: [PATCH 5/5] Moved default streamlit warehouse from model (#1952) --- RELEASE-NOTES.md | 1 + .../cli/_plugins/streamlit/manager.py | 9 ++++- .../project/schemas/v1/streamlit/streamlit.py | 2 +- tests/api/utils/test_definition_rendering.py | 1 - tests/streamlit/test_streamlit_manager.py | 38 +++++++++++++++++++ 5 files changed, 48 insertions(+), 3 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index f158788b7c..81bd32d667 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -17,6 +17,7 @@ ## Backward incompatibility ## Deprecations +* Added deprecation message for default Streamlit warehouse ## New additions * Add Release Directives support by introducing the following commands: diff --git a/src/snowflake/cli/_plugins/streamlit/manager.py b/src/snowflake/cli/_plugins/streamlit/manager.py index 3eabd528d9..8d51c9f240 100644 --- a/src/snowflake/cli/_plugins/streamlit/manager.py +++ b/src/snowflake/cli/_plugins/streamlit/manager.py @@ -104,8 +104,15 @@ def _create_streamlit( query.append(f"MAIN_FILE = '{streamlit.main_file}'") if streamlit.imports: query.append(streamlit.get_imports_sql()) - if streamlit.query_warehouse: + + if not streamlit.query_warehouse: + cli_console.warning( + "[Deprecation] In next major version we will remove default query_warehouse='streamlit'." + ) + query.append(f"QUERY_WAREHOUSE = 'streamlit'") + else: query.append(f"QUERY_WAREHOUSE = {streamlit.query_warehouse}") + if streamlit.title: query.append(f"TITLE = '{streamlit.title}'") diff --git a/src/snowflake/cli/api/project/schemas/v1/streamlit/streamlit.py b/src/snowflake/cli/api/project/schemas/v1/streamlit/streamlit.py index c7a454d077..c77ce9e00d 100644 --- a/src/snowflake/cli/api/project/schemas/v1/streamlit/streamlit.py +++ b/src/snowflake/cli/api/project/schemas/v1/streamlit/streamlit.py @@ -27,7 +27,7 @@ class Streamlit(UpdatableModel, ObjectIdentifierModel(object_name="Streamlit")): title="Stage in which the app’s artifacts will be stored", default="streamlit" ) query_warehouse: str = Field( - title="Snowflake warehouse to host the app", default="streamlit" + title="Snowflake warehouse to host the app", default=None ) main_file: Optional[Path] = Field( title="Entrypoint file of the Streamlit app", default="streamlit_app.py" diff --git a/tests/api/utils/test_definition_rendering.py b/tests/api/utils/test_definition_rendering.py index f25f47de7a..4cb3225b80 100644 --- a/tests/api/utils/test_definition_rendering.py +++ b/tests/api/utils/test_definition_rendering.py @@ -196,7 +196,6 @@ def test_resolve_variables_in_project_cross_project_dependencies(): "streamlit": { "name": "my_app", "main_file": "streamlit_app.py", - "query_warehouse": "streamlit", "stage": "streamlit", }, "env": ProjectEnvironment( diff --git a/tests/streamlit/test_streamlit_manager.py b/tests/streamlit/test_streamlit_manager.py index 4f04fe854c..0fcfad4ec0 100644 --- a/tests/streamlit/test_streamlit_manager.py +++ b/tests/streamlit/test_streamlit_manager.py @@ -134,6 +134,44 @@ def test_deploy_streamlit_with_comment( ) +@mock.patch("snowflake.cli._plugins.streamlit.manager.StageManager") +@mock.patch("snowflake.cli._plugins.streamlit.manager.StreamlitManager.get_url") +@mock.patch("snowflake.cli._plugins.streamlit.manager.StreamlitManager.execute_query") +@mock_streamlit_exists +def test_deploy_streamlit_with_default_warehouse( + mock_execute_query, _, mock_stage_manager, temp_dir +): + mock_stage_manager().get_standard_stage_prefix.return_value = "stage_root" + + main_file = Path(temp_dir) / "main.py" + main_file.touch() + + st = StreamlitEntityModel( + type="streamlit", + identifier="my_streamlit_app", + title="MyStreamlit", + main_file=str(main_file), + artifacts=[main_file], + comment="This is a test comment", + ) + + StreamlitManager(MagicMock(database="DB", schema="SH")).deploy( + streamlit=st, replace=False + ) + + mock_execute_query.assert_called_once_with( + dedent( + f"""\ + CREATE STREAMLIT IDENTIFIER('DB.SH.my_streamlit_app') + ROOT_LOCATION = 'stage_root' + MAIN_FILE = '{main_file}' + QUERY_WAREHOUSE = 'streamlit' + TITLE = 'MyStreamlit' + COMMENT = 'This is a test comment'""" + ) + ) + + @mock.patch("snowflake.cli._plugins.streamlit.manager.StreamlitManager.execute_query") @mock_streamlit_exists def test_execute_streamlit(mock_execute_query):