Skip to content

Commit

Permalink
SNOW-1758715 Move small helper methods out of ApplicationEntity class (
Browse files Browse the repository at this point in the history
…#1800)

Moves `application_objects_to_str()` and `get_account_event_table()` out from the `ApplicationEntity` class to be module-level functions since they're not part of the public API of the entity
  • Loading branch information
sfc-gh-fcampbell authored Oct 25, 2024
1 parent aadbc0f commit 6a00ae3
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 94 deletions.
58 changes: 25 additions & 33 deletions src/snowflake/cli/_plugins/nativeapp/entities/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from snowflake.cli._plugins.nativeapp.same_account_install_method import (
SameAccountInstallMethod,
)
from snowflake.cli._plugins.nativeapp.sf_facade import get_snowflake_facade
from snowflake.cli._plugins.nativeapp.utils import needs_confirmation
from snowflake.cli._plugins.workspace.context import ActionContext
from snowflake.cli.api.cli_global_context import get_cli_context
Expand Down Expand Up @@ -325,19 +326,19 @@ def action_drop(
console.message(cascade_true_message)
with console.indented():
for obj in application_objects:
console.message(self.application_object_to_str(obj))
console.message(_application_object_to_str(obj))
elif cascade is False:
# If the user explicitly passed the --no-cascade flag
console.message(cascade_false_message)
with console.indented():
for obj in application_objects:
console.message(self.application_object_to_str(obj))
console.message(_application_object_to_str(obj))
elif interactive:
# If the user didn't pass any cascade flag and the session is interactive
console.message(message_prefix)
with console.indented():
for obj in application_objects:
console.message(self.application_object_to_str(obj))
console.message(_application_object_to_str(obj))
user_response = typer.prompt(
interactive_prompt,
show_default=False,
Expand All @@ -354,7 +355,7 @@ def action_drop(
console.message(message_prefix)
with console.indented():
for obj in application_objects:
console.message(self.application_object_to_str(obj))
console.message(_application_object_to_str(obj))
console.message(non_interactive_abort)
raise typer.Abort()
elif cascade is None:
Expand Down Expand Up @@ -436,25 +437,6 @@ def get_objects_owned_by_application(self) -> List[ApplicationOwnedObject]:
).fetchall()
return [{"name": row[1], "type": row[2]} for row in results]

@classmethod
def application_objects_to_str(
cls, application_objects: list[ApplicationOwnedObject]
) -> str:
"""
Returns a list in an "(Object Type) Object Name" format. Database-level and schema-level object names are fully qualified:
(COMPUTE_POOL) POOL_NAME
(DATABASE) DB_NAME
(SCHEMA) DB_NAME.PUBLIC
...
"""
return "\n".join(
[cls.application_object_to_str(obj) for obj in application_objects]
)

@staticmethod
def application_object_to_str(obj: ApplicationOwnedObject) -> str:
return f"({obj['type']}) {obj['name']}"

def create_or_upgrade_app(
self,
package_model: ApplicationPackageEntityModel,
Expand Down Expand Up @@ -652,7 +634,7 @@ def drop_application_before_upgrade(
if cascade:
try:
if application_objects := self.get_objects_owned_by_application():
application_objects_str = self.application_objects_to_str(
application_objects_str = _application_objects_to_str(
application_objects
)
console.message(
Expand Down Expand Up @@ -716,8 +698,8 @@ def get_events(
if first >= 0 and last >= 0:
raise ValueError("first and last cannot be used together")

account_event_table = self.get_account_event_table()
if not account_event_table or account_event_table == "NONE":
account_event_table = get_snowflake_facade().get_account_event_table()
if account_event_table is None:
raise NoEventTableForAccount()

# resource_attributes uses the unquoted/uppercase app and package name
Expand Down Expand Up @@ -849,13 +831,6 @@ def stream_events(
except KeyboardInterrupt:
return

@staticmethod
def get_account_event_table():
query = "show parameters like 'event_table' in account"
sql_executor = get_sql_executor()
results = sql_executor.execute_query(query, cursor_class=DictCursor)
return next((r["value"] for r in results if r["key"] == "EVENT_TABLE"), "")

def get_snowsight_url(self) -> str:
"""Returns the URL that can be used to visit this app via Snowsight."""
name = identifier_for_url(self._entity_model.fqn.name)
Expand Down Expand Up @@ -885,3 +860,20 @@ def _new_events_only(previous_events: list[dict], new_events: list[dict]) -> lis
# either be in both lists or in new_events only
new_events.remove(event)
return new_events


def _application_objects_to_str(
application_objects: list[ApplicationOwnedObject],
) -> str:
"""
Returns a list in an "(Object Type) Object Name" format. Database-level and schema-level object names are fully qualified:
(COMPUTE_POOL) POOL_NAME
(DATABASE) DB_NAME
(SCHEMA) DB_NAME.PUBLIC
...
"""
return "\n".join([_application_object_to_str(obj) for obj in application_objects])


def _application_object_to_str(obj: ApplicationOwnedObject) -> str:
return f"({obj['type']}) {obj['name']}"
16 changes: 15 additions & 1 deletion src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from snowflake.cli.api.project.util import to_identifier
from snowflake.cli.api.sql_execution import SqlExecutor
from snowflake.connector import ProgrammingError
from snowflake.connector import DictCursor, ProgrammingError


class SnowflakeSQLFacade:
Expand Down Expand Up @@ -193,3 +193,17 @@ def execute_user_script(
raise UserScriptError(script_name, err.msg) from err
except Exception as err:
handle_unclassified_error(err, f"Failed to run script {script_name}.")

def get_account_event_table(self, role: str | None = None) -> str | None:
"""
Returns the name of the event table for the account.
If the account has no event table set up or the event table is set to NONE, returns None.
@param [Optional] role: Role to switch to while running this script. Current role will be used if no role is passed in.
"""
query = "show parameters like 'event_table' in account"
with self._use_role_optional(role):
results = self._sql_executor.execute_query(query, cursor_class=DictCursor)
table = next((r["value"] for r in results if r["key"] == "EVENT_TABLE"), None)
if table is None or table == "NONE":
return None
return table
66 changes: 8 additions & 58 deletions tests/nativeapp/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@
mock_get_app_pkg_distribution_in_sf,
)
from tests.nativeapp.utils import (
APP_ENTITY_GET_ACCOUNT_EVENT_TABLE,
APP_PACKAGE_ENTITY_DEPLOY,
APP_PACKAGE_ENTITY_GET_EXISTING_APP_PKG_INFO,
APP_PACKAGE_ENTITY_IS_DISTRIBUTION_SAME,
ENTITIES_UTILS_MODULE,
SQL_EXECUTOR_EXECUTE,
SQL_FACADE_GET_ACCOUNT_EVENT_TABLE,
mock_execute_helper,
mock_snowflake_yml_file_v2,
quoted_override_yml_file_v2,
Expand Down Expand Up @@ -1460,55 +1460,6 @@ def test_validate_raw_returns_data(mock_execute, temp_dir, mock_cursor):
assert mock_execute.mock_calls == expected


@mock.patch(SQL_EXECUTOR_EXECUTE)
def test_account_event_table(mock_execute, temp_dir, mock_cursor):
create_named_file(
file_name="snowflake.yml",
dir_name=temp_dir,
contents=[mock_snowflake_yml_file_v2],
)

event_table = "db.schema.event_table"
side_effects, expected = mock_execute_helper(
[
(
mock_cursor([dict(key="EVENT_TABLE", value=event_table)], []),
mock.call(
"show parameters like 'event_table' in account",
cursor_class=DictCursor,
),
),
]
)
mock_execute.side_effect = side_effects

assert ApplicationEntity.get_account_event_table() == event_table


@mock.patch(SQL_EXECUTOR_EXECUTE)
def test_account_event_table_not_set_up(mock_execute, temp_dir, mock_cursor):
create_named_file(
file_name="snowflake.yml",
dir_name=temp_dir,
contents=[mock_snowflake_yml_file_v2],
)

side_effects, expected = mock_execute_helper(
[
(
mock_cursor([], []),
mock.call(
"show parameters like 'event_table' in account",
cursor_class=DictCursor,
),
),
]
)
mock_execute.side_effect = side_effects

assert ApplicationEntity.get_account_event_table() == ""


@pytest.mark.parametrize(
["since", "expected_since_clause"],
[
Expand Down Expand Up @@ -1615,7 +1566,7 @@ def test_account_event_table_not_set_up(mock_execute, temp_dir, mock_cursor):
],
)
@mock.patch(
APP_ENTITY_GET_ACCOUNT_EVENT_TABLE,
SQL_FACADE_GET_ACCOUNT_EVENT_TABLE,
return_value="db.schema.event_table",
)
@mock.patch(SQL_EXECUTOR_EXECUTE)
Expand Down Expand Up @@ -1707,7 +1658,7 @@ def get_events():


@mock.patch(
APP_ENTITY_GET_ACCOUNT_EVENT_TABLE,
SQL_FACADE_GET_ACCOUNT_EVENT_TABLE,
return_value="db.schema.event_table",
)
@mock.patch(SQL_EXECUTOR_EXECUTE)
Expand Down Expand Up @@ -1762,12 +1713,11 @@ def test_get_events_quoted_app_name(
assert mock_execute.mock_calls == expected


@pytest.mark.parametrize("return_value", [None, "NONE"])
@mock.patch(APP_ENTITY_GET_ACCOUNT_EVENT_TABLE)
@mock.patch(SQL_FACADE_GET_ACCOUNT_EVENT_TABLE)
def test_get_events_no_event_table(
mock_account_event_table, return_value, temp_dir, mock_cursor, workspace_context
mock_account_event_table, temp_dir, mock_cursor, workspace_context
):
mock_account_event_table.return_value = return_value
mock_account_event_table.return_value = None
create_named_file(
file_name="snowflake.yml",
dir_name=temp_dir,
Expand All @@ -1783,7 +1733,7 @@ def test_get_events_no_event_table(


@mock.patch(
APP_ENTITY_GET_ACCOUNT_EVENT_TABLE,
SQL_FACADE_GET_ACCOUNT_EVENT_TABLE,
return_value="db.schema.non_existent_event_table",
)
@mock.patch(SQL_EXECUTOR_EXECUTE)
Expand Down Expand Up @@ -1845,7 +1795,7 @@ def test_get_events_event_table_dne_or_unauthorized(


@mock.patch(
APP_ENTITY_GET_ACCOUNT_EVENT_TABLE,
SQL_FACADE_GET_ACCOUNT_EVENT_TABLE,
return_value="db.schema.event_table",
)
@mock.patch(SQL_EXECUTOR_EXECUTE)
Expand Down
35 changes: 34 additions & 1 deletion tests/nativeapp/test_sf_sql_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED,
NO_WAREHOUSE_SELECTED_IN_SESSION,
)
from snowflake.connector import DatabaseError, Error
from snowflake.connector import DatabaseError, DictCursor, Error
from snowflake.connector.errors import (
InternalServerError,
ProgrammingError,
Expand Down Expand Up @@ -856,3 +856,36 @@ def test_use_db_bubbles_errors(
pass

assert error_message in str(err)


@mock.patch(SQL_EXECUTOR_EXECUTE)
@pytest.mark.parametrize(
"parameter_value,event_table",
[
["db.schema.event_table", "db.schema.event_table"],
[None, None],
["NONE", None],
],
)
def test_account_event_table(
mock_execute_query, mock_cursor, parameter_value, event_table
):
query_result = (
[dict(key="EVENT_TABLE", value=parameter_value)]
if parameter_value is not None
else []
)
side_effects, expected = mock_execute_helper(
[
(
mock_cursor(query_result, []),
mock.call(
"show parameters like 'event_table' in account",
cursor_class=DictCursor,
),
),
]
)
mock_execute_query.side_effect = side_effects

assert sql_facade.get_account_event_table() == event_table
5 changes: 4 additions & 1 deletion tests/nativeapp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
APP_ENTITY_GET_OBJECTS_OWNED_BY_APPLICATION = (
f"{APP_ENTITY}.get_objects_owned_by_application"
)
APP_ENTITY_GET_ACCOUNT_EVENT_TABLE = f"{APP_ENTITY}.get_account_event_table"

APP_PACKAGE_ENTITY = "snowflake.cli._plugins.nativeapp.entities.application_package.ApplicationPackageEntity"
APP_PACKAGE_ENTITY_DEPLOY = f"{APP_PACKAGE_ENTITY}._deploy"
Expand All @@ -63,6 +62,10 @@
SQL_EXECUTOR_EXECUTE = f"{ENTITIES_COMMON_MODULE}.SqlExecutor._execute_query"
SQL_EXECUTOR_EXECUTE_QUERIES = f"{ENTITIES_COMMON_MODULE}.SqlExecutor._execute_queries"

SQL_FACADE_MODULE = "snowflake.cli._plugins.nativeapp.sf_facade"
SQL_FACADE = f"{SQL_FACADE_MODULE}.SnowflakeSQLFacade"
SQL_FACADE_GET_ACCOUNT_EVENT_TABLE = f"{SQL_FACADE}.get_account_event_table"

mock_snowflake_yml_file = dedent(
"""\
definition_version: 1
Expand Down

0 comments on commit 6a00ae3

Please sign in to comment.