Skip to content

Commit

Permalink
feedback round 4
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bgoel committed Jul 11, 2024
1 parent 5d9c35d commit 03a89f6
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 51 deletions.
29 changes: 29 additions & 0 deletions src/snowflake/cli/api/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,35 @@ def use_role(self, new_role: str):
if is_different_role:
self._execute_query(f"use role {prev_role}")

@contextmanager
def use_warehouse(self, new_wh: str):
"""
Switches to a different warehouse for a while, then switches back.
This is a no-op if the requested warehouse is already active.
If there is no default warehouse in the account, it will throw an error.
"""

wh_result = self._execute_query(
f"select current_warehouse()", cursor_class=DictCursor
).fetchone()
# If user has an assigned default warehouse, prev_wh will contain a value even if the warehouse is suspended.
try:
prev_wh = wh_result["CURRENT_WAREHOUSE()"]
except:
prev_wh = None

# new_wh is not None, and should already be a valid identifier, no additional check is performed here.
is_different_wh = new_wh != prev_wh
try:
if is_different_wh:
self._log.debug("Using warehouse: %s", new_wh)
self.use(object_type=ObjectType.WAREHOUSE, name=new_wh)
yield
finally:
if prev_wh and is_different_wh:
self._log.debug("Switching back to warehouse: %s", prev_wh)
self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh)

def create_password_secret(
self, name: str, username: str, password: str
) -> SnowflakeCursor:
Expand Down
64 changes: 26 additions & 38 deletions src/snowflake/cli/plugins/nativeapp/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import json
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import cached_property
from pathlib import Path
from textwrap import dedent
Expand All @@ -26,7 +25,6 @@
import jinja2
from click import ClickException
from snowflake.cli.api.console import cli_console as cc
from snowflake.cli.api.constants import ObjectType
from snowflake.cli.api.errno import (
DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED,
DOES_NOT_EXIST_OR_NOT_AUTHORIZED,
Expand Down Expand Up @@ -83,7 +81,6 @@
)
from snowflake.cli.plugins.stage.manager import StageManager
from snowflake.connector import ProgrammingError
from snowflake.connector.cursor import DictCursor

ApplicationOwnedObject = TypedDict("ApplicationOwnedObject", {"name": str, "type": str})

Expand Down Expand Up @@ -221,10 +218,34 @@ def stage_schema(self) -> Optional[str]:
def package_warehouse(self) -> Optional[str]:
return self.na_project.package_warehouse

def use_package_warehouse(self) -> str:
if self.package_warehouse:
return self.package_warehouse
raise ClickException(
dedent(
f"""\
Application package warehouse cannot be empty.
Please provide a value for it in your connection information or your project definition file.
"""
)
)

@property
def application_warehouse(self) -> Optional[str]:
return self.na_project.application_warehouse

def use_application_warehouse(self) -> str:
if self.application_warehouse:
return self.application_warehouse
raise ClickException(
dedent(
f"""\
Application warehouse cannot be empty.
Please provide a value for it in your connection information or your project definition file.
"""
)
)

@property
def project_identifier(self) -> str:
return self.na_project.project_identifier
Expand Down Expand Up @@ -257,39 +278,6 @@ def app_post_deploy_hooks(self) -> Optional[List[ApplicationPostDeployHook]]:
def debug_mode(self) -> bool:
return self.na_project.debug_mode

@contextmanager
def use_warehouse(self, new_wh: Optional[str]):
"""
Switches to a different warehouse for a while, then switches back.
This is a no-op if the requested warehouse is already active.
If there is no default warehouse in the account, it will throw an error.
"""

if new_wh is None:
# The new_wh parameter is an Optional[str] as the project definition attributes are Optional[str], passed directly to this method.
raise ClickException("Requested warehouse cannot be None.")

wh_result = self._execute_query(
f"select current_warehouse()", cursor_class=DictCursor
).fetchone()
# If user has an assigned default warehouse, prev_wh will contain a value even if the warehouse is suspended.
try:
prev_wh = wh_result["CURRENT_WAREHOUSE()"]
except:
prev_wh = None

# new_wh is not None, and should already be a valid identifier, no additional check is performed here.
is_different_wh = new_wh != prev_wh
try:
if is_different_wh:
self._log.debug("Using warehouse: %s", new_wh)
self.use(object_type=ObjectType.WAREHOUSE, name=new_wh)
yield
finally:
if prev_wh and is_different_wh:
self._log.debug("Switching back to warehouse: %s", prev_wh)
self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh)

@cached_property
def get_app_pkg_distribution_in_snowflake(self) -> str:
"""
Expand Down Expand Up @@ -520,7 +508,7 @@ def _application_object_to_str(self, obj: ApplicationOwnedObject) -> str:
def get_snowsight_url(self) -> str:
"""Returns the URL that can be used to visit this app via Snowsight."""
name = identifier_for_url(self.app_name)
with self.use_warehouse(self.application_warehouse):
with self.use_warehouse(self.use_application_warehouse()):
return make_snowsight_url(self._conn, f"/#/apps/application/{name}")

def create_app_package(self) -> None:
Expand Down Expand Up @@ -594,7 +582,7 @@ def _apply_package_scripts(self) -> None:
raise InvalidPackageScriptError(relpath, e)

# once we're sure all the templates expanded correctly, execute all of them
with self.use_warehouse(self.package_warehouse):
with self.use_warehouse(self.use_package_warehouse()):
try:
for i, queries in enumerate(queued_queries):
cc.step(f"Applying package script: {self.package_scripts[i]}")
Expand Down
100 changes: 92 additions & 8 deletions tests/nativeapp/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import json
import os
from pathlib import Path
from textwrap import dedent
from typing import Optional
from unittest import mock
from unittest.mock import call

import pytest
from click import ClickException
from snowflake.cli.api.errno import DOES_NOT_EXIST_OR_NOT_AUTHORIZED
from snowflake.cli.api.project.definition_manager import DefinitionManager
from snowflake.cli.plugins.nativeapp.artifacts import BundleMap
Expand Down Expand Up @@ -80,8 +83,8 @@
}


def _get_na_manager():
dm = DefinitionManager()
def _get_na_manager(working_dir: Optional[str] = None):
dm = DefinitionManager(working_dir)
return NativeAppManager(
project_definition=dm.project_definition.native_app,
project_root=dm.project_root,
Expand Down Expand Up @@ -580,21 +583,41 @@ def test_get_existing_app_pkg_info_app_pkg_does_not_exist(
assert mock_execute.mock_calls == expected


# With connection warehouse, with PDF warehouse
# Without connection warehouse, with PDF warehouse
@mock.patch("snowflake.cli.plugins.connection.util.get_context")
@mock.patch("snowflake.cli.plugins.connection.util.get_account")
@mock.patch("snowflake.cli.plugins.connection.util.get_snowsight_host")
@mock.patch(NATIVEAPP_MANAGER_EXECUTE)
@mock_connection()
def test_get_snowsight_url(
@pytest.mark.parametrize(
"warehouse, fallback_warehouse_call, fallback_side_effect",
[
(
"MockWarehouse",
[mock.call("use warehouse MockWarehouse")],
[None],
),
(
None,
[],
[],
),
],
)
def test_get_snowsight_url_with_pdf_warehouse(
mock_conn,
mock_execute_query,
mock_snowsight_host,
mock_account,
mock_context,
warehouse,
fallback_warehouse_call,
fallback_side_effect,
temp_dir,
mock_cursor,
):
mock_conn.return_value = MockConnectionCtx()
mock_conn.return_value = MockConnectionCtx(warehouse=warehouse)
mock_snowsight_host.return_value = "https://host"
mock_context.return_value = "organization"
mock_account.return_value = "account"
Expand All @@ -609,21 +632,82 @@ def test_get_snowsight_url(
side_effects, expected = mock_execute_helper(
[
(
mock_cursor([{"CURRENT_WAREHOUSE()": "old_wh"}], []),
mock_cursor([{"CURRENT_WAREHOUSE()": warehouse}], []),
mock.call("select current_warehouse()", cursor_class=DictCursor),
),
(None, mock.call("use warehouse app_warehouse")),
(None, mock.call("use warehouse old_wh")),
]
)
mock_execute_query.side_effect = side_effects
mock_execute_query.side_effect = side_effects + fallback_side_effect

native_app_manager = _get_na_manager()
assert (
native_app_manager.get_snowsight_url()
== "https://host/organization/account/#/apps/application/MYAPP"
)
assert mock_execute_query.mock_calls == expected
assert mock_execute_query.mock_calls == expected + fallback_warehouse_call


# With connection warehouse, without PDF warehouse
# Without connection warehouse, without PDF warehouse
@mock.patch("snowflake.cli.plugins.connection.util.get_context")
@mock.patch("snowflake.cli.plugins.connection.util.get_account")
@mock.patch("snowflake.cli.plugins.connection.util.get_snowsight_host")
@mock.patch(NATIVEAPP_MANAGER_EXECUTE)
@mock_connection()
@pytest.mark.parametrize(
"project_definition_files, warehouse, expected_calls, fallback_side_effect",
[
(
"napp_project_1",
"MockWarehouse",
[mock.call("select current_warehouse()", cursor_class=DictCursor)],
[None],
),
(
"napp_project_1",
None,
[],
[],
),
],
indirect=["project_definition_files"],
)
def test_get_snowsight_url_without_pdf_warehouse(
mock_conn,
mock_execute_query,
mock_snowsight_host,
mock_account,
mock_context,
project_definition_files,
warehouse,
expected_calls,
fallback_side_effect,
mock_cursor,
):
mock_conn.return_value = MockConnectionCtx(warehouse=warehouse)
mock_snowsight_host.return_value = "https://host"
mock_context.return_value = "organization"
mock_account.return_value = "account"

working_dir: Path = project_definition_files[0].parent

mock_execute_query.side_effect = [
mock_cursor([{"CURRENT_WAREHOUSE()": warehouse}], [])
] + fallback_side_effect

native_app_manager = _get_na_manager(str(working_dir))
if warehouse:
assert (
native_app_manager.get_snowsight_url()
== "https://host/organization/account/#/apps/application/MYAPP_POLLY"
)
else:
with pytest.raises(ClickException) as err:
native_app_manager.get_snowsight_url()
assert "Application warehouse cannot be empty." in err.value.message

assert mock_execute_query.mock_calls == expected_calls


def test_ensure_correct_owner():
Expand Down
Loading

0 comments on commit 03a89f6

Please sign in to comment.