Skip to content

Commit

Permalink
SNOW-1011766: Factoring out shared logic of getting a row from SHOW .…
Browse files Browse the repository at this point in the history
….. LIKE ... based on object name from NativeAppManager and ImageRepositoryManager to a mixin.
  • Loading branch information
sfc-gh-davwang committed Feb 6, 2024
1 parent 4437bdd commit d8b9b88
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 87 deletions.
11 changes: 9 additions & 2 deletions src/snowflake/cli/api/project/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def validate_version(version: str):
)


def escape_like_pattern(pattern: str) -> str:
pattern = pattern.replace("%", r"\\%").replace("_", r"\\_")
def escape_like_pattern(pattern: str, escape_sequence: str = r"\\") -> str:
pattern = pattern.replace("%", rf"{escape_sequence}%").replace(
"_", rf"{escape_sequence}_"
)
return pattern


def identifier_to_show_like_pattern(identifier: str) -> str:
"""Unquotes and escapes special characters for an identifier to be used as a pattern for a 'SHOW <object_type> LIKE <identifier>' query."""
return f"'{escape_like_pattern(unquote_identifier(identifier))}'"
59 changes: 28 additions & 31 deletions src/snowflake/cli/plugins/nativeapp/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from snowflake.cli.api.project.util import (
extract_schema,
identifier_to_show_like_pattern,
to_identifier,
unquote_identifier,
)
Expand Down Expand Up @@ -98,7 +99,30 @@ def process(self, *args, **kwargs):
pass


class NativeAppManager(SqlExecutionMixin):
class FindObjectRowMixin:
def __init__(self):
if not isinstance(self, SqlExecutionMixin):
raise TypeError(
"FindObjectRowMixin can only be used with classes that also mix in SqlExecutionMixin."
)

def find_row_by_object_name(
self, object_type_plural: str, object_name: str, name_col: str = "name"
) -> Optional[dict]:
show_obj_query = f"show {object_type_plural} like {identifier_to_show_like_pattern(object_name)}"
show_obj_cursor = self._execute_query( # type: ignore
show_obj_query, cursor_class=DictCursor
)
if show_obj_cursor.rowcount is None:
raise SnowflakeSQLExecutionError(show_obj_query)
show_obj_row = find_first_row(
show_obj_cursor,
lambda row: row[name_col] == unquote_identifier(object_name),
)
return show_obj_row


class NativeAppManager(SqlExecutionMixin, FindObjectRowMixin):
"""
Base class with frequently used functionality already implemented and ready to be used by related subclasses.
"""
Expand Down Expand Up @@ -304,23 +328,7 @@ def get_existing_app_info(self) -> Optional[dict]:
It executes a 'show applications like' query and returns the result as single row, if one exists.
"""
with self.use_role(self.app_role):
show_obj_query = (
f"show applications like '{unquote_identifier(self.app_name)}'"
)
show_obj_cursor = self._execute_query(
show_obj_query,
cursor_class=DictCursor,
)

if show_obj_cursor.rowcount is None:
raise SnowflakeSQLExecutionError(show_obj_query)

show_obj_row = find_first_row(
show_obj_cursor,
lambda row: row[NAME_COL] == unquote_identifier(self.app_name),
)

return show_obj_row
return self.find_row_by_object_name("applications", self.app_name, NAME_COL)

def get_existing_app_pkg_info(self) -> Optional[dict]:
"""
Expand All @@ -329,21 +337,10 @@ def get_existing_app_pkg_info(self) -> Optional[dict]:
"""

with self.use_role(self.package_role):
show_obj_query = f"show application packages like '{unquote_identifier(self.package_name)}'"
show_obj_cursor = self._execute_query(
show_obj_query, cursor_class=DictCursor
return self.find_row_by_object_name(
"application packages", self.package_name, NAME_COL
)

if show_obj_cursor.rowcount is None:
raise SnowflakeSQLExecutionError(show_obj_query)

show_obj_row = find_first_row(
show_obj_cursor,
lambda row: row[NAME_COL] == unquote_identifier(self.package_name),
)

return show_obj_row # Can be None or a dict

def get_snowsight_url(self) -> str:
"""Returns the URL that can be used to visit this app via Snowsight."""
name = unquote_identifier(self.app_name)
Expand Down
28 changes: 6 additions & 22 deletions src/snowflake/cli/plugins/spcs/image_repository/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import click
from click import ClickException
from snowflake.cli.api.project.util import (
escape_like_pattern,
is_valid_unquoted_identifier,
)
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.connector.cursor import DictCursor
from snowflake.cli.plugins.nativeapp.manager import FindObjectRowMixin


class ImageRepositoryManager(SqlExecutionMixin):
class ImageRepositoryManager(SqlExecutionMixin, FindObjectRowMixin):
def get_database(self):
return self._conn.database

Expand All @@ -27,28 +26,13 @@ def get_repository_row(self, repo_name: str) -> Dict:
f"repo_name '{repo_name}' is not a valid unquoted Snowflake identifier"
)

repo_name = repo_name.upper()

# because image repositories only support unquoted identifiers, SHOW LIKE should only return one or zero rows
repository_list_query = (
f"show image repositories like '{escape_like_pattern(repo_name)}'"
)

result_set = self._execute_schema_query(
repository_list_query, cursor_class=DictCursor
)
results = result_set.fetchall()

colored_repo_name = click.style(f"'{repo_name}'", fg="green")
if len(results) == 0:
repo_row = self.find_row_by_object_name("image repositories", repo_name, "name")
if repo_row is None:
colored_repo_name = click.style(f"'{repo_name.upper()}'", fg="green")
raise ClickException(
f"Image repository {colored_repo_name} does not exist or not authorized."
)
elif len(results) > 1:
raise ClickException(
f"Found more than one image repository with name matching {colored_repo_name}. This is unexpected."
)
return results[0]
return repo_row

def get_repository_url(self, repo_name: str):
if not is_valid_unquoted_identifier(repo_name):
Expand Down
88 changes: 83 additions & 5 deletions tests/nativeapp/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest
from textwrap import dedent

from unittest.mock import Mock
from snowflake.cli.plugins.nativeapp.constants import (
LOOSE_FILES_MAGIC_VERSION,
NAME_COL,
Expand All @@ -11,10 +10,11 @@
NativeAppManager,
SnowflakeSQLExecutionError,
ensure_correct_owner,
FindObjectRowMixin,
)
from snowflake.cli.plugins.object.stage.diff import DiffResult
from snowflake.cli.api.project.definition_manager import DefinitionManager
from snowflake.connector import ProgrammingError
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.connector.cursor import DictCursor

from tests.nativeapp.patch_utils import (
Expand Down Expand Up @@ -415,7 +415,8 @@ def test_get_existing_app_pkg_info_app_pkg_exists(mock_execute, temp_dir, mock_c
[],
),
mock.call(
"show application packages like 'APP_PKG'", cursor_class=DictCursor
r"show application packages like 'APP\\_PKG'",
cursor_class=DictCursor,
),
),
(None, mock.call("use role old_role")),
Expand Down Expand Up @@ -451,7 +452,8 @@ def test_get_existing_app_pkg_info_app_pkg_does_not_exist(
(
mock_cursor([], []),
mock.call(
"show application packages like 'APP_PKG'", cursor_class=DictCursor
r"show application packages like 'APP\\_PKG'",
cursor_class=DictCursor,
),
),
(None, mock.call("use role old_role")),
Expand Down Expand Up @@ -510,3 +512,79 @@ def test_is_correct_owner_bad_owner():
test_row = {"name": "some_name", "owner": "wrong_role", "comment": "some_comment"}
with pytest.raises(UnexpectedOwnerError):
ensure_correct_owner(row=test_row, role="right_role", obj_name="some_name")


def test_find_object_row_mixin_needs_sql_execution_mixin_correct():
class CorrectUsage(SqlExecutionMixin, FindObjectRowMixin):
pass

assert isinstance(CorrectUsage(), FindObjectRowMixin)


def test_find_object_row_mixin_needs_sql_execution_mixin_incorrect():
with pytest.raises(TypeError) as expected_error:

class IncorrectUsage(FindObjectRowMixin):
pass

IncorrectUsage()
assert (
"FindObjectRowMixin can only be used with classes that also mix in SqlExecutionMixin"
in str(expected_error.value)
)


@mock.patch(NATIVEAPP_MANAGER_EXECUTE)
def test_find_row_by_object_name(mock_execute, temp_dir, mock_cursor):
current_working_directory = os.getcwd()
create_named_file(
file_name="snowflake.yml",
dir=current_working_directory,
contents=[mock_snowflake_yml_file],
)
mock_columns = ["id", "created_on"]
mock_row_dict = {c: r for c, r in zip(mock_columns, ["EXAMPLE_ID", "dummy"])}
cursor = mock_cursor(rows=[mock_row_dict], columns=mock_columns)
mock_execute.return_value = cursor
result = _get_na_manager().find_row_by_object_name("objects", "example_id", "id")
mock_execute.assert_called_once_with(
r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor
)
assert result == mock_row_dict


@mock.patch(NATIVEAPP_MANAGER_EXECUTE)
def test_find_row_by_object_name_no_match(mock_execute, temp_dir, mock_cursor):
current_working_directory = os.getcwd()
create_named_file(
file_name="snowflake.yml",
dir=current_working_directory,
contents=[mock_snowflake_yml_file],
)
mock_columns = ["id", "created_on"]
mock_row_dict = {c: r for c, r in zip(mock_columns, ["OTHER_ID", "dummy"])}
cursor = mock_cursor(rows=[mock_row_dict], columns=mock_columns)
mock_execute.return_value = cursor
result = _get_na_manager().find_row_by_object_name("objects", "example_id", "id")
mock_execute.assert_called_once_with(
r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor
)
assert result == None


@mock.patch(NATIVEAPP_MANAGER_EXECUTE)
def test_find_row_by_object_name_sql_execution_error(mock_execute, temp_dir):
current_working_directory = os.getcwd()
create_named_file(
file_name="snowflake.yml",
dir=current_working_directory,
contents=[mock_snowflake_yml_file],
)
cursor = Mock(spec=DictCursor)
cursor.rowcount = None
mock_execute.return_value = cursor
with pytest.raises(SnowflakeSQLExecutionError):
_get_na_manager().find_row_by_object_name("objects", "example_id", "id")
mock_execute.assert_called_once_with(
r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor
)
20 changes: 20 additions & 0 deletions tests/project/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
to_identifier,
to_string_literal,
escape_like_pattern,
identifier_to_show_like_pattern,
)

VALID_UNQUOTED_IDENTIFIERS = (
Expand Down Expand Up @@ -177,3 +178,22 @@ def test_to_string_literal(raw_string, literal):
)
def test_escape_like_pattern(raw_string, escaped):
assert escape_like_pattern(raw_string) == escaped


@pytest.mark.parametrize(
"identifier, pattern",
[
(r"underscore_table", r"'UNDERSCORE\\_TABLE'"),
(r"percent%%table", r"'PERCENT\\%\\%TABLE'"),
(r"__many__under__scores__", r"'\\_\\_MANY\\_\\_UNDER\\_\\_SCORES\\_\\_'"),
(r"mixed_underscore%percent", r"'MIXED\\_UNDERSCORE\\%PERCENT'"),
(r"regular$table", r"'REGULAR$TABLE'"),
(
r'"underscore_table"',
r"'underscore\\_table'",
), # quoted identifiers keep case
(r'"underscore_TABLE%"', r"'underscore\\_TABLE\\%'"),
],
)
def test_identifier_to_show_like_pattern(identifier, pattern):
assert identifier_to_show_like_pattern(identifier) == pattern
31 changes: 6 additions & 25 deletions tests/spcs/test_image_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,44 +134,25 @@ def test_get_repository_url_cli(mock_url, runner):


@mock.patch(
"snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._execute_schema_query"
"snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.find_row_by_object_name"
)
def test_get_repository_row(mock_execute, mock_cursor):
mock_execute.return_value = mock_cursor(
rows=MOCK_ROWS_DICT,
columns=MOCK_COLUMNS,
)
def test_get_repository_row(mock_find, mock_cursor):
mock_find.return_value = MOCK_ROWS_DICT[0]
result = ImageRepositoryManager().get_repository_row(MOCK_ROWS_DICT[0]["name"])
assert result == MOCK_ROWS_DICT[0]


@mock.patch(
"snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._execute_schema_query"
"snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.find_row_by_object_name"
)
def test_get_repository_row_no_repo_found(mock_execute, mock_cursor):
mock_execute.return_value = mock_cursor(
rows=[],
columns=MOCK_COLUMNS,
)
def test_get_repository_row_no_repo_found(mock_find, mock_cursor):
mock_find.return_value = None

with pytest.raises(ClickException) as expected_error:
ImageRepositoryManager().get_repository_row("IMAGES")
assert "does not exist or not authorized" in expected_error.value.message


@mock.patch(
"snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._execute_schema_query"
)
def test_get_repository_row_more_than_one_repo(mock_execute, mock_cursor):
mock_execute.return_value = mock_cursor(
rows=MOCK_ROWS_DICT + MOCK_ROWS_DICT,
columns=MOCK_COLUMNS,
)
with pytest.raises(ClickException) as expected_error:
ImageRepositoryManager().get_repository_row("IMAGES")
assert "Found more than one image repository" in expected_error.value.message


@mock.patch(
"snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.get_repository_row"
)
Expand Down
4 changes: 2 additions & 2 deletions tests_integration/spcs/test_image_repository.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from snowflake.cli.api.project.util import escape_like_pattern
from snowflake.cli.api.project.util import identifier_to_show_like_pattern

from tests_integration.test_utils import contains_row_with, row_from_snowflake_session
from tests_integration.testing_utils.naming_utils import ObjectNameProvider
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_get_repo_url(runner, snowflake_session, test_database):
snowflake_session.execute_string(f"create image repository {repo_name}")

created_repo = snowflake_session.execute_string(
f"show image repositories like '{escape_like_pattern(repo_name)}'"
f"show image repositories like {identifier_to_show_like_pattern(repo_name)}"
)
created_row = row_from_snowflake_session(created_repo)[0]
created_name = created_row["name"]
Expand Down

0 comments on commit d8b9b88

Please sign in to comment.