diff --git a/src/snowflake/cli/api/sql_execution.py b/src/snowflake/cli/api/sql_execution.py index 9a22007f2c..cb7b89a3fb 100644 --- a/src/snowflake/cli/api/sql_execution.py +++ b/src/snowflake/cli/api/sql_execution.py @@ -99,13 +99,12 @@ def use_role(self, new_role: str): prev_role = role_result["CURRENT_ROLE()"] is_different_role = new_role.lower() != prev_role.lower() if is_different_role: - self._log.debug("Temporarily switching to a different role: %s", new_role) + self._log.debug("Assuming different role: %s", new_role) self._execute_query(f"use role {new_role}") try: yield finally: if is_different_role: - self._log.debug("Switching back to the original role: %s", prev_role) self._execute_query(f"use role {prev_role}") def create_password_secret( diff --git a/src/snowflake/cli/plugins/nativeapp/manager.py b/src/snowflake/cli/plugins/nativeapp/manager.py index 445b35a1f5..276cea9d93 100644 --- a/src/snowflake/cli/plugins/nativeapp/manager.py +++ b/src/snowflake/cli/plugins/nativeapp/manager.py @@ -40,7 +40,6 @@ from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping from snowflake.cli.api.project.util import ( identifier_for_url, - to_identifier, unquote_identifier, ) from snowflake.cli.api.sql_execution import SqlExecutionMixin @@ -265,61 +264,31 @@ def use_warehouse(self, new_wh: Optional[str]): This is a no-op if the requested warehouse is already active. If there is no default warehouse in the account, it will throw an error. """ + + if new_wh is None: + # The new_wh parameter is an Optional[str] as the project definition attributes are Optional[str], passed directly to this method. + raise ClickException("Requested warehouse cannot be None.") + wh_result = self._execute_query( f"select current_warehouse()", cursor_class=DictCursor ).fetchone() # If user has an assigned default warehouse, prev_wh will contain a value even if the warehouse is suspended. - prev_wh = wh_result["CURRENT_WAREHOUSE()"] - - if new_wh is not None: - new_wh_id = to_identifier(new_wh) - - if prev_wh is None and new_wh is None: - raise ClickException( - "Could not find both the connection and the requested warehouse in the Snowflake account for this user." - ) - - elif prev_wh is None: - self._log.debug("Using warehouse: %s", new_wh_id) - self.use(object_type=ObjectType.WAREHOUSE, name=new_wh_id) - try: - yield - finally: - self._log.debug( - "Continuing to use warehouse unless requested otherwise: %s", - new_wh_id, - ) - - elif new_wh is None: - self._log.debug( - "Requested warehouse is empty, continuing to use the connection warehouse: %s", - prev_wh, - ) - # Activate a suspended warehouse, will be a no-op will already active - self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh) - try: - yield - finally: - self._log.debug( - "Continuing to use warehouse unless requested otherwise: %s", - prev_wh, - ) + try: + prev_wh = wh_result["CURRENT_WAREHOUSE()"] + except: + prev_wh = None - else: - is_different_wh = new_wh_id != prev_wh - self._log.debug("Using warehouse: %s", new_wh_id) + # new_wh is not None, and should already be a valid identifier, no additional check is performed here. + is_different_wh = new_wh != prev_wh + try: if is_different_wh: - # Activate the new warehouse - self.use(object_type=ObjectType.WAREHOUSE, name=new_wh_id) - else: - # Activate a suspended warehouse, will be a no-op will already active + self._log.debug("Using warehouse: %s", new_wh) + self.use(object_type=ObjectType.WAREHOUSE, name=new_wh) + yield + finally: + if prev_wh and is_different_wh: + self._log.debug("Switching back to warehouse: %s", prev_wh) self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh) - try: - yield - finally: - if is_different_wh: - self._log.debug("Switching back to warehouse: %s", prev_wh) - self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh) @cached_property def get_app_pkg_distribution_in_snowflake(self) -> str: diff --git a/src/snowflake/cli/plugins/nativeapp/project_model.py b/src/snowflake/cli/plugins/nativeapp/project_model.py index 00602b6261..7b90f0dcde 100644 --- a/src/snowflake/cli/plugins/nativeapp/project_model.py +++ b/src/snowflake/cli/plugins/nativeapp/project_model.py @@ -107,14 +107,18 @@ def package_warehouse(self) -> Optional[str]: if self.definition.package and self.definition.package.warehouse: return to_identifier(self.definition.package.warehouse) else: - return to_identifier(cli_context.connection.warehouse) + if cli_context.connection.warehouse: + return to_identifier(cli_context.connection.warehouse) + return None @cached_property def application_warehouse(self) -> Optional[str]: if self.definition.application and self.definition.application.warehouse: return to_identifier(self.definition.application.warehouse) else: - return to_identifier(cli_context.connection.warehouse) + if cli_context.connection.warehouse: + return to_identifier(cli_context.connection.warehouse) + return None @cached_property def project_identifier(self) -> str: diff --git a/src/snowflake/cli/plugins/nativeapp/run_processor.py b/src/snowflake/cli/plugins/nativeapp/run_processor.py index 4f5897d857..92e620ca26 100644 --- a/src/snowflake/cli/plugins/nativeapp/run_processor.py +++ b/src/snowflake/cli/plugins/nativeapp/run_processor.py @@ -138,7 +138,8 @@ def __init__(self, project_definition: NativeApp, project_root: Path): def _execute_sql_script(self, sql_script_path): """ Executing the SQL script in the provided file path after expanding template variables. - "use warehouse" and "use database" will be executed first if they are set in definition file or in the current connection. + This assumes that a relevant warehouse is already active. + Consequently, "use database" will be executed first if it is set in definition file or in the current connection. """ with open(sql_script_path) as f: sql_script = f.read() diff --git a/tests/nativeapp/test_package_scripts.py b/tests/nativeapp/test_package_scripts.py index 6d749f9d8e..b49ecb5d7f 100644 --- a/tests/nativeapp/test_package_scripts.py +++ b/tests/nativeapp/test_package_scripts.py @@ -56,7 +56,6 @@ def _get_na_manager(working_dir): "napp_project_1", [ mock.call("select current_warehouse()", cursor_class=DictCursor), - mock.call("use warehouse MockWarehouse"), ], ), ( @@ -205,6 +204,7 @@ def test_package_scripts_w_warehouse_access_exception( msg="Object does not exist, or operation cannot be performed.", errno=DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED, ), + None, ] mock_conn.return_value = MockConnectionCtx() diff --git a/tests/nativeapp/test_run_processor.py b/tests/nativeapp/test_run_processor.py index baca5d507e..eff91a08f9 100644 --- a/tests/nativeapp/test_run_processor.py +++ b/tests/nativeapp/test_run_processor.py @@ -114,6 +114,10 @@ def test_create_dev_app_w_warehouse_access_exception( ), mock.call("use warehouse app_warehouse"), ), + ( + None, + mock.call("use warehouse old_wh"), + ), (None, mock.call("use role old_role")), ] ) @@ -1110,6 +1114,10 @@ def test_upgrade_app_warehouse_error( ), mock.call("use warehouse app_warehouse"), ), + ( + None, + mock.call("use warehouse old_wh"), + ), (None, mock.call("use role old_role")), ] )