diff --git a/src/snowflake/cli/_plugins/nativeapp/manager.py b/src/snowflake/cli/_plugins/nativeapp/manager.py index 0bc7ccf604..ad68a6b767 100644 --- a/src/snowflake/cli/_plugins/nativeapp/manager.py +++ b/src/snowflake/cli/_plugins/nativeapp/manager.py @@ -15,7 +15,6 @@ from __future__ import annotations import json -import os import time from abc import ABC, abstractmethod from contextlib import contextmanager @@ -23,146 +22,58 @@ from functools import cached_property from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, List, NoReturn, Optional, TypedDict +from typing import Generator, List, Optional, TypedDict -import jinja2 from click import ClickException from snowflake.cli._plugins.connection.util import make_snowsight_url from snowflake.cli._plugins.nativeapp.artifacts import ( BundleMap, build_bundle, - resolve_without_follow, ) from snowflake.cli._plugins.nativeapp.codegen.compiler import ( NativeAppCompiler, ) from snowflake.cli._plugins.nativeapp.constants import ( - ALLOWED_SPECIAL_COMMENTS, - COMMENT_COL, - INTERNAL_DISTRIBUTION, NAME_COL, - OWNER_COL, - SPECIAL_COMMENT, ) from snowflake.cli._plugins.nativeapp.exceptions import ( - ApplicationPackageAlreadyExistsError, ApplicationPackageDoesNotExistError, - InvalidScriptError, - MissingScriptError, NoEventTableForAccount, SetupScriptFailedValidation, - UnexpectedOwnerError, ) from snowflake.cli._plugins.nativeapp.project_model import ( NativeAppProjectModel, ) -from snowflake.cli._plugins.nativeapp.utils import verify_exists, verify_no_directories from snowflake.cli._plugins.stage.diff import ( DiffResult, - StagePath, - compute_stage_diff, - preserve_from_diff, - sync_local_diff_with_stage, - to_stage_path, ) from snowflake.cli._plugins.stage.manager import StageManager -from snowflake.cli._plugins.stage.utils import print_diff_to_console from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.entities.application_package_entity import ( + ApplicationPackageEntity, +) +from snowflake.cli.api.entities.utils import ( + execute_post_deploy_hooks, + generic_sql_error_handler, + sync_deploy_root_with_stage, +) from snowflake.cli.api.errno import ( - DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED, DOES_NOT_EXIST_OR_NOT_AUTHORIZED, - NO_WAREHOUSE_SELECTED_IN_SESSION, ) from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError -from snowflake.cli.api.project.schemas.native_app.application import ( - PostDeployHook, -) +from snowflake.cli.api.project.schemas.entities.common import PostDeployHook from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping from snowflake.cli.api.project.util import ( identifier_for_url, unquote_identifier, ) -from snowflake.cli.api.rendering.jinja import ( - jinja_render_from_str, -) -from snowflake.cli.api.rendering.sql_templates import ( - snowflake_sql_jinja_render, -) -from snowflake.cli.api.secure_path import UNLIMITED, SecurePath from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.connector import DictCursor, ProgrammingError ApplicationOwnedObject = TypedDict("ApplicationOwnedObject", {"name": str, "type": str}) -def generic_sql_error_handler( - err: ProgrammingError, role: Optional[str] = None, warehouse: Optional[str] = None -) -> NoReturn: - # Potential refactor: If moving away from Python 3.8 and 3.9 to >= 3.10, use match ... case - if err.errno == DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED: - raise ProgrammingError( - msg=dedent( - f"""\ - Received error message '{err.msg}' while executing SQL statement. - '{role}' may not have access to warehouse '{warehouse}'. - Please grant usage privilege on warehouse to this role. - """ - ), - errno=err.errno, - ) - elif err.errno == NO_WAREHOUSE_SELECTED_IN_SESSION: - raise ProgrammingError( - msg=dedent( - f"""\ - Received error message '{err.msg}' while executing SQL statement. - Please provide a warehouse for the active session role in your project definition file, config.toml file, or via command line. - """ - ), - errno=err.errno, - ) - elif "does not exist or not authorized" in err.msg: - raise ProgrammingError( - msg=dedent( - f"""\ - Received error message '{err.msg}' while executing SQL statement. - Please check the name of the resource you are trying to query or the permissions of the role you are using to run the query. - """ - ) - ) - raise err - - -def ensure_correct_owner(row: dict, role: str, obj_name: str) -> None: - """ - Check if an object has the right owner role - """ - actual_owner = row[ - OWNER_COL - ].upper() # Because unquote_identifier() always returns uppercase str - if actual_owner != unquote_identifier(role): - raise UnexpectedOwnerError(obj_name, role, actual_owner) - - -def _get_stage_paths_to_sync( - local_paths_to_sync: List[Path], deploy_root: Path -) -> List[StagePath]: - """ - Takes a list of paths (files and directories), returning a list of all files recursively relative to the deploy root. - """ - - stage_paths = [] - for path in local_paths_to_sync: - if path.is_dir(): - for current_dir, _dirs, files in os.walk(path): - for file in files: - deploy_path = Path(current_dir, file).relative_to(deploy_root) - stage_paths.append(to_stage_path(deploy_path)) - else: - stage_paths.append(to_stage_path(path.relative_to(deploy_root))) - return stage_paths - - class NativeAppCommandProcessor(ABC): @abstractmethod def process(self, *args, **kwargs): @@ -229,20 +140,10 @@ def stage_schema(self) -> Optional[str]: def package_warehouse(self) -> Optional[str]: return self.na_project.package_warehouse - @contextmanager def use_package_warehouse(self): - if self.package_warehouse: - with self.use_warehouse(self.package_warehouse): - yield - else: - raise ClickException( - dedent( - f"""\ - Application package warehouse cannot be empty. - Please provide a value for it in your connection information or your project definition file. - """ - ) - ) + return ApplicationPackageEntity.use_package_warehouse( + self.package_warehouse, + ) @property def application_warehouse(self) -> Optional[str]: @@ -301,30 +202,8 @@ def debug_mode(self) -> bool: @cached_property def get_app_pkg_distribution_in_snowflake(self) -> str: - """ - Returns the 'distribution' attribute of a 'describe application package' SQL query, in lowercase. - """ - with self.use_role(self.package_role): - try: - desc_cursor = self._execute_query( - f"describe application package {self.package_name}" - ) - except ProgrammingError as err: - generic_sql_error_handler(err) - - if desc_cursor.rowcount is None or desc_cursor.rowcount == 0: - raise SnowflakeSQLExecutionError() - else: - for row in desc_cursor: - if row[0].lower() == "distribution": - return row[1].lower() - raise ProgrammingError( - msg=dedent( - f"""\ - Could not find the 'distribution' attribute for application package {self.package_name} in the output of SQL query: - 'describe application package {self.package_name}' - """ - ) + return ApplicationPackageEntity.get_app_pkg_distribution_in_snowflake( + self.package_name, self.package_role ) @cached_property @@ -336,27 +215,13 @@ def account_event_table(self) -> str: def verify_project_distribution( self, expected_distribution: Optional[str] = None ) -> bool: - """ - Returns true if the 'distribution' attribute of an existing application package in snowflake - is the same as the the attribute specified in project definition file. - """ - actual_distribution = ( - expected_distribution - if expected_distribution - else self.get_app_pkg_distribution_in_snowflake + return ApplicationPackageEntity.verify_project_distribution( + console=cc, + package_name=self.package_name, + package_role=self.package_role, + package_distribution=self.package_distribution, + expected_distribution=expected_distribution, ) - project_def_distribution = self.package_distribution.lower() - if actual_distribution != project_def_distribution: - cc.warning( - dedent( - f"""\ - Application package {self.package_name} in your Snowflake account has distribution property {actual_distribution}, - which does not match the value specified in project definition file: {project_def_distribution}. - """ - ) - ) - return False - return True def build_bundle(self) -> BundleMap: """ @@ -377,110 +242,19 @@ def sync_deploy_root_with_stage( local_paths_to_sync: List[Path] | None = None, print_diff: bool = True, ) -> DiffResult: - """ - Ensures that the files on our remote stage match the artifacts we have in - the local filesystem. - - Args: - bundle_map (BundleMap): The artifact mapping computed by the `build_bundle` function. - role (str): The name of the role to use for queries and commands. - prune (bool): Whether to prune artifacts from the stage that don't exist locally. - recursive (bool): Whether to traverse directories recursively. - stage_fqn (str): The name of the stage to diff against and upload to. - local_paths_to_sync (List[Path], optional): List of local paths to sync. Defaults to None to sync all - local paths. Note that providing an empty list here is equivalent to None. - print_diff (bool): Whether to print the diff between the local files and the remote stage. Defaults to True - - Returns: - A `DiffResult` instance describing the changes that were performed. - """ - - # Does a stage already exist within the application package, or we need to create one? - # Using "if not exists" should take care of either case. - cc.step( - f"Checking if stage {stage_fqn} exists, or creating a new one if none exists." + return sync_deploy_root_with_stage( + console=cc, + deploy_root=self.deploy_root, + package_name=self.package_name, + stage_schema=self.stage_schema, + bundle_map=bundle_map, + role=role, + prune=prune, + recursive=recursive, + stage_fqn=stage_fqn, + local_paths_to_sync=local_paths_to_sync, + print_diff=print_diff, ) - with self.use_role(role): - self._execute_query( - f"create schema if not exists {self.package_name}.{self.stage_schema}" - ) - self._execute_query( - f""" - create stage if not exists {stage_fqn} - encryption = (TYPE = 'SNOWFLAKE_SSE') - DIRECTORY = (ENABLE = TRUE)""" - ) - - # Perform a diff operation and display results to the user for informational purposes - if print_diff: - cc.step( - "Performing a diff between the Snowflake stage and your local deploy_root ('%s') directory." - % self.deploy_root.resolve() - ) - diff: DiffResult = compute_stage_diff(self.deploy_root, stage_fqn) - - if local_paths_to_sync: - # Deploying specific files/directories - resolved_paths_to_sync = [ - resolve_without_follow(p) for p in local_paths_to_sync - ] - if not recursive: - verify_no_directories(resolved_paths_to_sync) - - deploy_paths_to_sync = [] - for resolved_path in resolved_paths_to_sync: - verify_exists(resolved_path) - deploy_paths = bundle_map.to_deploy_paths(resolved_path) - if not deploy_paths: - if resolved_path.is_dir() and recursive: - # No direct artifact mapping found for this path. Check to see - # if there are subpaths of this directory that are matches. We - # loop over sources because it's likely a much smaller list - # than the project directory. - for src in bundle_map.all_sources(absolute=True): - if resolved_path in src.parents: - # There is a source that contains this path, get its dest path(s) - deploy_paths.extend(bundle_map.to_deploy_paths(src)) - - if not deploy_paths: - raise ClickException(f"No artifact found for {resolved_path}") - deploy_paths_to_sync.extend(deploy_paths) - - stage_paths_to_sync = _get_stage_paths_to_sync( - deploy_paths_to_sync, resolve_without_follow(self.deploy_root) - ) - diff = preserve_from_diff(diff, stage_paths_to_sync) - else: - # Full deploy - if not recursive: - verify_no_directories(self.deploy_root.resolve().iterdir()) - - if not prune: - files_not_removed = [str(path) for path in diff.only_on_stage] - diff.only_on_stage = [] - - if len(files_not_removed) > 0: - files_not_removed_str = "\n".join(files_not_removed) - cc.warning( - f"The following files exist only on the stage:\n{files_not_removed_str}\n\nUse the --prune flag to delete them from the stage." - ) - - if print_diff: - print_diff_to_console(diff, bundle_map) - - # Upload diff-ed files to application package stage - if diff.has_changes(): - cc.step( - "Updating the Snowflake stage from your local %s directory." - % self.deploy_root.resolve(), - ) - sync_local_diff_with_stage( - role=role, - deploy_root_path=self.deploy_root, - diff_result=diff, - stage_fqn=stage_fqn, - ) - return diff def get_existing_app_info(self) -> Optional[dict]: """ @@ -493,15 +267,10 @@ def get_existing_app_info(self) -> Optional[dict]: ) def get_existing_app_pkg_info(self) -> Optional[dict]: - """ - Check for an existing application package by the same name as in project definition, in account. - It executes a 'show application packages like' query and returns the result as single row, if one exists. - """ - - with self.use_role(self.package_role): - return self.show_specific_object( - "application packages", self.package_name, name_col=NAME_COL - ) + return ApplicationPackageEntity.get_existing_app_pkg_info( + package_name=self.package_name, + package_role=self.package_role, + ) def get_objects_owned_by_application(self) -> List[ApplicationOwnedObject]: """ @@ -537,170 +306,39 @@ def get_snowsight_url(self) -> str: return make_snowsight_url(self._conn, f"/#/apps/application/{name}") def create_app_package(self) -> None: - """ - Creates the application package with our up-to-date stage if none exists. - """ - - # 1. Check for existing existing application package - show_obj_row = self.get_existing_app_pkg_info() - - if show_obj_row: - # 1. Check for the right owner role - ensure_correct_owner( - row=show_obj_row, role=self.package_role, obj_name=self.package_name - ) - - # 2. Check distribution of the existing application package - actual_distribution = self.get_app_pkg_distribution_in_snowflake - if not self.verify_project_distribution(actual_distribution): - cc.warning( - f"Continuing to execute `snow app run` on application package {self.package_name} with distribution '{actual_distribution}'." - ) - - # 3. If actual_distribution is external, skip comment check - if actual_distribution == INTERNAL_DISTRIBUTION: - row_comment = show_obj_row[COMMENT_COL] - - if row_comment not in ALLOWED_SPECIAL_COMMENTS: - raise ApplicationPackageAlreadyExistsError(self.package_name) - - return - - # If no application package pre-exists, create an application package, with the specified distribution in the project definition file. - with self.use_role(self.package_role): - cc.step(f"Creating new application package {self.package_name} in account.") - self._execute_query( - dedent( - f"""\ - create application package {self.package_name} - comment = {SPECIAL_COMMENT} - distribution = {self.package_distribution} - """ - ) - ) - - def _render_script_templates( - self, - render_from_str: Callable[[str, Dict[str, Any]], str], - jinja_context: dict[str, Any], - scripts: List[str], - ) -> List[str]: - """ - Input: - - render_from_str: function which renders a jinja template from a string and jinja context - - jinja_context: a dictionary with the jinja context - - scripts: list of script paths relative to the project root - Returns: - - List of rendered scripts content - Size of the return list is the same as the size of the input scripts list. - """ - scripts_contents = [] - for relpath in scripts: - script_full_path = SecurePath(self.project_root) / relpath - try: - template_content = script_full_path.read_text( - file_size_limit_mb=UNLIMITED - ) - result = render_from_str(template_content, jinja_context) - scripts_contents.append(result) - - except FileNotFoundError as e: - raise MissingScriptError(relpath) from e - - except jinja2.TemplateSyntaxError as e: - raise InvalidScriptError(relpath, e, e.lineno) from e - - except jinja2.UndefinedError as e: - raise InvalidScriptError(relpath, e) from e - - return scripts_contents + return ApplicationPackageEntity.create_app_package( + console=cc, + package_name=self.package_name, + package_role=self.package_role, + package_distribution=self.package_distribution, + ) def _apply_package_scripts(self) -> None: - """ - Assuming the application package exists and we are using the correct role, - applies all package scripts in-order to the application package. - """ - - if self.package_scripts: - cc.warning( - "WARNING: native_app.package.scripts is deprecated. Please migrate to using native_app.package.post_deploy." - ) - - queued_queries = self._render_script_templates( - jinja_render_from_str, - dict(package_name=self.package_name), - self.package_scripts, + return ApplicationPackageEntity.apply_package_scripts( + console=cc, + package_scripts=self.package_scripts, + package_warehouse=self.package_warehouse, + project_root=self.project_root, + package_role=self.package_role, + package_name=self.package_name, ) - # once we're sure all the templates expanded correctly, execute all of them - with self.use_package_warehouse(): - try: - for i, queries in enumerate(queued_queries): - cc.step(f"Applying package script: {self.package_scripts[i]}") - self._execute_queries(queries) - except ProgrammingError as err: - generic_sql_error_handler( - err, role=self.package_role, warehouse=self.package_warehouse - ) - - def _execute_sql_script( - self, script_content: str, database_name: Optional[str] = None - ) -> None: - """ - Executing the provided SQL script content. - This assumes that a relevant warehouse is already active. - If database_name is passed in, it will be used first. - """ - try: - if database_name is not None: - self._execute_query(f"use database {database_name}") - - self._execute_queries(script_content) - except ProgrammingError as err: - generic_sql_error_handler(err) - - def _execute_post_deploy_hooks( - self, - post_deploy_hooks: Optional[List[PostDeployHook]], - deployed_object_type: str, - database_name: str, - ) -> None: - """ - Executes post-deploy hooks for the given object type. - While executing SQL post deploy hooks, it first switches to the database provided in the input. - All post deploy scripts templates will first be expanded using the global template context. - """ - if not post_deploy_hooks: - return - - with cc.phase(f"Executing {deployed_object_type} post-deploy actions"): - sql_scripts_paths = [] - for hook in post_deploy_hooks: - if hook.sql_script: - sql_scripts_paths.append(hook.sql_script) - else: - raise ValueError( - f"Unsupported {deployed_object_type} post-deploy hook type: {hook}" - ) - - scripts_content_list = self._render_script_templates( - snowflake_sql_jinja_render, - {}, - sql_scripts_paths, - ) - - for index, sql_script_path in enumerate(sql_scripts_paths): - cc.step(f"Executing SQL script: {sql_script_path}") - self._execute_sql_script(scripts_content_list[index], database_name) - def execute_package_post_deploy_hooks(self) -> None: - self._execute_post_deploy_hooks( - self.package_post_deploy_hooks, "application package", self.package_name + execute_post_deploy_hooks( + console=cc, + project_root=self.project_root, + post_deploy_hooks=self.package_post_deploy_hooks, + deployed_object_type="application package", + database_name=self.package_name, ) def execute_app_post_deploy_hooks(self) -> None: - self._execute_post_deploy_hooks( - self.app_post_deploy_hooks, "application", self.app_name + execute_post_deploy_hooks( + console=cc, + project_root=self.project_root, + post_deploy_hooks=self.app_post_deploy_hooks, + deployed_object_type="application", + database_name=self.app_name, ) def deploy( @@ -797,7 +435,7 @@ def get_validation_result(self, use_scratch_stage: bool): f"drop stage if exists {self.scratch_stage_fqn}" ) - def get_events( + def get_events( # type: ignore [return] self, since: str | datetime | None = None, until: str | datetime | None = None, diff --git a/src/snowflake/cli/_plugins/nativeapp/project_model.py b/src/snowflake/cli/_plugins/nativeapp/project_model.py index 4f8a5e76e7..dfeac0516e 100644 --- a/src/snowflake/cli/_plugins/nativeapp/project_model.py +++ b/src/snowflake/cli/_plugins/nativeapp/project_model.py @@ -26,9 +26,7 @@ default_application, default_role, ) -from snowflake.cli.api.project.schemas.native_app.application import ( - PostDeployHook, -) +from snowflake.cli.api.project.schemas.entities.common import PostDeployHook from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping from snowflake.cli.api.project.util import ( diff --git a/src/snowflake/cli/_plugins/nativeapp/run_processor.py b/src/snowflake/cli/_plugins/nativeapp/run_processor.py index 3666584192..9590c739ee 100644 --- a/src/snowflake/cli/_plugins/nativeapp/run_processor.py +++ b/src/snowflake/cli/_plugins/nativeapp/run_processor.py @@ -35,8 +35,6 @@ from snowflake.cli._plugins.nativeapp.manager import ( NativeAppCommandProcessor, NativeAppManager, - ensure_correct_owner, - generic_sql_error_handler, ) from snowflake.cli._plugins.nativeapp.policy import PolicyBase from snowflake.cli._plugins.nativeapp.project_model import ( @@ -44,6 +42,10 @@ ) from snowflake.cli._plugins.stage.manager import StageManager from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.entities.utils import ( + ensure_correct_owner, + generic_sql_error_handler, +) from snowflake.cli.api.errno import ( APPLICATION_NO_LONGER_AVAILABLE, APPLICATION_OWNS_EXTERNAL_OBJECTS, diff --git a/src/snowflake/cli/_plugins/nativeapp/teardown_processor.py b/src/snowflake/cli/_plugins/nativeapp/teardown_processor.py index 31ac8795dd..7ad76d98b1 100644 --- a/src/snowflake/cli/_plugins/nativeapp/teardown_processor.py +++ b/src/snowflake/cli/_plugins/nativeapp/teardown_processor.py @@ -32,12 +32,12 @@ from snowflake.cli._plugins.nativeapp.manager import ( NativeAppCommandProcessor, NativeAppManager, - ensure_correct_owner, ) from snowflake.cli._plugins.nativeapp.utils import ( needs_confirmation, ) from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.entities.utils import ensure_correct_owner from snowflake.cli.api.errno import APPLICATION_NO_LONGER_AVAILABLE from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError from snowflake.connector import ProgrammingError diff --git a/src/snowflake/cli/_plugins/nativeapp/version/version_processor.py b/src/snowflake/cli/_plugins/nativeapp/version/version_processor.py index 159d2b6303..14ba06d832 100644 --- a/src/snowflake/cli/_plugins/nativeapp/version/version_processor.py +++ b/src/snowflake/cli/_plugins/nativeapp/version/version_processor.py @@ -32,11 +32,11 @@ from snowflake.cli._plugins.nativeapp.manager import ( NativeAppCommandProcessor, NativeAppManager, - ensure_correct_owner, ) from snowflake.cli._plugins.nativeapp.policy import PolicyBase from snowflake.cli._plugins.nativeapp.run_processor import NativeAppRunProcessor from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.entities.utils import ensure_correct_owner from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.api.project.util import to_identifier, unquote_identifier diff --git a/src/snowflake/cli/api/entities/application_package_entity.py b/src/snowflake/cli/api/entities/application_package_entity.py index e096fe567b..c1593f487a 100644 --- a/src/snowflake/cli/api/entities/application_package_entity.py +++ b/src/snowflake/cli/api/entities/application_package_entity.py @@ -1,13 +1,38 @@ +from contextlib import contextmanager from pathlib import Path +from textwrap import dedent +from typing import List, Optional +from click import ClickException from snowflake.cli._plugins.nativeapp.artifacts import build_bundle from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext from snowflake.cli._plugins.nativeapp.codegen.compiler import NativeAppCompiler +from snowflake.cli._plugins.nativeapp.constants import ( + ALLOWED_SPECIAL_COMMENTS, + COMMENT_COL, + INTERNAL_DISTRIBUTION, + NAME_COL, + SPECIAL_COMMENT, +) +from snowflake.cli._plugins.nativeapp.exceptions import ( + ApplicationPackageAlreadyExistsError, +) from snowflake.cli._plugins.workspace.action_context import ActionContext -from snowflake.cli.api.entities.common import EntityBase +from snowflake.cli.api.console.abc import AbstractConsole +from snowflake.cli.api.entities.common import EntityBase, get_sql_executor +from snowflake.cli.api.entities.utils import ( + ensure_correct_owner, + generic_sql_error_handler, + render_script_templates, +) +from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError from snowflake.cli.api.project.schemas.entities.application_package_entity_model import ( ApplicationPackageEntityModel, ) +from snowflake.cli.api.rendering.jinja import ( + jinja_render_from_str, +) +from snowflake.connector import ProgrammingError class ApplicationPackageEntity(EntityBase[ApplicationPackageEntityModel]): @@ -31,3 +56,205 @@ def action_bundle(self, ctx: ActionContext): compiler = NativeAppCompiler(bundle_context) compiler.compile_artifacts() return bundle_map + + @staticmethod + def get_existing_app_pkg_info( + package_name: str, + package_role: str, + ) -> Optional[dict]: + """ + Check for an existing application package by the same name as in project definition, in account. + It executes a 'show application packages like' query and returns the result as single row, if one exists. + """ + sql_executor = get_sql_executor() + with sql_executor.use_role(package_role): + return sql_executor.show_specific_object( + "application packages", package_name, name_col=NAME_COL + ) + + @staticmethod + def get_app_pkg_distribution_in_snowflake( + package_name: str, + package_role: str, + ) -> str: + """ + Returns the 'distribution' attribute of a 'describe application package' SQL query, in lowercase. + """ + sql_executor = get_sql_executor() + with sql_executor.use_role(package_role): + try: + desc_cursor = sql_executor.execute_query( + f"describe application package {package_name}" + ) + except ProgrammingError as err: + generic_sql_error_handler(err) + + if desc_cursor.rowcount is None or desc_cursor.rowcount == 0: + raise SnowflakeSQLExecutionError() + else: + for row in desc_cursor: + if row[0].lower() == "distribution": + return row[1].lower() + raise ProgrammingError( + msg=dedent( + f"""\ + Could not find the 'distribution' attribute for application package {package_name} in the output of SQL query: + 'describe application package {package_name}' + """ + ) + ) + + @classmethod + def verify_project_distribution( + cls, + console: AbstractConsole, + package_name: str, + package_role: str, + package_distribution: str, + expected_distribution: Optional[str] = None, + ) -> bool: + """ + Returns true if the 'distribution' attribute of an existing application package in snowflake + is the same as the the attribute specified in project definition file. + """ + actual_distribution = ( + expected_distribution + if expected_distribution + else cls.get_app_pkg_distribution_in_snowflake( + package_name=package_name, + package_role=package_role, + ) + ) + project_def_distribution = package_distribution.lower() + if actual_distribution != project_def_distribution: + console.warning( + dedent( + f"""\ + Application package {package_name} in your Snowflake account has distribution property {actual_distribution}, + which does not match the value specified in project definition file: {project_def_distribution}. + """ + ) + ) + return False + return True + + @staticmethod + @contextmanager + def use_package_warehouse( + package_warehouse: Optional[str], + ): + if package_warehouse: + with get_sql_executor().use_warehouse(package_warehouse): + yield + else: + raise ClickException( + dedent( + f"""\ + Application package warehouse cannot be empty. + Please provide a value for it in your connection information or your project definition file. + """ + ) + ) + + @classmethod + def apply_package_scripts( + cls, + console: AbstractConsole, + package_scripts: List[str], + package_warehouse: Optional[str], + project_root: Path, + package_role: str, + package_name: str, + ) -> None: + """ + Assuming the application package exists and we are using the correct role, + applies all package scripts in-order to the application package. + """ + + if package_scripts: + console.warning( + "WARNING: native_app.package.scripts is deprecated. Please migrate to using native_app.package.post_deploy." + ) + + queued_queries = render_script_templates( + project_root, + jinja_render_from_str, + dict(package_name=package_name), + package_scripts, + ) + + # once we're sure all the templates expanded correctly, execute all of them + with cls.use_package_warehouse( + package_warehouse=package_warehouse, + ): + try: + for i, queries in enumerate(queued_queries): + console.step(f"Applying package script: {package_scripts[i]}") + get_sql_executor().execute_queries(queries) + except ProgrammingError as err: + generic_sql_error_handler( + err, role=package_role, warehouse=package_warehouse + ) + + @classmethod + def create_app_package( + cls, + console: AbstractConsole, + package_name: str, + package_role: str, + package_distribution: str, + ) -> None: + """ + Creates the application package with our up-to-date stage if none exists. + """ + + # 1. Check for existing existing application package + show_obj_row = cls.get_existing_app_pkg_info( + package_name=package_name, + package_role=package_role, + ) + + if show_obj_row: + # 1. Check for the right owner role + ensure_correct_owner( + row=show_obj_row, role=package_role, obj_name=package_name + ) + + # 2. Check distribution of the existing application package + actual_distribution = cls.get_app_pkg_distribution_in_snowflake( + package_name=package_name, + package_role=package_role, + ) + if not cls.verify_project_distribution( + console=console, + package_name=package_name, + package_role=package_role, + package_distribution=package_distribution, + expected_distribution=actual_distribution, + ): + console.warning( + f"Continuing to execute `snow app run` on application package {package_name} with distribution '{actual_distribution}'." + ) + + # 3. If actual_distribution is external, skip comment check + if actual_distribution == INTERNAL_DISTRIBUTION: + row_comment = show_obj_row[COMMENT_COL] + + if row_comment not in ALLOWED_SPECIAL_COMMENTS: + raise ApplicationPackageAlreadyExistsError(package_name) + + return + + # If no application package pre-exists, create an application package, with the specified distribution in the project definition file. + sql_executor = get_sql_executor() + with sql_executor.use_role(package_role): + console.step(f"Creating new application package {package_name} in account.") + sql_executor.execute_query( + dedent( + f"""\ + create application package {package_name} + comment = {SPECIAL_COMMENT} + distribution = {package_distribution} + """ + ) + ) diff --git a/src/snowflake/cli/api/entities/common.py b/src/snowflake/cli/api/entities/common.py index f248c0951f..48c90387af 100644 --- a/src/snowflake/cli/api/entities/common.py +++ b/src/snowflake/cli/api/entities/common.py @@ -2,6 +2,7 @@ from typing import Generic, Type, TypeVar, get_args from snowflake.cli._plugins.workspace.action_context import ActionContext +from snowflake.cli.api.sql_execution import SqlExecutor class EntityActions(str, Enum): @@ -39,3 +40,8 @@ def perform(self, action: EntityActions, action_ctx: ActionContext): Performs the requested action. """ return getattr(self, action)(action_ctx) + + +def get_sql_executor() -> SqlExecutor: + """Returns an SQL Executor that uses the connection from the current CLI context""" + return SqlExecutor() diff --git a/src/snowflake/cli/api/entities/utils.py b/src/snowflake/cli/api/entities/utils.py new file mode 100644 index 0000000000..19e8ca7b17 --- /dev/null +++ b/src/snowflake/cli/api/entities/utils.py @@ -0,0 +1,321 @@ +import os +from pathlib import Path +from textwrap import dedent +from typing import Any, Callable, Dict, List, NoReturn, Optional + +import jinja2 +from click import ClickException +from snowflake.cli._plugins.nativeapp.artifacts import ( + BundleMap, + resolve_without_follow, +) +from snowflake.cli._plugins.nativeapp.constants import OWNER_COL +from snowflake.cli._plugins.nativeapp.exceptions import ( + InvalidScriptError, + MissingScriptError, + UnexpectedOwnerError, +) +from snowflake.cli._plugins.nativeapp.utils import verify_exists, verify_no_directories +from snowflake.cli._plugins.stage.diff import ( + DiffResult, + StagePath, + compute_stage_diff, + preserve_from_diff, + sync_local_diff_with_stage, + to_stage_path, +) +from snowflake.cli._plugins.stage.utils import print_diff_to_console +from snowflake.cli.api.console.abc import AbstractConsole +from snowflake.cli.api.entities.common import get_sql_executor +from snowflake.cli.api.errno import ( + DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED, + NO_WAREHOUSE_SELECTED_IN_SESSION, +) +from snowflake.cli.api.project.schemas.entities.common import PostDeployHook +from snowflake.cli.api.project.util import unquote_identifier +from snowflake.cli.api.rendering.sql_templates import ( + snowflake_sql_jinja_render, +) +from snowflake.cli.api.secure_path import UNLIMITED, SecurePath +from snowflake.connector import ProgrammingError + + +def generic_sql_error_handler( + err: ProgrammingError, role: Optional[str] = None, warehouse: Optional[str] = None +) -> NoReturn: + # Potential refactor: If moving away from Python 3.8 and 3.9 to >= 3.10, use match ... case + if err.errno == DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED: + raise ProgrammingError( + msg=dedent( + f"""\ + Received error message '{err.msg}' while executing SQL statement. + '{role}' may not have access to warehouse '{warehouse}'. + Please grant usage privilege on warehouse to this role. + """ + ), + errno=err.errno, + ) + elif err.errno == NO_WAREHOUSE_SELECTED_IN_SESSION: + raise ProgrammingError( + msg=dedent( + f"""\ + Received error message '{err.msg}' while executing SQL statement. + Please provide a warehouse for the active session role in your project definition file, config.toml file, or via command line. + """ + ), + errno=err.errno, + ) + elif "does not exist or not authorized" in err.msg: + raise ProgrammingError( + msg=dedent( + f"""\ + Received error message '{err.msg}' while executing SQL statement. + Please check the name of the resource you are trying to query or the permissions of the role you are using to run the query. + """ + ) + ) + raise err + + +def ensure_correct_owner(row: dict, role: str, obj_name: str) -> None: + """ + Check if an object has the right owner role + """ + actual_owner = row[ + OWNER_COL + ].upper() # Because unquote_identifier() always returns uppercase str + if actual_owner != unquote_identifier(role): + raise UnexpectedOwnerError(obj_name, role, actual_owner) + + +def _get_stage_paths_to_sync( + local_paths_to_sync: List[Path], deploy_root: Path +) -> List[StagePath]: + """ + Takes a list of paths (files and directories), returning a list of all files recursively relative to the deploy root. + """ + + stage_paths = [] + for path in local_paths_to_sync: + if path.is_dir(): + for current_dir, _dirs, files in os.walk(path): + for file in files: + deploy_path = Path(current_dir, file).relative_to(deploy_root) + stage_paths.append(to_stage_path(deploy_path)) + else: + stage_paths.append(to_stage_path(path.relative_to(deploy_root))) + return stage_paths + + +def sync_deploy_root_with_stage( + console: AbstractConsole, + deploy_root: Path, + package_name: str, + stage_schema: str, + bundle_map: BundleMap, + role: str, + prune: bool, + recursive: bool, + stage_fqn: str, + local_paths_to_sync: List[Path] | None = None, + print_diff: bool = True, +) -> DiffResult: + """ + Ensures that the files on our remote stage match the artifacts we have in + the local filesystem. + + Args: + bundle_map (BundleMap): The artifact mapping computed by the `build_bundle` function. + role (str): The name of the role to use for queries and commands. + prune (bool): Whether to prune artifacts from the stage that don't exist locally. + recursive (bool): Whether to traverse directories recursively. + stage_fqn (str): The name of the stage to diff against and upload to. + local_paths_to_sync (List[Path], optional): List of local paths to sync. Defaults to None to sync all + local paths. Note that providing an empty list here is equivalent to None. + print_diff (bool): Whether to print the diff between the local files and the remote stage. Defaults to True + + Returns: + A `DiffResult` instance describing the changes that were performed. + """ + + sql_executor = get_sql_executor() + # Does a stage already exist within the application package, or we need to create one? + # Using "if not exists" should take care of either case. + console.step( + f"Checking if stage {stage_fqn} exists, or creating a new one if none exists." + ) + with sql_executor.use_role(role): + sql_executor.execute_query( + f"create schema if not exists {package_name}.{stage_schema}" + ) + sql_executor.execute_query( + f""" + create stage if not exists {stage_fqn} + encryption = (TYPE = 'SNOWFLAKE_SSE') + DIRECTORY = (ENABLE = TRUE)""" + ) + + # Perform a diff operation and display results to the user for informational purposes + if print_diff: + console.step( + "Performing a diff between the Snowflake stage and your local deploy_root ('%s') directory." + % deploy_root.resolve() + ) + diff: DiffResult = compute_stage_diff(deploy_root, stage_fqn) + + if local_paths_to_sync: + # Deploying specific files/directories + resolved_paths_to_sync = [ + resolve_without_follow(p) for p in local_paths_to_sync + ] + if not recursive: + verify_no_directories(resolved_paths_to_sync) + + deploy_paths_to_sync = [] + for resolved_path in resolved_paths_to_sync: + verify_exists(resolved_path) + deploy_paths = bundle_map.to_deploy_paths(resolved_path) + if not deploy_paths: + if resolved_path.is_dir() and recursive: + # No direct artifact mapping found for this path. Check to see + # if there are subpaths of this directory that are matches. We + # loop over sources because it's likely a much smaller list + # than the project directory. + for src in bundle_map.all_sources(absolute=True): + if resolved_path in src.parents: + # There is a source that contains this path, get its dest path(s) + deploy_paths.extend(bundle_map.to_deploy_paths(src)) + + if not deploy_paths: + raise ClickException(f"No artifact found for {resolved_path}") + deploy_paths_to_sync.extend(deploy_paths) + + stage_paths_to_sync = _get_stage_paths_to_sync( + deploy_paths_to_sync, resolve_without_follow(deploy_root) + ) + diff = preserve_from_diff(diff, stage_paths_to_sync) + else: + # Full deploy + if not recursive: + verify_no_directories(deploy_root.resolve().iterdir()) + + if not prune: + files_not_removed = [str(path) for path in diff.only_on_stage] + diff.only_on_stage = [] + + if len(files_not_removed) > 0: + files_not_removed_str = "\n".join(files_not_removed) + console.warning( + f"The following files exist only on the stage:\n{files_not_removed_str}\n\nUse the --prune flag to delete them from the stage." + ) + + if print_diff: + print_diff_to_console(diff, bundle_map) + + # Upload diff-ed files to application package stage + if diff.has_changes(): + console.step( + "Updating the Snowflake stage from your local %s directory." + % deploy_root.resolve(), + ) + sync_local_diff_with_stage( + role=role, + deploy_root_path=deploy_root, + diff_result=diff, + stage_fqn=stage_fqn, + ) + return diff + + +def _execute_sql_script( + script_content: str, + database_name: Optional[str] = None, +) -> None: + """ + Executing the provided SQL script content. + This assumes that a relevant warehouse is already active. + If database_name is passed in, it will be used first. + """ + try: + sql_executor = get_sql_executor() + if database_name is not None: + sql_executor.execute_query(f"use database {database_name}") + sql_executor.execute_queries(script_content) + except ProgrammingError as err: + generic_sql_error_handler(err) + + +def execute_post_deploy_hooks( + console: AbstractConsole, + project_root: Path, + post_deploy_hooks: Optional[List[PostDeployHook]], + deployed_object_type: str, + database_name: str, +) -> None: + """ + Executes post-deploy hooks for the given object type. + While executing SQL post deploy hooks, it first switches to the database provided in the input. + All post deploy scripts templates will first be expanded using the global template context. + """ + if not post_deploy_hooks: + return + + with console.phase(f"Executing {deployed_object_type} post-deploy actions"): + sql_scripts_paths = [] + for hook in post_deploy_hooks: + if hook.sql_script: + sql_scripts_paths.append(hook.sql_script) + else: + raise ValueError( + f"Unsupported {deployed_object_type} post-deploy hook type: {hook}" + ) + + scripts_content_list = render_script_templates( + project_root, + snowflake_sql_jinja_render, + {}, + sql_scripts_paths, + ) + + for index, sql_script_path in enumerate(sql_scripts_paths): + console.step(f"Executing SQL script: {sql_script_path}") + _execute_sql_script( + script_content=scripts_content_list[index], + database_name=database_name, + ) + + +def render_script_templates( + project_root: Path, + render_from_str: Callable[[str, Dict[str, Any]], str], + jinja_context: dict[str, Any], + scripts: List[str], +) -> List[str]: + """ + Input: + - project_root: path to project root + - render_from_str: function which renders a jinja template from a string and jinja context + - jinja_context: a dictionary with the jinja context + - scripts: list of script paths relative to the project root + Returns: + - List of rendered scripts content + Size of the return list is the same as the size of the input scripts list. + """ + scripts_contents = [] + for relpath in scripts: + script_full_path = SecurePath(project_root) / relpath + try: + template_content = script_full_path.read_text(file_size_limit_mb=UNLIMITED) + result = render_from_str(template_content, jinja_context) + scripts_contents.append(result) + + except FileNotFoundError as e: + raise MissingScriptError(relpath) from e + + except jinja2.TemplateSyntaxError as e: + raise InvalidScriptError(relpath, e, e.lineno) from e + + except jinja2.UndefinedError as e: + raise InvalidScriptError(relpath, e) from e + + return scripts_contents diff --git a/src/snowflake/cli/api/project/schemas/entities/common.py b/src/snowflake/cli/api/project/schemas/entities/common.py index 1ee7b9564a..e9307a462f 100644 --- a/src/snowflake/cli/api/project/schemas/entities/common.py +++ b/src/snowflake/cli/api/project/schemas/entities/common.py @@ -20,15 +20,20 @@ from pydantic import Field, PrivateAttr, field_validator from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.project.schemas.identifier_model import Identifier -from snowflake.cli.api.project.schemas.native_app.application import ( - PostDeployHook, -) from snowflake.cli.api.project.schemas.updatable_model import ( IdentifierField, UpdatableModel, ) +class SqlScriptHookType(UpdatableModel): + sql_script: str = Field(title="SQL file path relative to the project root") + + +# Currently sql_script is the only supported hook type. Change to a Union once other hook types are added +PostDeployHook = SqlScriptHookType + + class MetaField(UpdatableModel): warehouse: Optional[str] = IdentifierField( title="Warehouse used to run the scripts", default=None diff --git a/src/snowflake/cli/api/project/schemas/native_app/application.py b/src/snowflake/cli/api/project/schemas/native_app/application.py index 7c392b9d50..08ee24f17f 100644 --- a/src/snowflake/cli/api/project/schemas/native_app/application.py +++ b/src/snowflake/cli/api/project/schemas/native_app/application.py @@ -17,6 +17,7 @@ from typing import List, Optional from pydantic import Field, field_validator +from snowflake.cli.api.project.schemas.entities.common import PostDeployHook from snowflake.cli.api.project.schemas.updatable_model import ( IdentifierField, UpdatableModel, @@ -24,14 +25,6 @@ from snowflake.cli.api.project.util import append_test_resource_suffix -class SqlScriptHookType(UpdatableModel): - sql_script: str = Field(title="SQL file path relative to the project root") - - -# Currently sql_script is the only supported hook type. Change to a Union once other hook types are added -PostDeployHook = SqlScriptHookType - - class Application(UpdatableModel): role: Optional[str] = Field( title="Role to use when creating the application object and consumer-side objects", diff --git a/src/snowflake/cli/api/project/schemas/native_app/package.py b/src/snowflake/cli/api/project/schemas/native_app/package.py index da8cd8c355..63a17b695a 100644 --- a/src/snowflake/cli/api/project/schemas/native_app/package.py +++ b/src/snowflake/cli/api/project/schemas/native_app/package.py @@ -17,7 +17,7 @@ from typing import List, Literal, Optional from pydantic import Field, field_validator, model_validator -from snowflake.cli.api.project.schemas.native_app.application import PostDeployHook +from snowflake.cli.api.project.schemas.entities.common import PostDeployHook from snowflake.cli.api.project.schemas.updatable_model import ( IdentifierField, UpdatableModel, diff --git a/src/snowflake/cli/api/sql_execution.py b/src/snowflake/cli/api/sql_execution.py index ef9a00f4f1..9a79c8c4fd 100644 --- a/src/snowflake/cli/api/sql_execution.py +++ b/src/snowflake/cli/api/sql_execution.py @@ -40,7 +40,7 @@ from snowflake.connector.errors import ProgrammingError -class SqlExecutionMixin: +class SqlExecutor: def __init__(self, connection: SnowflakeConnection | None = None): self._snowpark_session = None self._connection = connection @@ -51,16 +51,6 @@ def _conn(self) -> SnowflakeConnection: return self._connection return get_cli_context().connection - @property - def snowpark_session(self): - if not self._snowpark_session: - from snowflake.snowpark.session import Session - - self._snowpark_session = Session.builder.configs( - {"connection": self._conn} - ).create() - return self._snowpark_session - @cached_property def _log(self): return logging.getLogger(__name__) @@ -92,6 +82,12 @@ def _execute_query(self, query: str, **kwargs): def _execute_queries(self, queries: str, **kwargs): return list(self._execute_string(dedent(queries), **kwargs)) + def execute_query(self, query: str, **kwargs): + return self._execute_query(query, **kwargs) + + def execute_queries(self, queries: str, **kwargs): + return self._execute_queries(queries, **kwargs) + def use(self, object_type: ObjectType, name: str): try: self._execute_query(f"use {object_type.value.sf_name} {name}") @@ -258,6 +254,22 @@ def show_specific_object( return show_obj_row +class SqlExecutionMixin(SqlExecutor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._snowpark_session = None + + @property + def snowpark_session(self): + if not self._snowpark_session: + from snowflake.snowpark.session import Session + + self._snowpark_session = Session.builder.configs( + {"connection": self._conn} + ).create() + return self._snowpark_session + + class VerboseCursor(SnowflakeCursor): def execute(self, command: str, *args, **kwargs): cli_console.message(command) diff --git a/tests/nativeapp/patch_utils.py b/tests/nativeapp/patch_utils.py index 1953037108..bfb0557022 100644 --- a/tests/nativeapp/patch_utils.py +++ b/tests/nativeapp/patch_utils.py @@ -15,7 +15,7 @@ from unittest import mock from unittest.mock import PropertyMock -from tests.nativeapp.utils import NATIVEAPP_MANAGER_APP_PKG_DISTRIBUTION_IN_SF +from tests.nativeapp.utils import APP_PACKAGE_ENTITY_DISTRIBUTION_IN_SF def mock_connection(): @@ -27,8 +27,7 @@ def mock_connection(): def mock_get_app_pkg_distribution_in_sf(): return mock.patch( - NATIVEAPP_MANAGER_APP_PKG_DISTRIBUTION_IN_SF, - new_callable=PropertyMock, + APP_PACKAGE_ENTITY_DISTRIBUTION_IN_SF, ) diff --git a/tests/nativeapp/test_manager.py b/tests/nativeapp/test_manager.py index a631bef82f..0ce1369eb5 100644 --- a/tests/nativeapp/test_manager.py +++ b/tests/nativeapp/test_manager.py @@ -41,13 +41,15 @@ from snowflake.cli._plugins.nativeapp.manager import ( NativeAppManager, SnowflakeSQLExecutionError, - _get_stage_paths_to_sync, - ensure_correct_owner, ) from snowflake.cli._plugins.stage.diff import ( DiffResult, StagePath, ) +from snowflake.cli.api.entities.utils import ( + _get_stage_paths_to_sync, + ensure_correct_owner, +) from snowflake.cli.api.errno import DOES_NOT_EXIST_OR_NOT_AUTHORIZED from snowflake.cli.api.project.definition_manager import DefinitionManager from snowflake.connector import ProgrammingError @@ -58,13 +60,15 @@ mock_get_app_pkg_distribution_in_sf, ) from tests.nativeapp.utils import ( + APP_PACKAGE_ENTITY_GET_EXISTING_APP_PKG_INFO, + APP_PACKAGE_ENTITY_IS_DISTRIBUTION_SAME, + ENTITIES_UTILS_MODULE, NATIVEAPP_MANAGER_ACCOUNT_EVENT_TABLE, NATIVEAPP_MANAGER_BUILD_BUNDLE, NATIVEAPP_MANAGER_DEPLOY, NATIVEAPP_MANAGER_EXECUTE, - NATIVEAPP_MANAGER_GET_EXISTING_APP_PKG_INFO, - NATIVEAPP_MANAGER_IS_APP_PKG_DISTRIBUTION_SAME, NATIVEAPP_MODULE, + SQL_EXECUTOR_EXECUTE, mock_execute_helper, mock_snowflake_yml_file, quoted_override_yml_file, @@ -95,9 +99,9 @@ def _get_na_manager(working_dir: Optional[str] = None): ) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) -@mock.patch(f"{NATIVEAPP_MODULE}.compute_stage_diff") -@mock.patch(f"{NATIVEAPP_MODULE}.sync_local_diff_with_stage") +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(f"{ENTITIES_UTILS_MODULE}.compute_stage_diff") +@mock.patch(f"{ENTITIES_UTILS_MODULE}.sync_local_diff_with_stage") def test_sync_deploy_root_with_stage( mock_local_diff_with_stage, mock_compute_stage_diff, @@ -151,9 +155,9 @@ def test_sync_deploy_root_with_stage( ) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) -@mock.patch(f"{NATIVEAPP_MODULE}.sync_local_diff_with_stage") -@mock.patch(f"{NATIVEAPP_MODULE}.compute_stage_diff") +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(f"{ENTITIES_UTILS_MODULE}.sync_local_diff_with_stage") +@mock.patch(f"{ENTITIES_UTILS_MODULE}.compute_stage_diff") @mock.patch(f"{NATIVEAPP_MODULE}.cc.warning") @pytest.mark.parametrize( "prune,only_on_stage_files,expected_warn", @@ -208,7 +212,7 @@ def test_sync_deploy_root_with_stage_prune( mock_warning.assert_not_called() -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE) def test_get_app_pkg_distribution_in_snowflake(mock_execute, temp_dir, mock_cursor): side_effects, expected = mock_execute_helper( @@ -247,7 +251,7 @@ def test_get_app_pkg_distribution_in_snowflake(mock_execute, temp_dir, mock_curs assert mock_execute.mock_calls == expected -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE) def test_get_app_pkg_distribution_in_snowflake_throws_programming_error( mock_execute, temp_dir, mock_cursor ): @@ -285,7 +289,7 @@ def test_get_app_pkg_distribution_in_snowflake_throws_programming_error( assert mock_execute.mock_calls == expected -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE) def test_get_app_pkg_distribution_in_snowflake_throws_execution_error( mock_execute, temp_dir, mock_cursor ): @@ -317,7 +321,7 @@ def test_get_app_pkg_distribution_in_snowflake_throws_execution_error( assert mock_execute.mock_calls == expected -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE) def test_get_app_pkg_distribution_in_snowflake_throws_distribution_error( mock_execute, temp_dir, mock_cursor ): @@ -506,7 +510,7 @@ def test_get_existing_app_info_app_does_not_exist(mock_execute, temp_dir, mock_c assert mock_execute.mock_calls == expected -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE) def test_get_existing_app_pkg_info_app_pkg_exists(mock_execute, temp_dir, mock_cursor): side_effects, expected = mock_execute_helper( [ @@ -551,7 +555,7 @@ def test_get_existing_app_pkg_info_app_pkg_exists(mock_execute, temp_dir, mock_c assert mock_execute.mock_calls == expected -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE) def test_get_existing_app_pkg_info_app_pkg_does_not_exist( mock_execute, temp_dir, mock_cursor ): @@ -729,8 +733,8 @@ def test_is_correct_owner_bad_owner(): # Test create_app_package() with no existing package available -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) -@mock.patch(NATIVEAPP_MANAGER_GET_EXISTING_APP_PKG_INFO, return_value=None) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(APP_PACKAGE_ENTITY_GET_EXISTING_APP_PKG_INFO, return_value=None) def test_create_app_pkg_no_existing_package( mock_get_existing_app_pkg_info, mock_execute, temp_dir, mock_cursor ): @@ -772,7 +776,7 @@ def test_create_app_pkg_no_existing_package( # Test create_app_package() with incorrect owner -@mock.patch(NATIVEAPP_MANAGER_GET_EXISTING_APP_PKG_INFO) +@mock.patch(APP_PACKAGE_ENTITY_GET_EXISTING_APP_PKG_INFO) def test_create_app_pkg_incorrect_owner(mock_get_existing_app_pkg_info, temp_dir): mock_get_existing_app_pkg_info.return_value = { "name": "APP_PKG", @@ -794,9 +798,9 @@ def test_create_app_pkg_incorrect_owner(mock_get_existing_app_pkg_info, temp_dir # Test create_app_package() with distribution external AND variable mismatch -@mock.patch(NATIVEAPP_MANAGER_GET_EXISTING_APP_PKG_INFO) +@mock.patch(APP_PACKAGE_ENTITY_GET_EXISTING_APP_PKG_INFO) @mock_get_app_pkg_distribution_in_sf() -@mock.patch(NATIVEAPP_MANAGER_IS_APP_PKG_DISTRIBUTION_SAME) +@mock.patch(APP_PACKAGE_ENTITY_IS_DISTRIBUTION_SAME) @mock.patch(f"{NATIVEAPP_MODULE}.cc.warning") @pytest.mark.parametrize( "is_pkg_distribution_same", @@ -835,9 +839,9 @@ def test_create_app_pkg_external_distribution( # Test create_app_package() with distribution internal AND variable mismatch AND special comment is True -@mock.patch(NATIVEAPP_MANAGER_GET_EXISTING_APP_PKG_INFO) +@mock.patch(APP_PACKAGE_ENTITY_GET_EXISTING_APP_PKG_INFO) @mock_get_app_pkg_distribution_in_sf() -@mock.patch(NATIVEAPP_MANAGER_IS_APP_PKG_DISTRIBUTION_SAME) +@mock.patch(APP_PACKAGE_ENTITY_IS_DISTRIBUTION_SAME) @mock.patch(f"{NATIVEAPP_MODULE}.cc.warning") @pytest.mark.parametrize( "is_pkg_distribution_same, special_comment", @@ -882,9 +886,9 @@ def test_create_app_pkg_internal_distribution_special_comment( # Test create_app_package() with distribution internal AND variable mismatch AND special comment is False -@mock.patch(NATIVEAPP_MANAGER_GET_EXISTING_APP_PKG_INFO) +@mock.patch(APP_PACKAGE_ENTITY_GET_EXISTING_APP_PKG_INFO) @mock_get_app_pkg_distribution_in_sf() -@mock.patch(NATIVEAPP_MANAGER_IS_APP_PKG_DISTRIBUTION_SAME) +@mock.patch(APP_PACKAGE_ENTITY_IS_DISTRIBUTION_SAME) @mock.patch(f"{NATIVEAPP_MODULE}.cc.warning") @pytest.mark.parametrize( "is_pkg_distribution_same", diff --git a/tests/nativeapp/test_package_scripts.py b/tests/nativeapp/test_package_scripts.py index 2809c3914c..ed18f0708e 100644 --- a/tests/nativeapp/test_package_scripts.py +++ b/tests/nativeapp/test_package_scripts.py @@ -35,6 +35,8 @@ from tests.nativeapp.utils import ( NATIVEAPP_MANAGER_EXECUTE, NATIVEAPP_MANAGER_EXECUTE_QUERIES, + SQL_EXECUTOR_EXECUTE, + SQL_EXECUTOR_EXECUTE_QUERIES, ) from tests.testing_utils.fixtures import MockConnectionCtx @@ -47,8 +49,8 @@ def _get_na_manager(working_dir): ) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE_QUERIES) +@mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() @pytest.mark.parametrize( "project_definition_files, expected_calls", @@ -141,8 +143,8 @@ def test_package_scripts_without_conn_info_throws_error( # Without connection warehouse, with PDF warehouse -@mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE_QUERIES) +@mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() @pytest.mark.parametrize( "project_definition_files", ["napp_project_with_pkg_warehouse"], indirect=True @@ -194,8 +196,10 @@ def test_package_scripts_without_conn_info_succeeds( @mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) +@mock_connection() @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) -def test_missing_package_script(mock_execute, project_definition_files): +def test_missing_package_script(mock_conn, mock_execute, project_definition_files): + mock_conn.return_value = MockConnectionCtx() working_dir: Path = project_definition_files[0].parent native_app_manager = _get_na_manager(str(working_dir)) with pytest.raises(MissingScriptError): @@ -207,8 +211,10 @@ def test_missing_package_script(mock_execute, project_definition_files): @mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) +@mock_connection() @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) -def test_invalid_package_script(mock_execute, project_definition_files): +def test_invalid_package_script(mock_conn, mock_execute, project_definition_files): + mock_conn.return_value = MockConnectionCtx() working_dir: Path = project_definition_files[0].parent native_app_manager = _get_na_manager(str(working_dir)) with pytest.raises(InvalidScriptError): @@ -222,8 +228,12 @@ def test_invalid_package_script(mock_execute, project_definition_files): @mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) +@mock_connection() @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) -def test_undefined_var_package_script(mock_execute, project_definition_files): +def test_undefined_var_package_script( + mock_conn, mock_execute, project_definition_files +): + mock_conn.return_value = MockConnectionCtx() working_dir: Path = project_definition_files[0].parent native_app_manager = _get_na_manager(str(working_dir)) with pytest.raises(InvalidScriptError): @@ -235,8 +245,8 @@ def test_undefined_var_package_script(mock_execute, project_definition_files): assert mock_execute.mock_calls == [] -@mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE_QUERIES) +@mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) def test_package_scripts_w_missing_warehouse_exception( @@ -267,7 +277,7 @@ def test_package_scripts_w_missing_warehouse_exception( assert "Please provide a warehouse for the active session role" in err.value.msg -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) def test_package_scripts_w_warehouse_access_exception( diff --git a/tests/nativeapp/test_post_deploy_for_app.py b/tests/nativeapp/test_post_deploy_for_app.py index 341d444662..36acb76918 100644 --- a/tests/nativeapp/test_post_deploy_for_app.py +++ b/tests/nativeapp/test_post_deploy_for_app.py @@ -22,16 +22,14 @@ from snowflake.cli._plugins.nativeapp.run_processor import NativeAppRunProcessor from snowflake.cli.api.project.definition_manager import DefinitionManager from snowflake.cli.api.project.errors import SchemaValidationError -from snowflake.cli.api.project.schemas.native_app.application import ( - PostDeployHook, -) +from snowflake.cli.api.project.schemas.entities.common import PostDeployHook from tests.nativeapp.patch_utils import mock_connection from tests.nativeapp.utils import ( CLI_GLOBAL_TEMPLATE_CONTEXT, - NATIVEAPP_MANAGER_EXECUTE, - NATIVEAPP_MANAGER_EXECUTE_QUERIES, RUN_PROCESSOR_APP_POST_DEPLOY_HOOKS, + SQL_EXECUTOR_EXECUTE, + SQL_EXECUTOR_EXECUTE_QUERIES, ) from tests.testing_utils.fixtures import MockConnectionCtx @@ -47,8 +45,8 @@ def _get_run_processor(working_dir): ) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE_QUERIES) @mock.patch(CLI_GLOBAL_TEMPLATE_CONTEXT, new_callable=mock.PropertyMock) @mock.patch.dict(os.environ, {"USER": "test_user"}) @mock_connection() @@ -88,8 +86,8 @@ def test_sql_scripts( ] -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE_QUERIES) @mock.patch(CLI_GLOBAL_TEMPLATE_CONTEXT, new_callable=mock.PropertyMock) @mock_connection() @mock.patch(MOCK_CONNECTION_DB, new_callable=mock.PropertyMock) diff --git a/tests/nativeapp/test_post_deploy_for_package.py b/tests/nativeapp/test_post_deploy_for_package.py index c69f9b1c30..1436426838 100644 --- a/tests/nativeapp/test_post_deploy_for_package.py +++ b/tests/nativeapp/test_post_deploy_for_package.py @@ -27,14 +27,14 @@ from tests.nativeapp.patch_utils import mock_connection from tests.nativeapp.utils import ( CLI_GLOBAL_TEMPLATE_CONTEXT, - NATIVEAPP_MANAGER_EXECUTE, - NATIVEAPP_MANAGER_EXECUTE_QUERIES, + SQL_EXECUTOR_EXECUTE, + SQL_EXECUTOR_EXECUTE_QUERIES, ) from tests.testing_utils.fixtures import MockConnectionCtx -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE_QUERIES) @mock.patch(CLI_GLOBAL_TEMPLATE_CONTEXT, new_callable=mock.PropertyMock) @mock.patch.dict(os.environ, {"USER": "test_user"}) @mock_connection() @@ -76,8 +76,8 @@ def test_package_post_deploy_scripts( ] -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE_QUERIES) @mock.patch(CLI_GLOBAL_TEMPLATE_CONTEXT, new_callable=mock.PropertyMock) @mock.patch.dict(os.environ, {"USER": "test_user"}) @mock_connection() @@ -129,7 +129,7 @@ def test_package_post_deploy_scripts_with_non_existing_scripts( ) -@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE) @mock.patch(CLI_GLOBAL_TEMPLATE_CONTEXT, new_callable=mock.PropertyMock) @mock.patch.dict(os.environ, {"USER": "test_user"}) @mock_connection() diff --git a/tests/nativeapp/test_project_model.py b/tests/nativeapp/test_project_model.py index 6136a96ac4..67c921fe94 100644 --- a/tests/nativeapp/test_project_model.py +++ b/tests/nativeapp/test_project_model.py @@ -27,7 +27,7 @@ NativeAppProjectModel, ) from snowflake.cli.api.project.definition import load_project -from snowflake.cli.api.project.schemas.native_app.application import SqlScriptHookType +from snowflake.cli.api.project.schemas.entities.common import SqlScriptHookType from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping from snowflake.cli.api.project.schemas.project_definition import ( build_project_definition, diff --git a/tests/nativeapp/test_teardown_processor.py b/tests/nativeapp/test_teardown_processor.py index 6d79a179a8..47f0da7bb1 100644 --- a/tests/nativeapp/test_teardown_processor.py +++ b/tests/nativeapp/test_teardown_processor.py @@ -43,6 +43,7 @@ NATIVEAPP_MANAGER_EXECUTE, NATIVEAPP_MANAGER_GET_OBJECTS_OWNED_BY_APPLICATION, NATIVEAPP_MANAGER_IS_APP_PKG_DISTRIBUTION_SAME, + SQL_EXECUTOR_EXECUTE, TEARDOWN_MODULE, TEARDOWN_PROCESSOR_DROP_GENERIC_OBJECT, TEARDOWN_PROCESSOR_GET_EXISTING_APP_INFO, @@ -747,6 +748,7 @@ def test_drop_package_variable_mistmatch_w_special_comment_auto_drop( # Test drop_package when there is no distribution mismatch AND distribution = internal AND special comment is True AND name is quoted @mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock.patch(SQL_EXECUTOR_EXECUTE) @mock_get_app_pkg_distribution_in_sf() @pytest.mark.parametrize( "auto_yes_param, special_comment", # auto_yes should have no effect on the test @@ -759,7 +761,8 @@ def test_drop_package_variable_mistmatch_w_special_comment_auto_drop( ) def test_drop_package_variable_mistmatch_w_special_comment_quoted_name_auto_drop( mock_get_distribution, - mock_execute, + mock_execute_sql, + mock_execute_na, auto_yes_param, special_comment, temp_dir, @@ -767,6 +770,12 @@ def test_drop_package_variable_mistmatch_w_special_comment_quoted_name_auto_drop ): mock_get_distribution.return_value = "internal" + # We are mocking _execute_query on both NativeAppManager and SqlExecutor. Attaching both to a single mock to verify the order of calls. + # The first 4 calls are expected to be on SqlExecutor, and the rest on NativeAppManager. + mock_execute = mock.MagicMock() + mock_execute.attach_mock(mock_execute_na, "execute_na") + mock_execute.attach_mock(mock_execute_sql, "execute_sql") + side_effects, expected = mock_execute_helper( [ # Show app pkg @@ -816,7 +825,9 @@ def test_drop_package_variable_mistmatch_w_special_comment_quoted_name_auto_drop (None, mock.call("use role old_role")), ] ) - mock_execute.side_effect = side_effects + + mock_execute_sql.side_effect = side_effects[::3] + mock_execute_na.side_effect = side_effects[4::] current_working_directory = os.getcwd() create_named_file( diff --git a/tests/nativeapp/utils.py b/tests/nativeapp/utils.py index 099c8932d4..72acc17018 100644 --- a/tests/nativeapp/utils.py +++ b/tests/nativeapp/utils.py @@ -30,6 +30,8 @@ TYPER_PROMPT = "typer.prompt" RUN_MODULE = "snowflake.cli._plugins.nativeapp.run_processor" VERSION_MODULE = "snowflake.cli._plugins.nativeapp.version.version_processor" +ENTITIES_COMMON_MODULE = "snowflake.cli.api.entities.common" +ENTITIES_UTILS_MODULE = "snowflake.cli.api.entities.utils" CLI_GLOBAL_TEMPLATE_CONTEXT = ( "snowflake.cli.api.cli_global_context._CliGlobalContextAccess.template_context" @@ -70,6 +72,22 @@ FIND_VERSION_FROM_MANIFEST = f"{VERSION_MODULE}.find_version_info_in_manifest_file" +APP_PACKAGE_ENTITY = ( + "snowflake.cli.api.entities.application_package_entity.ApplicationPackageEntity" +) +APP_PACKAGE_ENTITY_DISTRIBUTION_IN_SF = ( + f"{APP_PACKAGE_ENTITY}.get_app_pkg_distribution_in_snowflake" +) +APP_PACKAGE_ENTITY_GET_EXISTING_APP_PKG_INFO = ( + f"{APP_PACKAGE_ENTITY}.get_existing_app_pkg_info" +) +APP_PACKAGE_ENTITY_IS_DISTRIBUTION_SAME = ( + f"{APP_PACKAGE_ENTITY}.verify_project_distribution" +) + +SQL_EXECUTOR_EXECUTE = f"{ENTITIES_COMMON_MODULE}.SqlExecutor._execute_query" +SQL_EXECUTOR_EXECUTE_QUERIES = f"{ENTITIES_COMMON_MODULE}.SqlExecutor._execute_queries" + mock_snowflake_yml_file = dedent( """\ definition_version: 1