Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1011766: Added 'snow spcs image-repository url <repo_name>' command #708

Merged
merged 14 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
## New additions
* Added ability to specify scope of the `object list` command with the `--in <scope_type> <scope_name>` option.
* Introduced `snowflake.cli.api.console.cli_console` object with helper methods for intermediate output.
* Added convenience function `spcs image-repository url <repo_name>`.

## Fixes and improvements
* Restricted permissions of automatically created files
Expand Down
12 changes: 12 additions & 0 deletions src/snowflake/cli/api/project/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,15 @@ def validate_version(version: str):
raise ValueError(
f"Project definition version {version} is not supported by this version of Snowflake CLI. Supported versions: {SUPPORTED_VERSIONS}"
)


def escape_like_pattern(pattern: str, escape_sequence: str = r"\\") -> str:
"""
When used with LIKE in Snowflake, '%' and '_' are wildcard characters and must be escaped to be used literally.
The escape character is '\\' when used in SHOW LIKE and must be specified when used with string matching using the
following syntax: <subject> LIKE <pattern> [ ESCAPE <escape> ].
"""
pattern = pattern.replace("%", rf"{escape_sequence}%").replace(
"_", rf"{escape_sequence}_"
)
return pattern
4 changes: 2 additions & 2 deletions src/snowflake/cli/api/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def use_role(self, new_role: str):
if is_different_role:
self._execute_query(f"use role {prev_role}")

def _execute_schema_query(self, query: str):
def _execute_schema_query(self, query: str, **kwargs):
self.check_database_and_schema()
return self._execute_query(query)
return self._execute_query(query, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allows you to pass cursor_class with _execute_schema_query

def check_database_and_schema(self) -> None:
database = self._conn.database
Expand Down
42 changes: 35 additions & 7 deletions src/snowflake/cli/plugins/spcs/image_repository/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
with_output,
)
from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS
from snowflake.cli.api.output.types import CollectionResult
from snowflake.cli.api.output.types import CollectionResult, MessageResult
from snowflake.cli.api.project.util import is_valid_unquoted_identifier
from snowflake.cli.plugins.spcs.image_registry.manager import RegistryManager
from snowflake.cli.plugins.spcs.image_repository.manager import ImageRepositoryManager

Expand All @@ -21,13 +22,25 @@
)


def _repo_name_callback(name: str):
if not is_valid_unquoted_identifier(name):
raise ClickException(
"Repository name must be a valid unquoted identifier. Quoted names for special characters or case-sensitive names are not supported for image repositories."
)
return name


REPO_NAME_ARGUMENT = typer.Argument(
help="Name of the image repository. Only unquoted identifiers are supported for image repositories.",
callback=_repo_name_callback,
)


@app.command("list-images")
@with_output
@global_options_with_connection
def list_images(
repo_name: str = typer.Argument(
help="Name of the image repository shown by the `SHOW IMAGE REPOSITORIES` SQL command.",
),
repo_name: str = REPO_NAME_ARGUMENT,
**options,
) -> CollectionResult:
"""Lists images in given repository."""
Expand Down Expand Up @@ -72,9 +85,7 @@ def list_images(
@with_output
@global_options_with_connection
def list_tags(
repo_name: str = typer.Argument(
help="Name of the image repository shown by the `SHOW IMAGE REPOSITORIES` SQL command.",
),
repo_name: str = REPO_NAME_ARGUMENT,
image_name: str = typer.Option(
...,
"--image_name",
Expand Down Expand Up @@ -119,3 +130,20 @@ def list_tags(
tags_list.append({"tag": image_tag})

return CollectionResult(tags_list)


@app.command("url")
@with_output
@global_options_with_connection
def repo_url(
repo_name: str = REPO_NAME_ARGUMENT,
**options,
):
"""Returns the URL for the given repository."""
return MessageResult(
(
ImageRepositoryManager().get_repository_url(
repo_name=repo_name, with_scheme=False
)
)
)
61 changes: 36 additions & 25 deletions src/snowflake/cli/plugins/spcs/image_repository/manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Dict
from urllib.parse import urlparse

from click import ClickException
from snowflake.cli.api.project.util import (
escape_like_pattern,
is_valid_unquoted_identifier,
)
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.connector.cursor import SnowflakeCursor
from snowflake.connector.cursor import DictCursor


class ImageRepositoryManager(SqlExecutionMixin):
Expand All @@ -15,38 +20,44 @@ def get_schema(self):
def get_role(self):
return self._conn.role

def get_repository_url_list(self, repo_name: str) -> SnowflakeCursor:
role = self.get_role()
database = self.get_database()
schema = self.get_schema()

registry_query = f"""
use role {role};
use database {database};
use schema {schema};
show image repositories like '{repo_name}';
"""

return self._execute_query(registry_query)
def get_repository_row(self, repo_name: str) -> Dict:
if not is_valid_unquoted_identifier(repo_name):
raise ValueError(
f"repo_name '{repo_name}' is not a valid unquoted Snowflake identifier"
)

sfc-gh-davwang marked this conversation as resolved.
Show resolved Hide resolved
def get_repository_url(self, repo_name):
database = self.get_database()
schema = self.get_schema()
repo_name = repo_name.upper()

result_set = self.get_repository_url_list(repo_name=repo_name)
# because image repositories only support unquoted identifiers, SHOW LIKE should only return one or zero rows
repository_list_query = (
f"show image repositories like '{escape_like_pattern(repo_name)}'"
sfc-gh-davwang marked this conversation as resolved.
Show resolved Hide resolved
)

result_set = self._execute_schema_query(
repository_list_query, cursor_class=DictCursor
)
results = result_set.fetchall()

if len(results) == 0:
raise ClickException(
f"Specified repository name {repo_name} not found in database {database} and schema {schema}"
f"Image repository '{repo_name}' does not exist in database '{self.get_database()}' and schema '{self.get_schema()}' or not authorized."
)
sfc-gh-davwang marked this conversation as resolved.
Show resolved Hide resolved
else:
sfc-gh-davwang marked this conversation as resolved.
Show resolved Hide resolved
if len(results) > 1:
raise Exception(
f"Found more than one repositories with name {repo_name}. This is unexpected."
)
elif len(results) > 1:
raise ClickException(
f"Found more than one image repository with name matching '{repo_name}'. This is unexpected."
)
return results[0]

return f"https://{results[0][4]}"
def get_repository_url(self, repo_name: str, with_scheme: bool = True):
if not is_valid_unquoted_identifier(repo_name):
raise ValueError(
f"repo_name '{repo_name}' is not a valid unquoted Snowflake identifier"
)
repo_row = self.get_repository_row(repo_name)
if with_scheme:
return f"https://{repo_row['repository_url']}"
else:
return repo_row["repository_url"]

def get_repository_api_url(self, repo_url):
"""
Expand Down
15 changes: 15 additions & 0 deletions tests/project/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
is_valid_unquoted_identifier,
to_identifier,
to_string_literal,
escape_like_pattern,
)

VALID_UNQUOTED_IDENTIFIERS = (
Expand Down Expand Up @@ -162,3 +163,17 @@ def test_is_valid_string_literal(literal, valid):
)
def test_to_string_literal(raw_string, literal):
assert to_string_literal(raw_string) == literal


@pytest.mark.parametrize(
"raw_string, escaped",
[
(r"underscore_table", r"underscore\\_table"),
(r"percent%%table", r"percent\\%\\%table"),
(r"__many__under__scores__", r"\\_\\_many\\_\\_under\\_\\_scores\\_\\_"),
(r"mixed_underscore%percent", r"mixed\\_underscore\\%percent"),
(r"regular$table", r"regular$table"),
],
)
def test_escape_like_pattern(raw_string, escaped):
assert escape_like_pattern(raw_string) == escaped
Loading