Skip to content

Commit

Permalink
feedback round 3
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bgoel committed Jul 11, 2024
1 parent 0d92999 commit 5d9c35d
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 55 deletions.
3 changes: 1 addition & 2 deletions src/snowflake/cli/api/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,12 @@ def use_role(self, new_role: str):
prev_role = role_result["CURRENT_ROLE()"]
is_different_role = new_role.lower() != prev_role.lower()
if is_different_role:
self._log.debug("Temporarily switching to a different role: %s", new_role)
self._log.debug("Assuming different role: %s", new_role)
self._execute_query(f"use role {new_role}")
try:
yield
finally:
if is_different_role:
self._log.debug("Switching back to the original role: %s", prev_role)
self._execute_query(f"use role {prev_role}")

def create_password_secret(
Expand Down
67 changes: 18 additions & 49 deletions src/snowflake/cli/plugins/nativeapp/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping
from snowflake.cli.api.project.util import (
identifier_for_url,
to_identifier,
unquote_identifier,
)
from snowflake.cli.api.sql_execution import SqlExecutionMixin
Expand Down Expand Up @@ -265,61 +264,31 @@ def use_warehouse(self, new_wh: Optional[str]):
This is a no-op if the requested warehouse is already active.
If there is no default warehouse in the account, it will throw an error.
"""

if new_wh is None:
# The new_wh parameter is an Optional[str] as the project definition attributes are Optional[str], passed directly to this method.
raise ClickException("Requested warehouse cannot be None.")

wh_result = self._execute_query(
f"select current_warehouse()", cursor_class=DictCursor
).fetchone()
# If user has an assigned default warehouse, prev_wh will contain a value even if the warehouse is suspended.
prev_wh = wh_result["CURRENT_WAREHOUSE()"]

if new_wh is not None:
new_wh_id = to_identifier(new_wh)

if prev_wh is None and new_wh is None:
raise ClickException(
"Could not find both the connection and the requested warehouse in the Snowflake account for this user."
)

elif prev_wh is None:
self._log.debug("Using warehouse: %s", new_wh_id)
self.use(object_type=ObjectType.WAREHOUSE, name=new_wh_id)
try:
yield
finally:
self._log.debug(
"Continuing to use warehouse unless requested otherwise: %s",
new_wh_id,
)

elif new_wh is None:
self._log.debug(
"Requested warehouse is empty, continuing to use the connection warehouse: %s",
prev_wh,
)
# Activate a suspended warehouse, will be a no-op will already active
self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh)
try:
yield
finally:
self._log.debug(
"Continuing to use warehouse unless requested otherwise: %s",
prev_wh,
)
try:
prev_wh = wh_result["CURRENT_WAREHOUSE()"]
except:
prev_wh = None

else:
is_different_wh = new_wh_id != prev_wh
self._log.debug("Using warehouse: %s", new_wh_id)
# new_wh is not None, and should already be a valid identifier, no additional check is performed here.
is_different_wh = new_wh != prev_wh
try:
if is_different_wh:
# Activate the new warehouse
self.use(object_type=ObjectType.WAREHOUSE, name=new_wh_id)
else:
# Activate a suspended warehouse, will be a no-op will already active
self._log.debug("Using warehouse: %s", new_wh)
self.use(object_type=ObjectType.WAREHOUSE, name=new_wh)
yield
finally:
if prev_wh and is_different_wh:
self._log.debug("Switching back to warehouse: %s", prev_wh)
self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh)
try:
yield
finally:
if is_different_wh:
self._log.debug("Switching back to warehouse: %s", prev_wh)
self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh)

@cached_property
def get_app_pkg_distribution_in_snowflake(self) -> str:
Expand Down
8 changes: 6 additions & 2 deletions src/snowflake/cli/plugins/nativeapp/project_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,18 @@ def package_warehouse(self) -> Optional[str]:
if self.definition.package and self.definition.package.warehouse:
return to_identifier(self.definition.package.warehouse)
else:
return to_identifier(cli_context.connection.warehouse)
if cli_context.connection.warehouse:
return to_identifier(cli_context.connection.warehouse)
return None

@cached_property
def application_warehouse(self) -> Optional[str]:
if self.definition.application and self.definition.application.warehouse:
return to_identifier(self.definition.application.warehouse)
else:
return to_identifier(cli_context.connection.warehouse)
if cli_context.connection.warehouse:
return to_identifier(cli_context.connection.warehouse)
return None

@cached_property
def project_identifier(self) -> str:
Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/cli/plugins/nativeapp/run_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def __init__(self, project_definition: NativeApp, project_root: Path):
def _execute_sql_script(self, sql_script_path):
"""
Executing the SQL script in the provided file path after expanding template variables.
"use warehouse" and "use database" will be executed first if they are set in definition file or in the current connection.
This assumes that a relevant warehouse is already active.
Consequently, "use database" will be executed first if it is set in definition file or in the current connection.
"""
with open(sql_script_path) as f:
sql_script = f.read()
Expand Down
2 changes: 1 addition & 1 deletion tests/nativeapp/test_package_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def _get_na_manager(working_dir):
"napp_project_1",
[
mock.call("select current_warehouse()", cursor_class=DictCursor),
mock.call("use warehouse MockWarehouse"),
],
),
(
Expand Down Expand Up @@ -205,6 +204,7 @@ def test_package_scripts_w_warehouse_access_exception(
msg="Object does not exist, or operation cannot be performed.",
errno=DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED,
),
None,
]

mock_conn.return_value = MockConnectionCtx()
Expand Down
8 changes: 8 additions & 0 deletions tests/nativeapp/test_run_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def test_create_dev_app_w_warehouse_access_exception(
),
mock.call("use warehouse app_warehouse"),
),
(
None,
mock.call("use warehouse old_wh"),
),
(None, mock.call("use role old_role")),
]
)
Expand Down Expand Up @@ -1110,6 +1114,10 @@ def test_upgrade_app_warehouse_error(
),
mock.call("use warehouse app_warehouse"),
),
(
None,
mock.call("use warehouse old_wh"),
),
(None, mock.call("use role old_role")),
]
)
Expand Down

0 comments on commit 5d9c35d

Please sign in to comment.