Skip to content

Commit

Permalink
SNOW-1011772: Updating compute pool callback to only allow a single v…
Browse files Browse the repository at this point in the history
…alid identifier
  • Loading branch information
sfc-gh-davwang committed Feb 7, 2024
1 parent 14ee8d4 commit bf46e0a
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 35 deletions.
12 changes: 9 additions & 3 deletions src/snowflake/cli/api/project/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,17 @@ def is_valid_identifier(identifier: str) -> bool:
)


def is_valid_object_name(name: str) -> bool:
def is_valid_object_name(name: str, max_depth=2) -> bool:
"""
Determines whether the given identifier is a valid object name in the form <name>, <schema>.<name>, or <database>.<schema>.<name>
Determines whether the given identifier is a valid object name in the form <name>, <schema>.<name>, or <database>.<schema>.<name>.
Max_depth determines how many valid identifiers are allowed. For example, account level objects would have a max depth of 0
because they cannot be qualified by a database or schema, just the single identifier.
"""
pattern = rf"{VALID_IDENTIFIER_REGEX}(?:\.{VALID_IDENTIFIER_REGEX}){{0,2}}"
if max_depth < 0:
raise ValueError("max_depth must be non-negative")
pattern = (
rf"{VALID_IDENTIFIER_REGEX}(?:\.{VALID_IDENTIFIER_REGEX}){{0,{max_depth}}}"
)
return re.fullmatch(pattern, name) is not None


Expand Down
11 changes: 0 additions & 11 deletions src/snowflake/cli/plugins/object/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
QUOTED_IDENTIFIER_REGEX,
UNQUOTED_IDENTIFIER_REGEX,
is_valid_identifier,
is_valid_object_name,
to_string_literal,
)

Expand Down Expand Up @@ -80,13 +79,3 @@ def comment_option(object_type: str):
help=f"Comment for the {object_type}.",
callback=_comment_callback,
)


def object_name_callback(name: str) -> str:
"""
Callback for arguments that should be an object name (e.g. 'id' or 'db.schema.id').
Currently does not support object names with arguments such as UDFs or procedures.
"""
if not is_valid_object_name(name):
raise ClickException(f"{name} is not a valid object name.")
return name
15 changes: 13 additions & 2 deletions src/snowflake/cli/plugins/spcs/compute_pool/commands.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Optional

import typer
from click import ClickException
from snowflake.cli.api.commands.decorators import (
global_options_with_connection,
with_output,
)
from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS
from snowflake.cli.api.output.types import CommandResult, SingleQueryResult
from snowflake.cli.plugins.object.common import comment_option, object_name_callback
from snowflake.cli.api.project.util import is_valid_object_name
from snowflake.cli.plugins.object.common import comment_option
from snowflake.cli.plugins.spcs.common import validate_and_set_instances
from snowflake.cli.plugins.spcs.compute_pool.manager import ComputePoolManager

Expand All @@ -18,8 +20,17 @@
)


def _compute_pool_name_callback(name: str) -> str:
"""
Verifies that compute pool name is a single valid identifier.
"""
if not is_valid_object_name(name, 0):
raise ClickException(f"{name} is not a valid compute pool name.")
return name


ComputePoolNameArgument = typer.Argument(
..., help="Name of the compute pool.", callback=object_name_callback
..., help="Name of the compute pool.", callback=_compute_pool_name_callback
)


Expand Down
19 changes: 0 additions & 19 deletions tests/object/test_common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import uuid
from snowflake.cli.plugins.object.common import (
_parse_tag,
Tag,
object_name_callback,
TagError,
)
from typing import Tuple
import pytest
from unittest import mock
from click import ClickException


@pytest.mark.parametrize(
Expand Down Expand Up @@ -50,18 +46,3 @@ def test_parse_tag_valid(value: str, expected: Tuple[str, str]):
def test_parse_tag_invalid(value: str):
with pytest.raises(TagError):
_parse_tag(value)


@mock.patch("snowflake.cli.plugins.object.common.is_valid_object_name")
def test_object_name_callback(mock_is_valid):
name = f"id_{uuid.uuid4()}"
mock_is_valid.return_value = True
assert object_name_callback(name) == name


@mock.patch("snowflake.cli.plugins.object.common.is_valid_object_name")
def test_object_name_callback_invalid(mock_is_valid):
name = f"id_{uuid.uuid4()}"
mock_is_valid.return_value = False
with pytest.raises(ClickException):
object_name_callback(name)
2 changes: 2 additions & 0 deletions tests/project/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def test_is_valid_object_name():
names = [".".join(p) for p in list(permutations(valid_identifiers, num))]
for name in names:
assert is_valid_object_name(name)
if num > 1:
assert not is_valid_object_name(name, 0)

# any combination with at least one invalid identifier is invalid
for num in [1, 2, 3]:
Expand Down
20 changes: 20 additions & 0 deletions tests/spcs/test_compute_pool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from unittest.mock import Mock, patch

from snowflake.cli.plugins.spcs.compute_pool.manager import ComputePoolManager
from snowflake.cli.plugins.spcs.compute_pool.commands import _compute_pool_name_callback
from snowflake.connector.cursor import SnowflakeCursor
from snowflake.cli.api.project.util import to_string_literal
import json
import pytest

from click import ClickException


@patch(
Expand Down Expand Up @@ -192,3 +196,19 @@ def test_resume_cli(mock_resume, mock_cursor, runner):
result_json_parsed = json.loads(result_json.output)
assert isinstance(result_json_parsed, dict)
assert result_json_parsed == {"status": "Statement executed successfully."}


@patch("snowflake.cli.plugins.spcs.compute_pool.commands.is_valid_object_name")
def test_compute_pool_name_callback(mock_is_valid):
name = "test_pool"
mock_is_valid.return_value = True
assert _compute_pool_name_callback(name) == name


@patch("snowflake.cli.plugins.spcs.compute_pool.commands.is_valid_object_name")
def test_compute_pool_name_callback_invalid(mock_is_valid):
name = "test_pool"
mock_is_valid.return_value = False
with pytest.raises(ClickException) as e:
_compute_pool_name_callback(name)
assert "is not a valid compute pool name." in e.value.message

0 comments on commit bf46e0a

Please sign in to comment.