Skip to content

Commit

Permalink
move show_specific_object to SqlExecutionMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-cgorrie committed Feb 8, 2024
1 parent 2bc5392 commit d69854a
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 143 deletions.
30 changes: 28 additions & 2 deletions src/snowflake/cli/api/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from functools import cached_property
from io import StringIO
from textwrap import dedent
from typing import Iterable
from typing import Iterable, Optional

from click import ClickException
from snowflake.cli.api.cli_global_context import cli_context
from snowflake.connector.cursor import DictCursor, SnowflakeCursor
from snowflake.connector.errors import ProgrammingError
from snowflake.cli.api.cli_global_context import cli_context
from snowflake.cli.api.project.util import (
identifier_to_show_like_pattern,
unquote_identifier,
)
from snowflake.cli.api.utils.cursor import find_first_row
from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError


class SqlExecutionMixin:
Expand Down Expand Up @@ -137,3 +143,23 @@ def to_fully_qualified_name(self, name: str):

schema = self._conn.schema or "public"
return f"{self._conn.database}.{schema}.{name}".upper()

def show_specific_object(
self,
object_type_plural: str,
object_name: str,
name_col: str = "name",
in_clause: str = "",
) -> Optional[dict]:
# TODO: deal with fully-qualified names
show_obj_query = f"show {object_type_plural} like {identifier_to_show_like_pattern(object_name)} {in_clause}".strip()
show_obj_cursor = self._execute_query( # type: ignore
show_obj_query, cursor_class=DictCursor
)
if show_obj_cursor.rowcount is None:
raise SnowflakeSQLExecutionError(show_obj_query)
show_obj_row = find_first_row(
show_obj_cursor,
lambda row: row[name_col] == unquote_identifier(object_name),
)
return show_obj_row
18 changes: 18 additions & 0 deletions src/snowflake/cli/api/utils/cursor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Callable, List, Optional

from snowflake.connector.cursor import DictCursor


def _rows_generator(cursor: DictCursor, predicate: Callable[[dict], bool]):
return (row for row in cursor.fetchall() if predicate(row))


def find_all_rows(cursor: DictCursor, predicate: Callable[[dict], bool]) -> List[dict]:
return list(_rows_generator(cursor, predicate))


def find_first_row(
cursor: DictCursor, predicate: Callable[[dict], bool]
) -> Optional[dict]:
"""Returns the first row that matches the predicate, or None."""
return next(_rows_generator(cursor, predicate), None)
34 changes: 6 additions & 28 deletions src/snowflake/cli/plugins/nativeapp/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
OWNER_COL,
)
from snowflake.cli.plugins.nativeapp.exceptions import UnexpectedOwnerError
from snowflake.cli.plugins.nativeapp.utils import find_first_row
from snowflake.cli.plugins.object.stage.diff import (
DiffResult,
stage_diff,
Expand Down Expand Up @@ -99,30 +98,7 @@ def process(self, *args, **kwargs):
pass


class FindObjectRowMixin:
def __init__(self):
if not isinstance(self, SqlExecutionMixin):
raise TypeError(
"FindObjectRowMixin can only be used with classes that also mix in SqlExecutionMixin."
)

def find_row_by_object_name(
self, object_type_plural: str, object_name: str, name_col: str = "name"
) -> Optional[dict]:
show_obj_query = f"show {object_type_plural} like {identifier_to_show_like_pattern(object_name)}"
show_obj_cursor = self._execute_query( # type: ignore
show_obj_query, cursor_class=DictCursor
)
if show_obj_cursor.rowcount is None:
raise SnowflakeSQLExecutionError(show_obj_query)
show_obj_row = find_first_row(
show_obj_cursor,
lambda row: row[name_col] == unquote_identifier(object_name),
)
return show_obj_row


class NativeAppManager(SqlExecutionMixin, FindObjectRowMixin):
class NativeAppManager(SqlExecutionMixin):
"""
Base class with frequently used functionality already implemented and ready to be used by related subclasses.
"""
Expand Down Expand Up @@ -328,7 +304,9 @@ def get_existing_app_info(self) -> Optional[dict]:
It executes a 'show applications like' query and returns the result as single row, if one exists.
"""
with self.use_role(self.app_role):
return self.find_row_by_object_name("applications", self.app_name, NAME_COL)
return self.show_specific_object(
"applications", self.app_name, name_col=NAME_COL
)

def get_existing_app_pkg_info(self) -> Optional[dict]:
"""
Expand All @@ -337,8 +315,8 @@ def get_existing_app_pkg_info(self) -> Optional[dict]:
"""

with self.use_role(self.package_role):
return self.find_row_by_object_name(
"application packages", self.package_name, NAME_COL
return self.show_specific_object(
"application packages", self.package_name, name_col=NAME_COL
)

def get_snowsight_url(self) -> str:
Expand Down
23 changes: 6 additions & 17 deletions src/snowflake/cli/plugins/nativeapp/run_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from click import UsageError
from rich import print
from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError
from snowflake.cli.api.project.util import unquote_identifier
from snowflake.cli.plugins.nativeapp.constants import (
COMMENT_COL,
INTERNAL_DISTRIBUTION,
Expand All @@ -29,11 +28,10 @@
generic_sql_error_handler,
)
from snowflake.cli.plugins.nativeapp.policy import PolicyBase
from snowflake.cli.plugins.nativeapp.utils import find_first_row
from snowflake.cli.plugins.object.stage.diff import DiffResult
from snowflake.cli.plugins.object.stage.manager import StageManager
from snowflake.connector import ProgrammingError
from snowflake.connector.cursor import DictCursor, SnowflakeCursor
from snowflake.connector.cursor import SnowflakeCursor

UPGRADE_RESTRICTION_CODES = {93044, 93055, 93045, 93046}

Expand Down Expand Up @@ -227,28 +225,19 @@ def get_existing_version_info(self, version: str) -> Optional[dict]:
It executes a 'show versions like ... in application package' query and returns the result as single row, if one exists.
"""
with self.use_role(self.package_role):
show_obj_query = f"show versions like '{unquote_identifier(version)}' in application package {self.package_name}"

try:
show_obj_cursor = self._execute_query(
show_obj_query, cursor_class=DictCursor
return self.show_specific_object(
"versions",
version,
name_col=VERSION_COL,
in_clause=f"in application package {self.package_name}",
)
except ProgrammingError as err:
if err.msg.__contains__("does not exist or not authorized"):
raise ApplicationPackageDoesNotExistError(self.package_name)
else:
generic_sql_error_handler(err=err, role=self.package_role)

if show_obj_cursor.rowcount is None:
raise SnowflakeSQLExecutionError(show_obj_query)

show_obj_row = find_first_row(
show_obj_cursor,
lambda row: row[VERSION_COL] == unquote_identifier(version),
)

return show_obj_row

def drop_application_before_upgrade(self, policy: PolicyBase, is_interactive: bool):
"""
This method will attempt to drop an application if a previous upgrade fails.
Expand Down
18 changes: 0 additions & 18 deletions src/snowflake/cli/plugins/nativeapp/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,4 @@
from sys import stdin, stdout
from typing import Callable, List, Optional

from snowflake.connector.cursor import DictCursor


def _rows_generator(cursor: DictCursor, predicate: Callable[[dict], bool]):
return (row for row in cursor.fetchall() if predicate(row))


def find_all_rows(cursor: DictCursor, predicate: Callable[[dict], bool]) -> List[dict]:
return list(_rows_generator(cursor, predicate))


def find_first_row(
cursor: DictCursor, predicate: Callable[[dict], bool]
) -> Optional[dict]:
"""Returns the first row that matches the predicate, or None."""
return next(_rows_generator(cursor, predicate), None)


def needs_confirmation(needs_confirm: bool, auto_yes: bool) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from snowflake.cli.plugins.nativeapp.policy import PolicyBase
from snowflake.cli.plugins.nativeapp.run_processor import NativeAppRunProcessor
from snowflake.cli.plugins.nativeapp.utils import (
from snowflake.cli.api.utils.cursor import (
find_all_rows,
find_first_row,
)
Expand Down
77 changes: 0 additions & 77 deletions tests/nativeapp/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
NativeAppManager,
SnowflakeSQLExecutionError,
ensure_correct_owner,
FindObjectRowMixin,
)
from snowflake.cli.plugins.object.stage.diff import DiffResult
from snowflake.cli.api.project.definition_manager import DefinitionManager
Expand Down Expand Up @@ -512,79 +511,3 @@ def test_is_correct_owner_bad_owner():
test_row = {"name": "some_name", "owner": "wrong_role", "comment": "some_comment"}
with pytest.raises(UnexpectedOwnerError):
ensure_correct_owner(row=test_row, role="right_role", obj_name="some_name")


def test_find_object_row_mixin_needs_sql_execution_mixin_correct():
class CorrectUsage(SqlExecutionMixin, FindObjectRowMixin):
pass

assert isinstance(CorrectUsage(), FindObjectRowMixin)


def test_find_object_row_mixin_needs_sql_execution_mixin_incorrect():
with pytest.raises(TypeError) as expected_error:

class IncorrectUsage(FindObjectRowMixin):
pass

IncorrectUsage()
assert (
"FindObjectRowMixin can only be used with classes that also mix in SqlExecutionMixin"
in str(expected_error.value)
)


@mock.patch(NATIVEAPP_MANAGER_EXECUTE)
def test_find_row_by_object_name(mock_execute, temp_dir, mock_cursor):
current_working_directory = os.getcwd()
create_named_file(
file_name="snowflake.yml",
dir=current_working_directory,
contents=[mock_snowflake_yml_file],
)
mock_columns = ["id", "created_on"]
mock_row_dict = {c: r for c, r in zip(mock_columns, ["EXAMPLE_ID", "dummy"])}
cursor = mock_cursor(rows=[mock_row_dict], columns=mock_columns)
mock_execute.return_value = cursor
result = _get_na_manager().find_row_by_object_name("objects", "example_id", "id")
mock_execute.assert_called_once_with(
r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor
)
assert result == mock_row_dict


@mock.patch(NATIVEAPP_MANAGER_EXECUTE)
def test_find_row_by_object_name_no_match(mock_execute, temp_dir, mock_cursor):
current_working_directory = os.getcwd()
create_named_file(
file_name="snowflake.yml",
dir=current_working_directory,
contents=[mock_snowflake_yml_file],
)
mock_columns = ["id", "created_on"]
mock_row_dict = {c: r for c, r in zip(mock_columns, ["OTHER_ID", "dummy"])}
cursor = mock_cursor(rows=[mock_row_dict], columns=mock_columns)
mock_execute.return_value = cursor
result = _get_na_manager().find_row_by_object_name("objects", "example_id", "id")
mock_execute.assert_called_once_with(
r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor
)
assert result == None


@mock.patch(NATIVEAPP_MANAGER_EXECUTE)
def test_find_row_by_object_name_sql_execution_error(mock_execute, temp_dir):
current_working_directory = os.getcwd()
create_named_file(
file_name="snowflake.yml",
dir=current_working_directory,
contents=[mock_snowflake_yml_file],
)
cursor = Mock(spec=DictCursor)
cursor.rowcount = None
mock_execute.return_value = cursor
with pytest.raises(SnowflakeSQLExecutionError):
_get_na_manager().find_row_by_object_name("objects", "example_id", "id")
mock_execute.assert_called_once_with(
r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor
)
45 changes: 45 additions & 0 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest import mock
from snowflake.connector.cursor import DictCursor
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError

import pytest

Expand Down Expand Up @@ -106,3 +109,45 @@ def test_sql_overrides_connection_configuration(mock_conn, runner, mock_cursor):
role="rolenameValue",
password="passFromTest",
)


@mock.patch("snowflake.cli.plugins.sql.manager.SqlExecutionMixin._execute_query")
def test_show_specific_object(mock_execute, mock_cursor):
mock_columns = ["id", "created_on"]
mock_row_dict = {c: r for c, r in zip(mock_columns, ["EXAMPLE_ID", "dummy"])}
cursor = mock_cursor(rows=[mock_row_dict], columns=mock_columns)
mock_execute.return_value = cursor
result = SqlExecutionMixin().show_specific_object(
"objects", "example_id", name_col="id"
)
mock_execute.assert_called_once_with(
r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor
)
assert result == mock_row_dict


@mock.patch("snowflake.cli.plugins.sql.manager.SqlExecutionMixin._execute_query")
def test_show_specific_object_no_match(mock_execute, mock_cursor):
mock_columns = ["id", "created_on"]
mock_row_dict = {c: r for c, r in zip(mock_columns, ["OTHER_ID", "dummy"])}
cursor = mock_cursor(rows=[mock_row_dict], columns=mock_columns)
mock_execute.return_value = cursor
result = SqlExecutionMixin().show_specific_object(
"objects", "example_id", name_col="id"
)
mock_execute.assert_called_once_with(
r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor
)
assert result == None


@mock.patch("snowflake.cli.plugins.sql.manager.SqlExecutionMixin._execute_query")
def test_show_specific_object_sql_execution_error(mock_execute):
cursor = mock.Mock(spec=DictCursor)
cursor.rowcount = None
mock_execute.return_value = cursor
with pytest.raises(SnowflakeSQLExecutionError):
SqlExecutionMixin().show_specific_object("objects", "example_id", name_col="id")
mock_execute.assert_called_once_with(
r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor
)

0 comments on commit d69854a

Please sign in to comment.