diff --git a/.github/workflows/test_integration.yaml b/.github/workflows/test_integration.yaml index c038051615..95952f6236 100644 --- a/.github/workflows/test_integration.yaml +++ b/.github/workflows/test_integration.yaml @@ -35,7 +35,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip hatch diff --git a/.gitignore b/.gitignore index 8157fa62fa..0e4e3f615f 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ gen_docs/ /venv/ .env .vscode +tmp/ ^app.zip ^snowflake.yml diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 7ba28b355f..925bc91b07 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -22,6 +22,7 @@ ## New additions * Added connection option `--token-file-path` allowing passing OAuth token using a file. The function is also supported by setting `token_file_path` in connection definition. +* Support for Python remote execution via `snow stage execute` and `snow git execute` similar to existing EXECUTE IMMEDIATE support. ## Fixes and improvements * The `snow app run` command now allows upgrading to unversioned mode from a versioned or release mode application installation @@ -29,7 +30,7 @@ * The `snow app version create` command now allows operating on application packages created outside the CLI * Added support for user stages in stage execute command * Added support for user stages in stage and git copy commands - +* Improved support for quoted identifiers in snowpark commands. # v2.6.0 ## Backward incompatibility diff --git a/pyproject.toml b/pyproject.toml index 254b9a3836..e0a9ba654b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "setuptools==70.2.0", 'snowflake.core==0.8.0; python_version < "3.12"', "snowflake-connector-python[secure-local-storage]==3.11.0", + 'snowflake-snowpark-python>=1.15.0;python_version < "3.12"', "tomlkit==0.12.5", "typer==0.12.3", "urllib3>=1.24.3,<2.3", diff --git a/snyk/requirements.txt b/snyk/requirements.txt index 74f5981dd5..84a73b7e62 100644 --- a/snyk/requirements.txt +++ b/snyk/requirements.txt @@ -8,6 +8,7 @@ requirements-parser==0.9.0 setuptools==70.2.0 snowflake.core==0.8.0; python_version < "3.12" snowflake-connector-python[secure-local-storage]==3.11.0 +snowflake-snowpark-python>=1.15.0;python_version < "3.12" tomlkit==0.12.5 typer==0.12.3 urllib3>=1.24.3,<2.3 diff --git a/src/snowflake/cli/api/commands/flags.py b/src/snowflake/cli/api/commands/flags.py index f7c8eb34c8..0e20ecaf97 100644 --- a/src/snowflake/cli/api/commands/flags.py +++ b/src/snowflake/cli/api/commands/flags.py @@ -449,7 +449,10 @@ def _password_callback(value: str): None, "--variable", "-D", - help="Variables for the template. For example: `-D \"=\"`, string values must be in `''`.", + help='Variables for the execution context. For example: `-D "="`. ' + "For SQL files variables are use to expand the template and any unknown variable will cause an error. " + "For Python files variables are used to update os.environ dictionary. Provided keys are capitalized to adhere to best practices." + "In case of SQL files string values must be quoted in `''` (consider embedding quoting in the file).", show_default=False, ) @@ -612,10 +615,11 @@ def __init__(self, key: str, value: str): def parse_key_value_variables(variables: Optional[List[str]]) -> List[Variable]: """Util for parsing key=value input. Useful for commands accepting multiple input options.""" + if not variables: + return [] result: List[Variable] = [] if not variables: return result - for p in variables: if "=" not in p: raise ClickException(f"Invalid variable: '{p}'") diff --git a/src/snowflake/cli/api/identifiers.py b/src/snowflake/cli/api/identifiers.py index b9f5c37c95..886cbc72e4 100644 --- a/src/snowflake/cli/api/identifiers.py +++ b/src/snowflake/cli/api/identifiers.py @@ -53,11 +53,17 @@ def name(self) -> str: return self._name @property - def identifier(self) -> str: + def prefix(self) -> str: if self.database: - return f"{self.database}.{self.schema if self.schema else 'PUBLIC'}.{self.name}" + return f"{self.database}.{self.schema if self.schema else 'PUBLIC'}" if self.schema: - return f"{self.schema}.{self.name}" + return f"{self.schema}" + return "" + + @property + def identifier(self) -> str: + if self.prefix: + return f"{self.prefix}.{self.name}" return self.name @property @@ -96,6 +102,13 @@ def from_string(cls, identifier: str) -> "FQN": unqualified_name = unqualified_name + signature return cls(name=unqualified_name, schema=schema, database=database) + @classmethod + def from_stage(cls, stage: str) -> "FQN": + name = stage + if stage.startswith("@"): + name = stage[1:] + return cls.from_string(name) + @classmethod def from_identifier_model(cls, model: ObjectIdentifierBaseModel) -> "FQN": """Create an instance from object model.""" diff --git a/src/snowflake/cli/api/sql_execution.py b/src/snowflake/cli/api/sql_execution.py index cb7b89a3fb..cb81cc3865 100644 --- a/src/snowflake/cli/api/sql_execution.py +++ b/src/snowflake/cli/api/sql_execution.py @@ -41,12 +41,22 @@ class SqlExecutionMixin: def __init__(self): - pass + self._snowpark_session = None @property def _conn(self): return cli_context.connection + @property + def snowpark_session(self): + if not self._snowpark_session: + from snowflake.snowpark.session import Session + + self._snowpark_session = Session.builder.configs( + {"connection": self._conn} + ).create() + return self._snowpark_session + @cached_property def _log(self): return logging.getLogger(__name__) diff --git a/src/snowflake/cli/app/snow_connector.py b/src/snowflake/cli/app/snow_connector.py index d2806b7cd5..b7b5432307 100644 --- a/src/snowflake/cli/app/snow_connector.py +++ b/src/snowflake/cli/app/snow_connector.py @@ -21,7 +21,12 @@ import snowflake.connector from click.exceptions import ClickException -from snowflake.cli.api.config import get_connection_dict, get_default_connection_dict +from snowflake.cli.api.cli_global_context import cli_context +from snowflake.cli.api.config import ( + get_connection_dict, + get_default_connection_dict, + get_default_connection_name, +) from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB from snowflake.cli.api.exceptions import ( InvalidConnectionConfiguration, @@ -70,6 +75,9 @@ def connect_to_snowflake( connection_parameters = {} # we will apply overrides in next step else: connection_parameters = get_default_connection_dict() + cli_context.connection_context.set_connection_name( + get_default_connection_name() + ) # Apply overrides to connection details for key, value in overrides.items(): diff --git a/src/snowflake/cli/plugins/cortex/commands.py b/src/snowflake/cli/plugins/cortex/commands.py index 2b36ae0416..0f866786b4 100644 --- a/src/snowflake/cli/plugins/cortex/commands.py +++ b/src/snowflake/cli/plugins/cortex/commands.py @@ -24,6 +24,7 @@ from snowflake.cli.api.cli_global_context import cli_context from snowflake.cli.api.commands.flags import readable_file_option from snowflake.cli.api.commands.snow_typer import SnowTyperFactory +from snowflake.cli.api.constants import PYTHON_3_12 from snowflake.cli.api.output.types import ( CollectionResult, CommandResult, @@ -45,7 +46,7 @@ help="Provides access to Snowflake Cortex.", ) -SEARCH_COMMAND_ENABLED = sys.version_info < (3, 12) +SEARCH_COMMAND_ENABLED = sys.version_info < PYTHON_3_12 @app.command( diff --git a/src/snowflake/cli/plugins/snowpark/commands.py b/src/snowflake/cli/plugins/snowpark/commands.py index ee2eef7097..e4d84e3636 100644 --- a/src/snowflake/cli/plugins/snowpark/commands.py +++ b/src/snowflake/cli/plugins/snowpark/commands.py @@ -67,7 +67,8 @@ from snowflake.cli.plugins.object.manager import ObjectManager from snowflake.cli.plugins.snowpark import package_utils from snowflake.cli.plugins.snowpark.common import ( - build_udf_sproc_identifier, + FunctionOrProcedure, + UdfSprocIdentifier, check_if_replace_is_required, ) from snowflake.cli.plugins.snowpark.manager import FunctionManager, ProcedureManager @@ -220,7 +221,7 @@ def deploy( def _assert_object_definitions_are_correct( - object_type, object_definitions: List[FunctionSchema | ProcedureSchema] + object_type, object_definitions: List[FunctionOrProcedure] ): for definition in object_definitions: database = definition.database @@ -239,14 +240,14 @@ def _assert_object_definitions_are_correct( def _find_existing_objects( object_type: ObjectType, - objects: List[Dict], + objects: List[FunctionOrProcedure], om: ObjectManager, ): existing_objects = {} for object_definition in objects: - identifier = build_udf_sproc_identifier( - object_definition, om, include_parameter_names=False - ) + identifier = UdfSprocIdentifier.from_definition( + object_definition + ).identifier_with_arg_types try: current_state = om.describe( object_type=object_type.value.sf_name, @@ -295,21 +296,17 @@ def get_app_stage_path(stage_name: Optional[str], project_name: str) -> str: def _deploy_single_object( manager: FunctionManager | ProcedureManager, object_type: ObjectType, - object_definition: FunctionSchema | ProcedureSchema, + object_definition: FunctionOrProcedure, existing_objects: Dict[str, Dict], snowflake_dependencies: List[str], stage_artifact_path: str, ): - identifier = build_udf_sproc_identifier( - object_definition, manager, include_parameter_names=False - ) - identifier_with_default_values = build_udf_sproc_identifier( - object_definition, - manager, - include_parameter_names=True, - include_default_values=True, + + identifiers = UdfSprocIdentifier.from_definition(object_definition) + + log.info( + "Deploying %s: %s", object_type, identifiers.identifier_with_arg_names_types ) - log.info("Deploying %s: %s", object_type, identifier_with_default_values) handler = object_definition.handler returns = object_definition.returns @@ -317,11 +314,11 @@ def _deploy_single_object( external_access_integrations = object_definition.external_access_integrations replace_object = False - object_exists = identifier in existing_objects + object_exists = identifiers.identifier_with_arg_types in existing_objects if object_exists: replace_object = check_if_replace_is_required( object_type=object_type, - current_state=existing_objects[identifier], + current_state=existing_objects[identifiers.identifier_with_arg_types], handler=handler, return_type=returns, snowflake_dependencies=snowflake_dependencies, @@ -332,13 +329,13 @@ def _deploy_single_object( if object_exists and not replace_object: return { - "object": identifier_with_default_values, + "object": identifiers.identifier_with_arg_names_types_defaults, "type": str(object_type), "status": "packages updated", } create_or_replace_kwargs = { - "identifier": identifier_with_default_values, + "identifier": identifiers, "handler": handler, "return_type": returns, "artifact_file": stage_artifact_path, @@ -356,7 +353,7 @@ def _deploy_single_object( status = "created" if not object_exists else "definition updated" return { - "object": identifier_with_default_values, + "object": identifiers.identifier_with_arg_names_types_defaults, "type": str(object_type), "status": status, } diff --git a/src/snowflake/cli/plugins/snowpark/common.py b/src/snowflake/cli/plugins/snowpark/common.py index 1b8fb1e81e..bb80ee060c 100644 --- a/src/snowflake/cli/plugins/snowpark/common.py +++ b/src/snowflake/cli/plugins/snowpark/common.py @@ -15,11 +15,14 @@ from __future__ import annotations import re -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Union from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.identifiers import FQN -from snowflake.cli.api.project.schemas.snowpark.argument import Argument +from snowflake.cli.api.project.schemas.snowpark.callable import ( + FunctionSchema, + ProcedureSchema, +) from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.cli.plugins.snowpark.models import Requirement from snowflake.cli.plugins.snowpark.package_utils import ( @@ -28,6 +31,7 @@ from snowflake.connector.cursor import SnowflakeCursor DEFAULT_RUNTIME = "3.8" +FunctionOrProcedure = Union[FunctionSchema, ProcedureSchema] def check_if_replace_is_required( @@ -150,7 +154,7 @@ def artifact_stage_path(identifier: str): def create_query( self, - identifier: str, + identifier: UdfSprocIdentifier, return_type: str, handler: str, artifact_file: str, @@ -166,7 +170,7 @@ def create_query( packages_list = ",".join(f"'{p}'" for p in packages) query = [ - f"create or replace {self._object_type.value.sf_name} {identifier}", + f"create or replace {self._object_type.value.sf_name} {identifier.identifier_for_sql}", f"copy grants", f"returns {return_type}", "language python", @@ -198,30 +202,69 @@ def _is_signature_type_a_string(sig_type: str) -> bool: return sig_type.lower() in ["string", "varchar"] -def build_udf_sproc_identifier( - udf_sproc, - slq_exec_mixin, - include_parameter_names, - include_default_values=False, -): - def format_arg(arg: Argument): - result = f"{arg.arg_type}" - if include_parameter_names: - result = f"{arg.name} {result}" - if include_default_values and arg.default: - val = f"{arg.default}" - if _is_signature_type_a_string(arg.arg_type): - val = f"'{val}'" - result += f" default {val}" - return result - - if udf_sproc.signature and udf_sproc.signature != "null": - arguments = ", ".join(format_arg(arg) for arg in udf_sproc.signature) - else: - arguments = "" +class UdfSprocIdentifier: + def __init__(self, identifier: FQN, arg_names, arg_types, arg_defaults): + self._identifier = identifier + self._arg_names = arg_names + self._arg_types = arg_types + self._arg_defaults = arg_defaults + + def _identifier_from_signature(self, sig: List[str], for_sql: bool = False): + signature = self._comma_join(sig) + id_ = self._identifier.sql_identifier if for_sql else self._identifier + return f"{id_}({signature})" + + @staticmethod + def _comma_join(*args): + return ", ".join(*args) + + @property + def identifier_with_arg_names(self): + return self._identifier_from_signature(self._arg_names) - name = FQN.from_identifier_model(udf_sproc).using_context().identifier - return f"{name}({arguments})" + @property + def identifier_with_arg_types(self): + return self._identifier_from_signature(self._arg_types) + + @property + def identifier_with_arg_names_types(self): + sig = [f"{n} {t}" for n, t in zip(self._arg_names, self._arg_types)] + return self._identifier_from_signature(sig) + + @property + def identifier_with_arg_names_types_defaults(self): + return self._identifier_from_signature(self._full_signature()) + + def _full_signature(self): + sig = [] + for name, _type, _default in zip( + self._arg_names, self._arg_types, self._arg_defaults + ): + s = f"{name} {_type}" + if _default: + if _is_signature_type_a_string(_type): + _default = f"'{_default}'" + s += f" default {_default}" + sig.append(s) + return sig + + @property + def identifier_for_sql(self): + return self._identifier_from_signature(self._full_signature(), for_sql=True) + + @classmethod + def from_definition(cls, udf_sproc: FunctionOrProcedure): + names = [] + types = [] + defaults = [] + if udf_sproc.signature and udf_sproc.signature != "null": + for arg in udf_sproc.signature: + names.append(arg.name) + types.append(arg.arg_type) + defaults.append(arg.default) + + identifier = FQN.from_identifier_model(udf_sproc).using_context() + return cls(identifier, names, types, defaults) def _compare_imports( diff --git a/src/snowflake/cli/plugins/snowpark/manager.py b/src/snowflake/cli/plugins/snowpark/manager.py index 06464fee9f..8929db14e5 100644 --- a/src/snowflake/cli/plugins/snowpark/manager.py +++ b/src/snowflake/cli/plugins/snowpark/manager.py @@ -18,7 +18,10 @@ from typing import Dict, List, Optional from snowflake.cli.api.constants import ObjectType -from snowflake.cli.plugins.snowpark.common import SnowparkObjectManager +from snowflake.cli.plugins.snowpark.common import ( + SnowparkObjectManager, + UdfSprocIdentifier, +) from snowflake.connector.cursor import SnowflakeCursor log = logging.getLogger(__name__) @@ -35,7 +38,7 @@ def _object_execute(self): def create_or_replace( self, - identifier: str, + identifier: UdfSprocIdentifier, return_type: str, handler: str, artifact_file: str, @@ -45,7 +48,11 @@ def create_or_replace( secrets: Optional[Dict[str, str]] = None, runtime: Optional[str] = None, ) -> SnowflakeCursor: - log.debug("Creating function %s using @%s", identifier, artifact_file) + log.debug( + "Creating function %s using @%s", + identifier.identifier_with_arg_names_types_defaults, + artifact_file, + ) query = self.create_query( identifier, return_type, @@ -71,7 +78,7 @@ def _object_execute(self): def create_or_replace( self, - identifier: str, + identifier: UdfSprocIdentifier, return_type: str, handler: str, artifact_file: str, @@ -82,7 +89,11 @@ def create_or_replace( runtime: Optional[str] = None, execute_as_caller: bool = False, ) -> SnowflakeCursor: - log.debug("Creating procedure %s using @%s", identifier, artifact_file) + log.debug( + "Creating procedure %s using @%s", + identifier.identifier_with_arg_names_types_defaults, + artifact_file, + ) query = self.create_query( identifier, return_type, diff --git a/src/snowflake/cli/plugins/snowpark/models.py b/src/snowflake/cli/plugins/snowpark/models.py index 7480196f84..454976e34c 100644 --- a/src/snowflake/cli/plugins/snowpark/models.py +++ b/src/snowflake/cli/plugins/snowpark/models.py @@ -33,6 +33,10 @@ class YesNoAsk(Enum): class Requirement(requirement.Requirement): extra_pattern = re.compile("'([^']*)'") + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.package_name = None + @classmethod def parse_line(cls, line: str) -> Requirement: if len(line_elements := line.split(";")) > 1: @@ -44,6 +48,8 @@ def parse_line(cls, line: str) -> Requirement: if "extra" in element and (extras := cls.extra_pattern.search(element)): result.extras.extend(extras.groups()) + result.package_name = result.name + if result.uri and not result.name: result.name = get_package_name(result.uri) result.name = cls.standardize_name(result.name) diff --git a/src/snowflake/cli/plugins/stage/manager.py b/src/snowflake/cli/plugins/stage/manager.py index f4dd71208c..18c9d94afb 100644 --- a/src/snowflake/cli/plugins/stage/manager.py +++ b/src/snowflake/cli/plugins/stage/manager.py @@ -18,28 +18,44 @@ import glob import logging import re +import sys from contextlib import nullcontext from dataclasses import dataclass from os import path from pathlib import Path +from textwrap import dedent from typing import Dict, List, Optional, Union from click import ClickException -from snowflake.cli.api.commands.flags import OnErrorType, parse_key_value_variables +from snowflake.cli.api.commands.flags import ( + OnErrorType, + Variable, + parse_key_value_variables, +) from snowflake.cli.api.console import cli_console +from snowflake.cli.api.constants import PYTHON_3_12 +from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.project.util import to_string_literal from snowflake.cli.api.secure_path import SecurePath from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.cli.api.utils.path_utils import path_resolver +from snowflake.cli.plugins.snowpark.package_utils import parse_requirements from snowflake.connector import DictCursor, ProgrammingError from snowflake.connector.cursor import SnowflakeCursor +if sys.version_info < PYTHON_3_12: + # Because Snowpark works only below 3.12 and to use @sproc Session must be imported here. + from snowflake.snowpark import Session + log = logging.getLogger(__name__) UNQUOTED_FILE_URI_REGEX = r"[\w/*?\-.=&{}$#[\]\"\\!@%^+:]+" -EXECUTE_SUPPORTED_FILES_FORMATS = {".sql"} USER_STAGE_PREFIX = "@~" +EXECUTE_SUPPORTED_FILES_FORMATS = ( + ".sql", + ".py", +) # tuple to preserve order but it's a set @dataclass @@ -62,6 +78,11 @@ def add_stage_prefix(self, file_path: str) -> str: def get_directory_from_file_path(self, file_path: str) -> List[str]: raise NotImplementedError + def get_full_stage_path(self, path: str): + if prefix := FQN.from_stage(self.stage).prefix: + return prefix + "." + path + return path + @dataclass class DefaultStagePathParts(StagePathParts): @@ -129,6 +150,10 @@ def get_directory_from_file_path(self, file_path: str) -> List[str]: class StageManager(SqlExecutionMixin): + def __init__(self): + super().__init__() + self._python_exe_procedure = None + @staticmethod def get_standard_stage_prefix(name: str) -> str: # Handle embedded stages @@ -290,22 +315,40 @@ def execute( filtered_file_list, key=lambda f: (path.dirname(f), path.basename(f)) ) - sql_variables = self._parse_execute_variables(variables) + parsed_variables = parse_key_value_variables(variables) + sql_variables = self._parse_execute_variables(parsed_variables) + python_variables = {str(v.key): v.value for v in parsed_variables} results = [] + + if any(file.endswith(".py") for file in sorted_file_path_list): + self._python_exe_procedure = self._bootstrap_snowpark_execution_environment( + stage_path_parts + ) + for file_path in sorted_file_path_list: - results.append( - self._call_execute_immediate( - stage_path_parts=stage_path_parts, - file_path=file_path, + file_stage_path = stage_path_parts.add_stage_prefix(file_path) + if file_path.endswith(".py"): + result = self._execute_python( + file_stage_path=file_stage_path, + on_error=on_error, + variables=python_variables, + ) + else: + result = self._call_execute_immediate( + file_stage_path=file_stage_path, variables=sql_variables, on_error=on_error, ) - ) + results.append(result) return results - def _get_files_list_from_stage(self, stage_path_parts: StagePathParts) -> List[str]: - files_list_result = self.list_files(stage_path_parts.stage).fetchall() + def _get_files_list_from_stage( + self, stage_path_parts: StagePathParts, pattern: str | None = None + ) -> List[str]: + files_list_result = self.list_files( + stage_path_parts.stage, pattern=pattern + ).fetchall() if not files_list_result: raise ClickException(f"No files found on stage '{stage_path_parts.stage}'") @@ -327,9 +370,8 @@ def _filter_files_list( return filtered_files else: raise ClickException( - "Invalid file extension, only `.sql` files are allowed." + f"Invalid file extension, only {', '.join(EXECUTE_SUPPORTED_FILES_FORMATS)} files are allowed." ) - # Filter with fnmatch if contains `*` or `?` if glob.has_magic(stage_path): filtered_files = fnmatch.filter(files_on_stage, stage_path) @@ -343,34 +385,42 @@ def _filter_supported_files(files: List[str]) -> List[str]: return [f for f in files if Path(f).suffix in EXECUTE_SUPPORTED_FILES_FORMATS] @staticmethod - def _parse_execute_variables(variables: Optional[List[str]]) -> Optional[str]: + def _parse_execute_variables(variables: List[Variable]) -> Optional[str]: if not variables: return None - - parsed_variables = parse_key_value_variables(variables) - query_parameters = [f"{v.key}=>{v.value}" for v in parsed_variables] + query_parameters = [f"{v.key}=>{v.value}" for v in variables] return f" using ({', '.join(query_parameters)})" + @staticmethod + def _success_result(file: str): + cli_console.warning(f"SUCCESS - {file}") + return {"File": file, "Status": "SUCCESS", "Error": None} + + @staticmethod + def _error_result(file: str, msg: str): + cli_console.warning(f"FAILURE - {file}") + return {"File": file, "Status": "FAILURE", "Error": msg} + + @staticmethod + def _handle_execution_exception(on_error: OnErrorType, exception: Exception): + if on_error == OnErrorType.BREAK: + raise exception + def _call_execute_immediate( self, - stage_path_parts: StagePathParts, - file_path: str, + file_stage_path: str, variables: Optional[str], on_error: OnErrorType, ) -> Dict: - file_stage_path = stage_path_parts.add_stage_prefix(file_path) try: query = f"execute immediate from {file_stage_path}" if variables: query += variables self._execute_query(query) - cli_console.step(f"SUCCESS - {file_stage_path}") - return {"File": file_stage_path, "Status": "SUCCESS", "Error": None} + return StageManager._success_result(file=file_stage_path) except ProgrammingError as e: - cli_console.warning(f"FAILURE - {file_stage_path}") - if on_error == OnErrorType.BREAK: - raise e - return {"File": file_stage_path, "Status": "FAILURE", "Error": e.msg} + StageManager._handle_execution_exception(on_error=on_error, exception=e) + return StageManager._error_result(file=file_stage_path, msg=e.msg) @staticmethod def _stage_path_part_factory(stage_path: str) -> StagePathParts: @@ -378,3 +428,102 @@ def _stage_path_part_factory(stage_path: str) -> StagePathParts: if stage_path.startswith(USER_STAGE_PREFIX): return UserStagePathParts(stage_path) return DefaultStagePathParts(stage_path) + + def _check_for_requirements_file( + self, stage_path_parts: StagePathParts + ) -> List[str]: + """Looks for requirements.txt file on stage.""" + req_files_on_stage = self._get_files_list_from_stage( + stage_path_parts, pattern=r".*requirements\.txt$" + ) + if not req_files_on_stage: + return [] + + # Construct all possible path for requirements file for this context + # We don't use os.path or pathlib to preserve compatibility on Windows + req_file_name = "requirements.txt" + path_parts = stage_path_parts.path.split("/") + possible_req_files = [] + + while path_parts: + current_file = "/".join([*path_parts, req_file_name]) + possible_req_files.append(str(current_file)) + path_parts = path_parts[:-1] + + # Now for every possible path check if the file exists on stage, + # if yes break, we use the first possible file + requirements_file = None + for req_file in possible_req_files: + if req_file in req_files_on_stage: + requirements_file = req_file + break + + # If we haven't found any matching requirements + if requirements_file is None: + return [] + + # req_file at this moment is the first found requirements file + with SecurePath.temporary_directory() as tmp_dir: + self.get( + stage_path_parts.get_full_stage_path(requirements_file), tmp_dir.path + ) + requirements = parse_requirements( + requirements_file=tmp_dir / "requirements.txt" + ) + + return [req.package_name for req in requirements] + + def _bootstrap_snowpark_execution_environment( + self, stage_path_parts: StagePathParts + ): + """Prepares Snowpark session for executing Python code remotely.""" + if sys.version_info >= PYTHON_3_12: + raise ClickException( + f"Executing python files is not supported in Python >= 3.12. Current version: {sys.version}" + ) + + from snowflake.snowpark.functions import sproc + + self.snowpark_session.add_packages("snowflake-snowpark-python") + self.snowpark_session.add_packages("snowflake.core") + requirements = self._check_for_requirements_file(stage_path_parts) + self.snowpark_session.add_packages(*requirements) + + @sproc(is_permanent=False) + def _python_execution_procedure( + _: Session, file_path: str, variables: Dict | None = None + ) -> None: + """Snowpark session-scoped stored procedure to execute content of provided python file.""" + import json + + from snowflake.snowpark.files import SnowflakeFile + + with SnowflakeFile.open(file_path, require_scoped_url=False) as f: + file_content: str = f.read() # type: ignore + + wrapper = dedent( + f"""\ + import os + os.environ.update({json.dumps(variables)}) + """ + ) + + exec(wrapper + file_content) + + return _python_execution_procedure + + def _execute_python( + self, file_stage_path: str, on_error: OnErrorType, variables: Dict + ): + """ + Executes Python file from stage using a Snowpark temporary procedure. + Currently, there's no option to pass input to the execution. + """ + from snowflake.snowpark.exceptions import SnowparkSQLException + + try: + self._python_exe_procedure(self.get_standard_stage_prefix(file_stage_path), variables) # type: ignore + return StageManager._success_result(file=file_stage_path) + except SnowparkSQLException as e: + StageManager._handle_execution_exception(on_error=on_error, exception=e) + return StageManager._error_result(file=file_stage_path, msg=e.message) diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index 7118d13fc6..ab73ffa92e 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -1446,9 +1446,18 @@ | --on-error [break|continue] What to do when an error occurs. | | Defaults to break. | | [default: break] | - | --variable -D TEXT Variables for the template. For | - | example: `-D "="`, string | - | values must be in `''`. | + | --variable -D TEXT Variables for the execution context. | + | For example: `-D "="`. For | + | SQL files variables are use to expand | + | the template and any unknown variable | + | will cause an error. For Python files | + | variables are used to update | + | os.environ dictionary. Provided keys | + | are capitalized to adhere to best | + | practices.In case of SQL files string | + | values must be quoted in `''` | + | (consider embedding quoting in the | + | file). | | --help -h Show this message and exit. | +------------------------------------------------------------------------------+ +- Connection configuration ---------------------------------------------------+ @@ -2023,9 +2032,16 @@ | git repository with templates. | | [default: | | https://github.com/snowflakedb/snowflake-c… | - | --variable -D TEXT Variables for the template. For example: | - | `-D "="`, string values must be | - | in `''`. | + | --variable -D TEXT Variables for the execution context. For | + | example: `-D "="`. For SQL | + | files variables are use to expand the | + | template and any unknown variable will | + | cause an error. For Python files variables | + | are used to update os.environ dictionary. | + | Provided keys are capitalized to adhere to | + | best practices.In case of SQL files string | + | values must be quoted in `''` (consider | + | embedding quoting in the file). | | --no-interactive Disable prompting. | | --help -h Show this message and exit. | +------------------------------------------------------------------------------+ @@ -7208,9 +7224,18 @@ | --on-error [break|continue] What to do when an error occurs. | | Defaults to break. | | [default: break] | - | --variable -D TEXT Variables for the template. For | - | example: `-D "="`, string | - | values must be in `''`. | + | --variable -D TEXT Variables for the execution context. | + | For example: `-D "="`. For | + | SQL files variables are use to expand | + | the template and any unknown variable | + | will cause an error. For Python files | + | variables are used to update | + | os.environ dictionary. Provided keys | + | are capitalized to adhere to best | + | practices.In case of SQL files string | + | values must be quoted in `''` | + | (consider embedding quoting in the | + | file). | | --help -h Show this message and exit. | +------------------------------------------------------------------------------+ +- Connection configuration ---------------------------------------------------+ diff --git a/tests/snowpark/test_function.py b/tests/snowpark/test_function.py index 60ad55849a..ec7218b546 100644 --- a/tests/snowpark/test_function.py +++ b/tests/snowpark/test_function.py @@ -57,7 +57,7 @@ def test_deploy_function( f" auto_compress=false parallel=4 overwrite=True", dedent( """\ - create or replace function MockDatabase.MockSchema.func1(a string default 'default value', b variant) + create or replace function IDENTIFIER('MockDatabase.MockSchema.func1')(a string default 'default value', b variant) copy grants returns string language python @@ -105,7 +105,7 @@ def test_deploy_function_with_external_access( f" auto_compress=false parallel=4 overwrite=True", dedent( """\ - create or replace function MockDatabase.MockSchema.func1(a string, b variant) + create or replace function IDENTIFIER('MockDatabase.MockSchema.func1')(a string, b variant) copy grants returns string language python @@ -225,7 +225,7 @@ def test_deploy_function_needs_update_because_packages_changes( f"put file://{Path(project_dir).resolve()}/app.zip @MockDatabase.MockSchema.dev_deployment/my_snowpark_project auto_compress=false parallel=4 overwrite=True", dedent( """\ - create or replace function MockDatabase.MockSchema.func1(a string default 'default value', b variant) + create or replace function IDENTIFIER('MockDatabase.MockSchema.func1')(a string default 'default value', b variant) copy grants returns string language python @@ -276,7 +276,7 @@ def test_deploy_function_needs_update_because_handler_changes( f" auto_compress=false parallel=4 overwrite=True", dedent( """\ - create or replace function MockDatabase.MockSchema.func1(a string default 'default value', b variant) + create or replace function IDENTIFIER('MockDatabase.MockSchema.func1')(a string default 'default value', b variant) copy grants returns string language python diff --git a/tests/snowpark/test_procedure.py b/tests/snowpark/test_procedure.py index a1fe3c9c25..221b3a8928 100644 --- a/tests/snowpark/test_procedure.py +++ b/tests/snowpark/test_procedure.py @@ -79,7 +79,7 @@ def test_deploy_procedure( f"put file://{Path(tmp).resolve()}/app.zip @MockDatabase.MockSchema.dev_deployment/my_snowpark_project auto_compress=false parallel=4 overwrite=True", dedent( """\ - create or replace procedure MockDatabase.MockSchema.procedureName(name string) + create or replace procedure IDENTIFIER('MockDatabase.MockSchema.procedureName')(name string) copy grants returns string language python @@ -91,7 +91,7 @@ def test_deploy_procedure( ).strip(), dedent( """\ - create or replace procedure MockDatabase.MockSchema.test() + create or replace procedure IDENTIFIER('MockDatabase.MockSchema.test')() copy grants returns string language python @@ -149,7 +149,7 @@ def test_deploy_procedure_with_external_access( f" auto_compress=false parallel=4 overwrite=True", dedent( """\ - create or replace procedure MockDatabase.MockSchema.procedureName(name string) + create or replace procedure IDENTIFIER('MockDatabase.MockSchema.procedureName')(name string) copy grants returns string language python diff --git a/tests/stage/__snapshots__/test_stage.ambr b/tests/stage/__snapshots__/test_stage.ambr index a4b189c607..823f08b8ca 100644 --- a/tests/stage/__snapshots__/test_stage.ambr +++ b/tests/stage/__snapshots__/test_stage.ambr @@ -259,16 +259,20 @@ # --- # name: test_execute_continue_on_error ''' + SUCCESS - @exe/p1.py + FAILURE - @exe/p2.py SUCCESS - @exe/s1.sql FAILURE - @exe/s2.sql SUCCESS - @exe/s3.sql - +-------------------------------+ - | File | Status | Error | - |-------------+---------+-------| - | @exe/s1.sql | SUCCESS | None | - | @exe/s2.sql | FAILURE | Error | - | @exe/s3.sql | SUCCESS | None | - +-------------------------------+ + +------------------------------------+ + | File | Status | Error | + |-------------+---------+------------| + | @exe/p1.py | SUCCESS | None | + | @exe/p2.py | FAILURE | Test error | + | @exe/s1.sql | SUCCESS | None | + | @exe/s2.sql | FAILURE | Error | + | @exe/s3.sql | SUCCESS | None | + +------------------------------------+ ''' # --- @@ -336,7 +340,7 @@ # name: test_execute_raise_invalid_file_extension_error ''' +- Error ----------------------------------------------------------------------+ - | Invalid file extension, only `.sql` files are allowed. | + | Invalid file extension, only .sql, .py files are allowed. | +------------------------------------------------------------------------------+ ''' diff --git a/tests/stage/test_stage.py b/tests/stage/test_stage.py index 49bdb4d707..395a2f2918 100644 --- a/tests/stage/test_stage.py +++ b/tests/stage/test_stage.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import sys from pathlib import Path from tempfile import TemporaryDirectory from unittest import mock @@ -30,6 +30,10 @@ STAGE_MANAGER = "snowflake.cli.plugins.stage.manager.StageManager" +skip_python_3_12 = pytest.mark.skipif( + sys.version_info >= (3, 12), reason="Snowpark is not supported in Python >= 3.12" +) + @mock.patch(f"{STAGE_MANAGER}._execute_query") def test_stage_list(mock_execute, runner, mock_cursor): @@ -848,8 +852,12 @@ def test_execute_from_user_stage( @mock.patch(f"{STAGE_MANAGER}._execute_query") -def test_execute_with_variables(mock_execute, mock_cursor, runner): - mock_execute.return_value = mock_cursor([{"name": "exe/s1.sql"}], []) +@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment") +@skip_python_3_12 +def test_execute_with_variables(mock_bootstrap, mock_execute, mock_cursor, runner): + mock_execute.return_value = mock_cursor( + [{"name": "exe/s1.sql"}, {"name": "exe/s2.py"}], [] + ) result = runner.invoke( [ @@ -861,7 +869,7 @@ def test_execute_with_variables(mock_execute, mock_cursor, runner): "-D", "key2=1", "-D", - "key3=TRUE", + "KEY3=TRUE", "-D", "key4=NULL", "-D", @@ -873,9 +881,19 @@ def test_execute_with_variables(mock_execute, mock_cursor, runner): assert mock_execute.mock_calls == [ mock.call("ls @exe", cursor_class=DictCursor), mock.call( - f"execute immediate from @exe/s1.sql using (key1=>'string value', key2=>1, key3=>TRUE, key4=>NULL, key5=>'var=value')" + f"execute immediate from @exe/s1.sql using (key1=>'string value', key2=>1, KEY3=>TRUE, key4=>NULL, key5=>'var=value')" ), ] + mock_bootstrap.return_value.assert_called_once_with( + "@exe/s2.py", + { + "key1": "'string value'", + "key2": "1", + "KEY3": "TRUE", + "key4": "NULL", + "key5": "'var=value'", + }, + ) @mock.patch(f"{STAGE_MANAGER}._execute_query") @@ -966,13 +984,17 @@ def test_execute_no_files_for_stage_path( @mock.patch(f"{STAGE_MANAGER}._execute_query") -def test_execute_stop_on_error(mock_execute, mock_cursor, runner): +@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment") +@skip_python_3_12 +def test_execute_stop_on_error(mock_bootstrap, mock_execute, mock_cursor, runner): error_message = "Error" mock_execute.side_effect = [ mock_cursor( [ {"name": "exe/s1.sql"}, + {"name": "exe/p1.py"}, {"name": "exe/s2.sql"}, + {"name": "exe/p2.py"}, {"name": "exe/s3.sql"}, ], [], @@ -989,18 +1011,28 @@ def test_execute_stop_on_error(mock_execute, mock_cursor, runner): mock.call(f"execute immediate from @exe/s1.sql"), mock.call(f"execute immediate from @exe/s2.sql"), ] + assert mock_bootstrap.return_value.mock_calls == [ + mock.call("@exe/p1.py", {}), + mock.call("@exe/p2.py", {}), + ] assert e.value.msg == error_message @mock.patch(f"{STAGE_MANAGER}._execute_query") +@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment") +@skip_python_3_12 def test_execute_continue_on_error( - mock_execute, mock_cursor, runner, os_agnostic_snapshot + mock_bootstrap, mock_execute, mock_cursor, runner, os_agnostic_snapshot ): + from snowflake.snowpark.exceptions import SnowparkSQLException + mock_execute.side_effect = [ mock_cursor( [ {"name": "exe/s1.sql"}, + {"name": "exe/p1.py"}, {"name": "exe/s2.sql"}, + {"name": "exe/p2.py"}, {"name": "exe/s3.sql"}, ], [], @@ -1010,6 +1042,8 @@ def test_execute_continue_on_error( mock_cursor([{"3": 3}], []), ] + mock_bootstrap.return_value.side_effect = ["ok", SnowparkSQLException("Test error")] + result = runner.invoke(["stage", "execute", "exe", "--on-error", "continue"]) assert result.exit_code == 0 @@ -1021,6 +1055,11 @@ def test_execute_continue_on_error( mock.call(f"execute immediate from @exe/s3.sql"), ] + assert mock_bootstrap.return_value.mock_calls == [ + mock.call("@exe/p1.py", {}), + mock.call("@exe/p2.py", {}), + ] + @mock.patch("snowflake.connector.connect") @pytest.mark.parametrize( @@ -1043,3 +1082,58 @@ def test_command_aliases(mock_connector, runner, mock_ctx, command, parameters): queries = ctx.get_queries() assert queries[0] == queries[1] + + +@pytest.mark.parametrize( + "files, selected, packages", + [ + ([], None, []), + (["my_stage/dir/parallel/requirements.txt"], None, []), + ( + ["my_stage/dir/files/requirements.txt"], + "db.schema.my_stage/dir/files/requirements.txt", + ["aaa", "bbb"], + ), + ( + [ + "my_stage/requirements.txt", + "my_stage/dir/requirements.txt", + "my_stage/dir/files/requirements.txt", + ], + "db.schema.my_stage/dir/files/requirements.txt", + ["aaa", "bbb"], + ), + ( + ["my_stage/requirements.txt"], + "db.schema.my_stage/requirements.txt", + ["aaa", "bbb"], + ), + ], +) +@pytest.mark.parametrize( + "input_path", ["@db.schema.my_stage/dir/files", "@db.schema.my_stage/dir/files/"] +) +def test_stage_manager_check_for_requirements_file( + files, selected, packages, input_path +): + class _MockGetter: + def __init__(self): + self.download_file = None + + def __call__(self, file_on_stage, target_dir): + self.download_file = file_on_stage + (target_dir / "requirements.txt").write_text("\n".join(packages)) + + get_mock = _MockGetter() + sm = StageManager() + with mock.patch.object( + sm, "_get_files_list_from_stage", lambda parts, pattern: files + ): + with mock.patch.object(StageManager, "get", get_mock) as get_mock: + result = sm._check_for_requirements_file( # noqa: SLF001 + stage_path_parts=sm._stage_path_part_factory(input_path) # noqa: SLF001 + ) + + assert result == packages + + assert get_mock.download_file == selected diff --git a/tests_integration/__snapshots__/test_stage.ambr b/tests_integration/__snapshots__/test_stage.ambr index f9a443fddd..a3f98c619d 100644 --- a/tests_integration/__snapshots__/test_stage.ambr +++ b/tests_integration/__snapshots__/test_stage.ambr @@ -77,6 +77,20 @@ }), ]) # --- +# name: test_stage_execute_python + list([ + dict({ + 'Error': None, + 'File': '@test_stage_execute/script1.py', + 'Status': 'SUCCESS', + }), + dict({ + 'Error': None, + 'File': '@test_stage_execute/script_template.py', + 'Status': 'SUCCESS', + }), + ]) +# --- # name: test_user_stage_execute list([ dict({ diff --git a/tests_integration/test_connection.py b/tests_integration/test_connection.py index e6bc70b2f7..4128c2e1fc 100644 --- a/tests_integration/test_connection.py +++ b/tests_integration/test_connection.py @@ -26,6 +26,7 @@ def test_connection_test_simple(runner): result = runner.invoke_with_connection_json(["connection", "test"]) assert result.exit_code == 0, result.output + assert result.json["Connection name"] == "integration" assert result.json["Status"] == "OK" diff --git a/tests_integration/test_data/projects/stage_execute/requirements.txt b/tests_integration/test_data/projects/stage_execute/requirements.txt new file mode 100644 index 0000000000..745306e644 --- /dev/null +++ b/tests_integration/test_data/projects/stage_execute/requirements.txt @@ -0,0 +1,4 @@ +scikit-learn + +# comment +matplotlib diff --git a/tests_integration/test_data/projects/stage_execute/script1.py b/tests_integration/test_data/projects/stage_execute/script1.py new file mode 100644 index 0000000000..c9f9b6de4b --- /dev/null +++ b/tests_integration/test_data/projects/stage_execute/script1.py @@ -0,0 +1 @@ +print("ok") diff --git a/tests_integration/test_data/projects/stage_execute/script_template.py b/tests_integration/test_data/projects/stage_execute/script_template.py new file mode 100644 index 0000000000..d1f7c80499 --- /dev/null +++ b/tests_integration/test_data/projects/stage_execute/script_template.py @@ -0,0 +1,19 @@ +import os +from snowflake.core import Root +from snowflake.core.database import DatabaseResource +from snowflake.core.schema import Schema +from snowflake.snowpark.session import Session + +session = Session.builder.getOrCreate() +database: DatabaseResource = Root(session).databases[os.environ["test_database_name"]] + +assert database.name.upper() == os.environ["test_database_name"].upper() + +# Make a side effect that we can check in tests +database.schemas.create(Schema(name=os.environ["TEST_ID"])) + +# Check if an external dependency works +from sklearn import show_versions +import matplotlib + +show_versions() diff --git a/tests_integration/test_snowpark.py b/tests_integration/test_snowpark.py index e6cfe7e08a..31d5a62f4a 100644 --- a/tests_integration/test_snowpark.py +++ b/tests_integration/test_snowpark.py @@ -14,6 +14,7 @@ from __future__ import annotations +import sys from pathlib import Path from textwrap import dedent @@ -715,6 +716,7 @@ def test_build_skip_version_check( @pytest.mark.integration +@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Unknown issues") @pytest.mark.parametrize( "flags", [ @@ -760,6 +762,7 @@ def test_build_with_non_anaconda_dependencies( @pytest.mark.integration +@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Unknown issues") def test_build_shared_libraries_error( runner, project_directory, alter_requirements_txt, test_database ): diff --git a/tests_integration/test_stage.py b/tests_integration/test_stage.py index 42d174ebd2..868218bbb7 100644 --- a/tests_integration/test_stage.py +++ b/tests_integration/test_stage.py @@ -14,10 +14,13 @@ import glob import os +import sys import tempfile +import time from pathlib import Path import pytest +from snowflake.connector import DictCursor from tests_integration.test_utils import ( contains_row_with, @@ -330,6 +333,57 @@ def test_user_stage_execute(runner, test_database, test_root_path, snapshot): assert result.json == snapshot +@pytest.mark.integration +@pytest.mark.skipif( + sys.version_info >= (3, 12), reason="Snowpark is not supported in Python >= 3.12" +) +def test_stage_execute_python( + snowflake_session, runner, test_database, test_root_path, snapshot +): + project_path = test_root_path / "test_data/projects/stage_execute" + stage_name = "test_stage_execute" + + result = runner.invoke_with_connection_json(["stage", "create", stage_name]) + assert contains_row_with( + result.json, + {"status": f"Stage area {stage_name.upper()} successfully created."}, + ) + + files = ["script1.py", "script_template.py", "requirements.txt"] + for name in files: + result = runner.invoke_with_connection_json( + [ + "stage", + "copy", + str(Path(project_path) / name), + f"@{stage_name}", + ] + ) + assert result.exit_code == 0, result.output + assert contains_row_with(result.json, {"status": "UPLOADED"}) + + test_id = f"FOO{time.time_ns()}" + result = runner.invoke_with_connection_json( + [ + "stage", + "execute", + f"{stage_name}/", + "-D", + f"test_database_name={test_database}", + "-D", + f"TEST_ID={test_id}", + ] + ) + assert result.exit_code == 0 + assert result.json == snapshot + + # Assert side effect created by executed script + *_, schemas = snowflake_session.execute_string( + f"show schemas like '{test_id}' in database {test_database};" + ) + assert len(list(schemas)) == 1 + + @pytest.mark.integration def test_stage_diff(runner, snowflake_session, test_database, tmp_path, snapshot): stage_name = "test_stage"