Skip to content

Commit

Permalink
Add snowgit commands to snowcli (#899)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pczajka authored Mar 15, 2024
1 parent 9330831 commit 998873d
Show file tree
Hide file tree
Showing 23 changed files with 1,571 additions and 43 deletions.
1 change: 1 addition & 0 deletions src/snowflake/cli/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ObjectType(Enum):
IMAGE_REPOSITORY = ObjectNames(
"image-repository", "image repository", "image repositories"
)
GIT_REPOSITORY = ObjectNames("git-repository", "git repository", "git repositories")

def __str__(self):
"""This makes using this Enum easier in formatted string"""
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/cli/api/feature_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ class FeatureFlag(FeatureFlagMixin):
ENABLE_STREAMLIT_EMBEDDED_STAGE = BooleanFlag(
"ENABLE_STREAMLIT_EMBEDDED_STAGE", False
)
ENABLE_SNOWGIT = BooleanFlag("ENABLE_SNOWGIT", False)
25 changes: 25 additions & 0 deletions src/snowflake/cli/api/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,31 @@ def use_role(self, new_role: str):
if is_different_role:
self._execute_query(f"use role {prev_role}")

def create_password_secret(
self, name: str, username: str, password: str
) -> SnowflakeCursor:
return self._execute_query(
f"""
create secret {name}
type = password
username = '{username}'
password = '{password}'
"""
)

def create_api_integration(
self, name: str, api_provider: str, allowed_prefix: str, secret: Optional[str]
) -> SnowflakeCursor:
return self._execute_query(
f"""
create api integration {name}
api_provider = {api_provider}
api_allowed_prefixes = ('{allowed_prefix}')
allowed_authentication_secrets = ({secret if secret else ''})
enabled = true
"""
)

def _execute_schema_query(self, query: str, name: Optional[str] = None, **kwargs):
"""
Check that a database and schema are provided before executing the query. Useful for operating on schema level objects.
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/cli/api/utils/path_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ def path_resolver(path_to_file: str) -> str:
if 0 < return_value <= BUFFER_SIZE:
return buffer.value
return path_to_file


def is_stage_path(path: str) -> bool:
return path.startswith("@") or path.startswith("snow://")
28 changes: 18 additions & 10 deletions src/snowflake/cli/app/commands_registration/builtin_plugins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from snowflake.cli.api.feature_flags import FeatureFlag
from snowflake.cli.plugins.connection import plugin_spec as connection_plugin_spec
from snowflake.cli.plugins.git import plugin_spec as git_plugin_spec
from snowflake.cli.plugins.nativeapp import plugin_spec as nativeapp_plugin_spec
from snowflake.cli.plugins.object import plugin_spec as object_plugin_spec
from snowflake.cli.plugins.render import plugin_spec as render_plugin_spec
Expand All @@ -7,14 +9,20 @@
from snowflake.cli.plugins.sql import plugin_spec as sql_plugin_spec
from snowflake.cli.plugins.streamlit import plugin_spec as streamlit_plugin_spec


# plugin name to plugin spec
builtin_plugin_name_to_plugin_spec = {
"connection": connection_plugin_spec,
"spcs": spcs_plugin_spec,
"nativeapp": nativeapp_plugin_spec,
"object": object_plugin_spec,
"render": render_plugin_spec,
"snowpark": snowpark_plugin_spec,
"sql": sql_plugin_spec,
"streamlit": streamlit_plugin_spec,
}
def get_builtin_plugin_name_to_plugin_spec():
plugin_specs = {
"connection": connection_plugin_spec,
"spcs": spcs_plugin_spec,
"nativeapp": nativeapp_plugin_spec,
"object": object_plugin_spec,
"render": render_plugin_spec,
"snowpark": snowpark_plugin_spec,
"sql": sql_plugin_spec,
"streamlit": streamlit_plugin_spec,
}
if FeatureFlag.ENABLE_SNOWGIT.is_enabled():
plugin_specs["git"] = git_plugin_spec

return plugin_specs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
LoadedExternalCommandPlugin,
)
from snowflake.cli.app.commands_registration.builtin_plugins import (
builtin_plugin_name_to_plugin_spec,
get_builtin_plugin_name_to_plugin_spec,
)
from snowflake.cli.app.commands_registration.exception_logging import exception_logging

Expand All @@ -31,7 +31,7 @@ def __init__(self):
self._loaded_command_paths: Dict[CommandPath, LoadedCommandPlugin] = {}

def register_builtin_plugins(self) -> None:
for (plugin_name, plugin) in builtin_plugin_name_to_plugin_spec.items():
for plugin_name, plugin in get_builtin_plugin_name_to_plugin_spec().items():
try:
self._plugin_manager.register(plugin=plugin, name=plugin_name)
except Exception as ex:
Expand All @@ -51,7 +51,7 @@ def register_external_plugins(self, plugin_names: List[str]) -> None:
)

def load_all_registered_plugins(self) -> List[LoadedCommandPlugin]:
for (plugin_name, plugin) in self._plugin_manager.list_name_plugin():
for plugin_name, plugin in self._plugin_manager.list_name_plugin():
self._load_plugin(plugin_name, plugin)
return list(self._loaded_plugins.values())

Expand Down Expand Up @@ -89,7 +89,7 @@ def _load_new_plugin(
def _load_plugin_spec(
self, plugin_name: str, plugin
) -> Optional[LoadedCommandPlugin]:
if plugin_name in builtin_plugin_name_to_plugin_spec.keys():
if plugin_name in get_builtin_plugin_name_to_plugin_spec().keys():
return self._load_builtin_plugin_spec(plugin_name, plugin)
else:
return self._load_external_plugin_spec(plugin_name, plugin)
Expand Down
Empty file.
237 changes: 237 additions & 0 deletions src/snowflake/cli/plugins/git/commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import logging
from pathlib import Path

import typer
from click import ClickException
from snowflake.cli.api.commands.flags import (
PatternOption,
identifier_argument,
like_option,
)
from snowflake.cli.api.commands.snow_typer import SnowTyper
from snowflake.cli.api.console.console import cli_console
from snowflake.cli.api.constants import ObjectType
from snowflake.cli.api.output.types import CommandResult, QueryResult
from snowflake.cli.api.utils.path_utils import is_stage_path
from snowflake.cli.plugins.git.manager import GitManager
from snowflake.cli.plugins.object.manager import ObjectManager

app = SnowTyper(
name="git",
help="Manages git repositories in Snowflake.",
)
log = logging.getLogger(__name__)


def _repo_path_argument_callback(path):
# All repository paths must start with repository scope:
# "@repo_name/tag/example_tag/*"
if not is_stage_path(path) or path.count("/") < 3:
raise ClickException(
"REPOSITORY_PATH should be a path to git repository stage with scope provided."
" Path to the repository root must end with '/'."
" For example: @my_repo/branches/main/"
)

return path


RepoNameArgument = identifier_argument(sf_object="git repository", example="my_repo")
RepoPathArgument = typer.Argument(
metavar="REPOSITORY_PATH",
help=(
"Path to git repository stage with scope provided."
" Path to the repository root must end with '/'."
" For example: @my_repo/branches/main/"
),
callback=_repo_path_argument_callback,
)


def _assure_repository_does_not_exist(om: ObjectManager, repository_name: str) -> None:
if om.object_exists(
object_type=ObjectType.GIT_REPOSITORY.value.cli_name, name=repository_name
):
raise ClickException(f"Repository '{repository_name}' already exists")


def _validate_origin_url(url: str) -> None:
if not url.startswith("https://"):
raise ClickException("Url address should start with 'https'")


@app.command("setup", requires_connection=True)
def setup(
repository_name: str = RepoNameArgument,
**options,
) -> CommandResult:
"""
Sets up a git repository object.
You will be prompted for:
* url - address of repository to be used for git clone operation
* secret - Snowflake secret containing authentication credentials. Not needed if origin repository does not require
authentication for RO operations (clone, fetch)
* API integration - object allowing Snowflake to interact with git repository.
"""
manager = GitManager()
om = ObjectManager()
_assure_repository_does_not_exist(om, repository_name)

url = typer.prompt("Origin url")
_validate_origin_url(url)

secret_needed = typer.confirm("Use secret for authentication?")
should_create_secret = False
secret_name = None
if secret_needed:
secret_name = f"{repository_name}_secret"
secret_name = typer.prompt(
"Secret identifier (will be created if not exists)", default=secret_name
)
if om.object_exists(
object_type=ObjectType.SECRET.value.cli_name, name=secret_name
):
cli_console.step(f"Using existing secret '{secret_name}'")
else:
should_create_secret = True
cli_console.step(f"Secret '{secret_name}' will be created")
secret_username = typer.prompt("username")
secret_password = typer.prompt("password/token", hide_input=True)

api_integration = f"{repository_name}_api_integration"
api_integration = typer.prompt(
"API integration identifier (will be created if not exists)",
default=api_integration,
)

if should_create_secret:
manager.create_password_secret(
name=secret_name, username=secret_username, password=secret_password
)
cli_console.step(f"Secret '{secret_name}' successfully created.")

if not om.object_exists(
object_type=ObjectType.INTEGRATION.value.cli_name, name=api_integration
):
manager.create_api_integration(
name=api_integration,
api_provider="git_https_api",
allowed_prefix=url,
secret=secret_name,
)
cli_console.step(f"API integration '{api_integration}' successfully created.")
else:
cli_console.step(f"Using existing API integration '{api_integration}'.")

return QueryResult(
manager.create(
repo_name=repository_name,
url=url,
api_integration=api_integration,
secret=secret_name,
)
)


@app.command(
"list-branches",
requires_connection=True,
)
def list_branches(
repository_name: str = RepoNameArgument,
like=like_option(
help_example='`list-branches --like "%_test"` lists all branches that end with "_test"'
),
**options,
) -> CommandResult:
"""
List all branches in the repository.
"""
return QueryResult(GitManager().show_branches(repo_name=repository_name, like=like))


@app.command(
"list-tags",
requires_connection=True,
)
def list_tags(
repository_name: str = RepoNameArgument,
like=like_option(
help_example='`list-tags --like "v2.0%"` lists all tags that start with "v2.0"'
),
**options,
) -> CommandResult:
"""
List all tags in the repository.
"""
return QueryResult(GitManager().show_tags(repo_name=repository_name, like=like))


@app.command(
"list-files",
requires_connection=True,
)
def list_files(
repository_path: str = RepoPathArgument,
pattern=PatternOption,
**options,
) -> CommandResult:
"""
List files from given state of git repository.
"""
return QueryResult(
GitManager().list_files(stage_name=repository_path, pattern=pattern)
)


@app.command(
"fetch",
requires_connection=True,
)
def fetch(
repository_name: str = RepoNameArgument,
**options,
) -> CommandResult:
"""
Fetch changes from origin to snowflake repository.
"""
return QueryResult(GitManager().fetch(repo_name=repository_name))


@app.command(
"copy",
requires_connection=True,
)
def copy(
repository_path: str = RepoPathArgument,
destination_path: str = typer.Argument(
help="Target path for copy operation. Should be a path to a directory on remote stage or local file system.",
),
parallel: int = typer.Option(
4,
help="Number of parallel threads to use when downloading files.",
),
**options,
):
"""
Copies all files from given state of repository to local directory or stage.
If the source path ends with '/', the command copies contents of specified directory.
Otherwise, it creates a new directory or file in the destination directory.
"""
is_copy = is_stage_path(destination_path)
if is_copy:
cursor = GitManager().copy_files(
source_path=repository_path, destination_path=destination_path
)
else:
cursor = GitManager().get(
stage_path=repository_path,
dest_path=Path(destination_path).resolve(),
parallel=parallel,
)
return QueryResult(cursor)
29 changes: 29 additions & 0 deletions src/snowflake/cli/plugins/git/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from textwrap import dedent

from snowflake.cli.plugins.object.stage.manager import StageManager
from snowflake.connector.cursor import SnowflakeCursor


class GitManager(StageManager):
def show_branches(self, repo_name: str, like: str) -> SnowflakeCursor:
return self._execute_query(f"show git branches like '{like}' in {repo_name}")

def show_tags(self, repo_name: str, like: str) -> SnowflakeCursor:
return self._execute_query(f"show git tags like '{like}' in {repo_name}")

def fetch(self, repo_name: str) -> SnowflakeCursor:
return self._execute_query(f"alter git repository {repo_name} fetch")

def create(
self, repo_name: str, api_integration: str, url: str, secret: str
) -> SnowflakeCursor:
query = dedent(
f"""
create git repository {repo_name}
api_integration = {api_integration}
origin = '{url}'
"""
)
if secret is not None:
query += f"git_credentials = {secret}\n"
return self._execute_query(query)
Loading

0 comments on commit 998873d

Please sign in to comment.