diff --git a/src/snowflake/cli/api/sql_execution.py b/src/snowflake/cli/api/sql_execution.py index cb7b89a3fb..f5f069f0e3 100644 --- a/src/snowflake/cli/api/sql_execution.py +++ b/src/snowflake/cli/api/sql_execution.py @@ -107,6 +107,35 @@ def use_role(self, new_role: str): if is_different_role: self._execute_query(f"use role {prev_role}") + @contextmanager + def use_warehouse(self, new_wh: str): + """ + Switches to a different warehouse for a while, then switches back. + This is a no-op if the requested warehouse is already active. + If there is no default warehouse in the account, it will throw an error. + """ + + wh_result = self._execute_query( + f"select current_warehouse()", cursor_class=DictCursor + ).fetchone() + # If user has an assigned default warehouse, prev_wh will contain a value even if the warehouse is suspended. + try: + prev_wh = wh_result["CURRENT_WAREHOUSE()"] + except: + prev_wh = None + + # new_wh is not None, and should already be a valid identifier, no additional check is performed here. + is_different_wh = new_wh != prev_wh + try: + if is_different_wh: + self._log.debug("Using warehouse: %s", new_wh) + self.use(object_type=ObjectType.WAREHOUSE, name=new_wh) + yield + finally: + if prev_wh and is_different_wh: + self._log.debug("Switching back to warehouse: %s", prev_wh) + self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh) + def create_password_secret( self, name: str, username: str, password: str ) -> SnowflakeCursor: diff --git a/src/snowflake/cli/plugins/nativeapp/manager.py b/src/snowflake/cli/plugins/nativeapp/manager.py index 276cea9d93..9852f8632e 100644 --- a/src/snowflake/cli/plugins/nativeapp/manager.py +++ b/src/snowflake/cli/plugins/nativeapp/manager.py @@ -17,7 +17,6 @@ import json import os from abc import ABC, abstractmethod -from contextlib import contextmanager from functools import cached_property from pathlib import Path from textwrap import dedent @@ -26,7 +25,6 @@ import jinja2 from click import ClickException from snowflake.cli.api.console import cli_console as cc -from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.errno import ( DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED, DOES_NOT_EXIST_OR_NOT_AUTHORIZED, @@ -83,7 +81,6 @@ ) from snowflake.cli.plugins.stage.manager import StageManager from snowflake.connector import ProgrammingError -from snowflake.connector.cursor import DictCursor ApplicationOwnedObject = TypedDict("ApplicationOwnedObject", {"name": str, "type": str}) @@ -221,10 +218,34 @@ def stage_schema(self) -> Optional[str]: def package_warehouse(self) -> Optional[str]: return self.na_project.package_warehouse + def use_package_warehouse(self) -> str: + if self.package_warehouse: + return self.package_warehouse + raise ClickException( + dedent( + f"""\ + Application package warehouse cannot be empty. + Please provide a value for it in your connection information or your project definition file. + """ + ) + ) + @property def application_warehouse(self) -> Optional[str]: return self.na_project.application_warehouse + def use_application_warehouse(self) -> str: + if self.application_warehouse: + return self.application_warehouse + raise ClickException( + dedent( + f"""\ + Application warehouse cannot be empty. + Please provide a value for it in your connection information or your project definition file. + """ + ) + ) + @property def project_identifier(self) -> str: return self.na_project.project_identifier @@ -257,39 +278,6 @@ def app_post_deploy_hooks(self) -> Optional[List[ApplicationPostDeployHook]]: def debug_mode(self) -> bool: return self.na_project.debug_mode - @contextmanager - def use_warehouse(self, new_wh: Optional[str]): - """ - Switches to a different warehouse for a while, then switches back. - This is a no-op if the requested warehouse is already active. - If there is no default warehouse in the account, it will throw an error. - """ - - if new_wh is None: - # The new_wh parameter is an Optional[str] as the project definition attributes are Optional[str], passed directly to this method. - raise ClickException("Requested warehouse cannot be None.") - - wh_result = self._execute_query( - f"select current_warehouse()", cursor_class=DictCursor - ).fetchone() - # If user has an assigned default warehouse, prev_wh will contain a value even if the warehouse is suspended. - try: - prev_wh = wh_result["CURRENT_WAREHOUSE()"] - except: - prev_wh = None - - # new_wh is not None, and should already be a valid identifier, no additional check is performed here. - is_different_wh = new_wh != prev_wh - try: - if is_different_wh: - self._log.debug("Using warehouse: %s", new_wh) - self.use(object_type=ObjectType.WAREHOUSE, name=new_wh) - yield - finally: - if prev_wh and is_different_wh: - self._log.debug("Switching back to warehouse: %s", prev_wh) - self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh) - @cached_property def get_app_pkg_distribution_in_snowflake(self) -> str: """ @@ -520,7 +508,7 @@ def _application_object_to_str(self, obj: ApplicationOwnedObject) -> str: def get_snowsight_url(self) -> str: """Returns the URL that can be used to visit this app via Snowsight.""" name = identifier_for_url(self.app_name) - with self.use_warehouse(self.application_warehouse): + with self.use_warehouse(self.use_application_warehouse()): return make_snowsight_url(self._conn, f"/#/apps/application/{name}") def create_app_package(self) -> None: @@ -594,7 +582,7 @@ def _apply_package_scripts(self) -> None: raise InvalidPackageScriptError(relpath, e) # once we're sure all the templates expanded correctly, execute all of them - with self.use_warehouse(self.package_warehouse): + with self.use_warehouse(self.use_package_warehouse()): try: for i, queries in enumerate(queued_queries): cc.step(f"Applying package script: {self.package_scripts[i]}") diff --git a/tests/nativeapp/test_manager.py b/tests/nativeapp/test_manager.py index 8bdc03c3b7..18e053bfb2 100644 --- a/tests/nativeapp/test_manager.py +++ b/tests/nativeapp/test_manager.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import json import os from pathlib import Path from textwrap import dedent +from typing import Optional from unittest import mock from unittest.mock import call import pytest +from click import ClickException from snowflake.cli.api.errno import DOES_NOT_EXIST_OR_NOT_AUTHORIZED from snowflake.cli.api.project.definition_manager import DefinitionManager from snowflake.cli.plugins.nativeapp.artifacts import BundleMap @@ -80,8 +83,8 @@ } -def _get_na_manager(): - dm = DefinitionManager() +def _get_na_manager(working_dir: Optional[str] = None): + dm = DefinitionManager(working_dir) return NativeAppManager( project_definition=dm.project_definition.native_app, project_root=dm.project_root, @@ -580,21 +583,41 @@ def test_get_existing_app_pkg_info_app_pkg_does_not_exist( assert mock_execute.mock_calls == expected +# With connection warehouse, with PDF warehouse +# Without connection warehouse, with PDF warehouse @mock.patch("snowflake.cli.plugins.connection.util.get_context") @mock.patch("snowflake.cli.plugins.connection.util.get_account") @mock.patch("snowflake.cli.plugins.connection.util.get_snowsight_host") @mock.patch(NATIVEAPP_MANAGER_EXECUTE) @mock_connection() -def test_get_snowsight_url( +@pytest.mark.parametrize( + "warehouse, fallback_warehouse_call, fallback_side_effect", + [ + ( + "MockWarehouse", + [mock.call("use warehouse MockWarehouse")], + [None], + ), + ( + None, + [], + [], + ), + ], +) +def test_get_snowsight_url_with_pdf_warehouse( mock_conn, mock_execute_query, mock_snowsight_host, mock_account, mock_context, + warehouse, + fallback_warehouse_call, + fallback_side_effect, temp_dir, mock_cursor, ): - mock_conn.return_value = MockConnectionCtx() + mock_conn.return_value = MockConnectionCtx(warehouse=warehouse) mock_snowsight_host.return_value = "https://host" mock_context.return_value = "organization" mock_account.return_value = "account" @@ -609,21 +632,82 @@ def test_get_snowsight_url( side_effects, expected = mock_execute_helper( [ ( - mock_cursor([{"CURRENT_WAREHOUSE()": "old_wh"}], []), + mock_cursor([{"CURRENT_WAREHOUSE()": warehouse}], []), mock.call("select current_warehouse()", cursor_class=DictCursor), ), (None, mock.call("use warehouse app_warehouse")), - (None, mock.call("use warehouse old_wh")), ] ) - mock_execute_query.side_effect = side_effects + mock_execute_query.side_effect = side_effects + fallback_side_effect native_app_manager = _get_na_manager() assert ( native_app_manager.get_snowsight_url() == "https://host/organization/account/#/apps/application/MYAPP" ) - assert mock_execute_query.mock_calls == expected + assert mock_execute_query.mock_calls == expected + fallback_warehouse_call + + +# With connection warehouse, without PDF warehouse +# Without connection warehouse, without PDF warehouse +@mock.patch("snowflake.cli.plugins.connection.util.get_context") +@mock.patch("snowflake.cli.plugins.connection.util.get_account") +@mock.patch("snowflake.cli.plugins.connection.util.get_snowsight_host") +@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock_connection() +@pytest.mark.parametrize( + "project_definition_files, warehouse, expected_calls, fallback_side_effect", + [ + ( + "napp_project_1", + "MockWarehouse", + [mock.call("select current_warehouse()", cursor_class=DictCursor)], + [None], + ), + ( + "napp_project_1", + None, + [], + [], + ), + ], + indirect=["project_definition_files"], +) +def test_get_snowsight_url_without_pdf_warehouse( + mock_conn, + mock_execute_query, + mock_snowsight_host, + mock_account, + mock_context, + project_definition_files, + warehouse, + expected_calls, + fallback_side_effect, + mock_cursor, +): + mock_conn.return_value = MockConnectionCtx(warehouse=warehouse) + mock_snowsight_host.return_value = "https://host" + mock_context.return_value = "organization" + mock_account.return_value = "account" + + working_dir: Path = project_definition_files[0].parent + + mock_execute_query.side_effect = [ + mock_cursor([{"CURRENT_WAREHOUSE()": warehouse}], []) + ] + fallback_side_effect + + native_app_manager = _get_na_manager(str(working_dir)) + if warehouse: + assert ( + native_app_manager.get_snowsight_url() + == "https://host/organization/account/#/apps/application/MYAPP_POLLY" + ) + else: + with pytest.raises(ClickException) as err: + native_app_manager.get_snowsight_url() + assert "Application warehouse cannot be empty." in err.value.message + + assert mock_execute_query.mock_calls == expected_calls def test_ensure_correct_owner(): diff --git a/tests/nativeapp/test_package_scripts.py b/tests/nativeapp/test_package_scripts.py index b49ecb5d7f..3fea552439 100644 --- a/tests/nativeapp/test_package_scripts.py +++ b/tests/nativeapp/test_package_scripts.py @@ -17,6 +17,7 @@ from unittest import mock import pytest +from click import ClickException from snowflake.cli.api.errno import ( DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED, NO_WAREHOUSE_SELECTED_IN_SESSION, @@ -53,13 +54,13 @@ def _get_na_manager(working_dir): "project_definition_files, expected_calls", [ ( - "napp_project_1", + "napp_project_1", # With connection warehouse, without PDF warehouse [ mock.call("select current_warehouse()", cursor_class=DictCursor), ], ), ( - "napp_project_with_pkg_warehouse", + "napp_project_with_pkg_warehouse", # With connection warehouse, with PDF warehouse [ mock.call("select current_warehouse()", cursor_class=DictCursor), mock.call("use warehouse myapp_pkg_warehouse"), @@ -69,7 +70,7 @@ def _get_na_manager(working_dir): ], indirect=["project_definition_files"], ) -def test_package_scripts( +def test_package_scripts_with_conn_info( mock_conn, mock_execute_query, mock_execute_queries, @@ -115,6 +116,83 @@ def test_package_scripts( ] +# Without connection warehouse, without PDF warehouse +@mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) +@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock_connection() +@pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) +def test_package_scripts_without_conn_info_throws_error( + mock_conn, + mock_execute_query, + mock_execute_queries, + project_definition_files, + mock_cursor, +): + mock_conn.return_value = MockConnectionCtx(warehouse=None) + working_dir: Path = project_definition_files[0].parent + mock_execute_query.return_value = mock_cursor([{"CURRENT_WAREHOUSE()": None}], []) + native_app_manager = _get_na_manager(str(working_dir)) + with pytest.raises(ClickException) as err: + native_app_manager._apply_package_scripts() # noqa: SLF001 + + assert "Application package warehouse cannot be empty." in err.value.message + assert mock_execute_query.mock_calls == [] + assert mock_execute_queries.mock_calls == [] + + +# Without connection warehouse, with PDF warehouse +@mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) +@mock.patch(NATIVEAPP_MANAGER_EXECUTE) +@mock_connection() +@pytest.mark.parametrize( + "project_definition_files", ["napp_project_with_pkg_warehouse"], indirect=True +) +def test_package_scripts_without_conn_info_succeeds( + mock_conn, + mock_execute_query, + mock_execute_queries, + project_definition_files, + mock_cursor, +): + mock_conn.return_value = MockConnectionCtx(warehouse=None) + working_dir: Path = project_definition_files[0].parent + mock_execute_query.return_value = mock_cursor([{"CURRENT_WAREHOUSE()": None}], []) + native_app_manager = _get_na_manager(str(working_dir)) + native_app_manager._apply_package_scripts() # noqa: SLF001 + + assert mock_execute_query.mock_calls == [ + mock.call("select current_warehouse()", cursor_class=DictCursor), + mock.call("use warehouse myapp_pkg_warehouse"), + ] + assert mock_execute_queries.mock_calls == [ + mock.call( + dedent( + f"""\ + -- package script (1/2) + + create schema if not exists myapp_pkg_polly.my_shared_content; + grant usage on schema myapp_pkg_polly.my_shared_content + to share in application package myapp_pkg_polly; + """ + ) + ), + mock.call( + dedent( + f"""\ + -- package script (2/2) + + create or replace table myapp_pkg_polly.my_shared_content.shared_table ( + col1 number, + col2 varchar + ); + grant select on table myapp_pkg_polly.my_shared_content.shared_table + to share in application package myapp_pkg_polly; + """ + ) + ), + ] + + @mock.patch(NATIVEAPP_MANAGER_EXECUTE_QUERIES) @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) def test_missing_package_script(mock_execute, project_definition_files): diff --git a/tests/testing_utils/fixtures.py b/tests/testing_utils/fixtures.py index 119c0d7e7c..1c8bd7c538 100644 --- a/tests/testing_utils/fixtures.py +++ b/tests/testing_utils/fixtures.py @@ -109,12 +109,20 @@ def _mock_connection_ctx_factory(cursor=mock_cursor(["row"], []), **kwargs): class MockConnectionCtx(mock.MagicMock): - def __init__(self, cursor=None, role: Optional[str] = "MockRole", *args, **kwargs): + def __init__( + self, + cursor=None, + role: Optional[str] = "MockRole", + warehouse: Optional[str] = "MockWarehouse", + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.queries: List[str] = [] self.cs = cursor self._checkout_count = 0 self._role = role + self._warehouse = warehouse def get_query(self): return "\n".join(self.queries) @@ -124,7 +132,7 @@ def get_queries(self): @property def warehouse(self): - return "MockWarehouse" + return self._warehouse @property def database(self):