diff --git a/src/snowflake/cli/api/project/util.py b/src/snowflake/cli/api/project/util.py index 6b75c6233f..22196864d6 100644 --- a/src/snowflake/cli/api/project/util.py +++ b/src/snowflake/cli/api/project/util.py @@ -48,11 +48,17 @@ def is_valid_identifier(identifier: str) -> bool: ) -def is_valid_object_name(name: str) -> bool: +def is_valid_object_name(name: str, max_depth=2) -> bool: """ - Determines whether the given identifier is a valid object name in the form , ., or .. + Determines whether the given identifier is a valid object name in the form , ., or ... + Max_depth determines how many valid identifiers are allowed. For example, account level objects would have a max depth of 0 + because they cannot be qualified by a database or schema, just the single identifier. """ - pattern = rf"{VALID_IDENTIFIER_REGEX}(?:\.{VALID_IDENTIFIER_REGEX}){{0,2}}" + if max_depth < 0: + raise ValueError("max_depth must be non-negative") + pattern = ( + rf"{VALID_IDENTIFIER_REGEX}(?:\.{VALID_IDENTIFIER_REGEX}){{0,{max_depth}}}" + ) return re.fullmatch(pattern, name) is not None diff --git a/src/snowflake/cli/plugins/object/common.py b/src/snowflake/cli/plugins/object/common.py index 0d97db07e7..b6a67426e5 100644 --- a/src/snowflake/cli/plugins/object/common.py +++ b/src/snowflake/cli/plugins/object/common.py @@ -7,7 +7,6 @@ QUOTED_IDENTIFIER_REGEX, UNQUOTED_IDENTIFIER_REGEX, is_valid_identifier, - is_valid_object_name, to_string_literal, ) @@ -80,13 +79,3 @@ def comment_option(object_type: str): help=f"Comment for the {object_type}.", callback=_comment_callback, ) - - -def object_name_callback(name: str) -> str: - """ - Callback for arguments that should be an object name (e.g. 'id' or 'db.schema.id'). - Currently does not support object names with arguments such as UDFs or procedures. - """ - if not is_valid_object_name(name): - raise ClickException(f"{name} is not a valid object name.") - return name diff --git a/src/snowflake/cli/plugins/spcs/compute_pool/commands.py b/src/snowflake/cli/plugins/spcs/compute_pool/commands.py index 1a176023c7..3c3d0c8a38 100644 --- a/src/snowflake/cli/plugins/spcs/compute_pool/commands.py +++ b/src/snowflake/cli/plugins/spcs/compute_pool/commands.py @@ -1,13 +1,15 @@ from typing import Optional import typer +from click import ClickException from snowflake.cli.api.commands.decorators import ( global_options_with_connection, with_output, ) from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS from snowflake.cli.api.output.types import CommandResult, SingleQueryResult -from snowflake.cli.plugins.object.common import comment_option, object_name_callback +from snowflake.cli.api.project.util import is_valid_object_name +from snowflake.cli.plugins.object.common import comment_option from snowflake.cli.plugins.spcs.common import validate_and_set_instances from snowflake.cli.plugins.spcs.compute_pool.manager import ComputePoolManager @@ -18,8 +20,17 @@ ) +def _compute_pool_name_callback(name: str) -> str: + """ + Verifies that compute pool name is a single valid identifier. + """ + if not is_valid_object_name(name, 0): + raise ClickException(f"{name} is not a valid compute pool name.") + return name + + ComputePoolNameArgument = typer.Argument( - ..., help="Name of the compute pool.", callback=object_name_callback + ..., help="Name of the compute pool.", callback=_compute_pool_name_callback ) diff --git a/tests/object/test_common.py b/tests/object/test_common.py index 4ca7aff386..41c487a6ec 100644 --- a/tests/object/test_common.py +++ b/tests/object/test_common.py @@ -1,14 +1,10 @@ -import uuid from snowflake.cli.plugins.object.common import ( _parse_tag, Tag, - object_name_callback, TagError, ) from typing import Tuple import pytest -from unittest import mock -from click import ClickException @pytest.mark.parametrize( @@ -50,18 +46,3 @@ def test_parse_tag_valid(value: str, expected: Tuple[str, str]): def test_parse_tag_invalid(value: str): with pytest.raises(TagError): _parse_tag(value) - - -@mock.patch("snowflake.cli.plugins.object.common.is_valid_object_name") -def test_object_name_callback(mock_is_valid): - name = f"id_{uuid.uuid4()}" - mock_is_valid.return_value = True - assert object_name_callback(name) == name - - -@mock.patch("snowflake.cli.plugins.object.common.is_valid_object_name") -def test_object_name_callback_invalid(mock_is_valid): - name = f"id_{uuid.uuid4()}" - mock_is_valid.return_value = False - with pytest.raises(ClickException): - object_name_callback(name) diff --git a/tests/project/test_util.py b/tests/project/test_util.py index e707e45080..f5025e8492 100644 --- a/tests/project/test_util.py +++ b/tests/project/test_util.py @@ -103,6 +103,8 @@ def test_is_valid_object_name(): names = [".".join(p) for p in list(permutations(valid_identifiers, num))] for name in names: assert is_valid_object_name(name) + if num > 1: + assert not is_valid_object_name(name, 0) # any combination with at least one invalid identifier is invalid for num in [1, 2, 3]: diff --git a/tests/spcs/test_compute_pool.py b/tests/spcs/test_compute_pool.py index c6035fbeb1..80c06aba97 100644 --- a/tests/spcs/test_compute_pool.py +++ b/tests/spcs/test_compute_pool.py @@ -1,9 +1,13 @@ from unittest.mock import Mock, patch from snowflake.cli.plugins.spcs.compute_pool.manager import ComputePoolManager +from snowflake.cli.plugins.spcs.compute_pool.commands import _compute_pool_name_callback from snowflake.connector.cursor import SnowflakeCursor from snowflake.cli.api.project.util import to_string_literal import json +import pytest + +from click import ClickException @patch( @@ -192,3 +196,19 @@ def test_resume_cli(mock_resume, mock_cursor, runner): result_json_parsed = json.loads(result_json.output) assert isinstance(result_json_parsed, dict) assert result_json_parsed == {"status": "Statement executed successfully."} + + +@patch("snowflake.cli.plugins.spcs.compute_pool.commands.is_valid_object_name") +def test_compute_pool_name_callback(mock_is_valid): + name = "test_pool" + mock_is_valid.return_value = True + assert _compute_pool_name_callback(name) == name + + +@patch("snowflake.cli.plugins.spcs.compute_pool.commands.is_valid_object_name") +def test_compute_pool_name_callback_invalid(mock_is_valid): + name = "test_pool" + mock_is_valid.return_value = False + with pytest.raises(ClickException) as e: + _compute_pool_name_callback(name) + assert "is not a valid compute pool name." in e.value.message