Skip to content

Commit

Permalink
Stage copy with directory structure
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-astus committed Mar 28, 2024
1 parent 2247d20 commit 14abd3f
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 495 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
All dependencies will be downloaded from PyPi.
* new `--skip-version-check` skips comparing versions of dependencies between requirements and Anaconda.
* new `--index-url` flag sets up Base URL of the Python Package Index to use for package lookup.
* Added `--recursive` flag for copy from stage, it will reproduce the directory structure locally.

## Fixes and improvements
* Adding `--image-name` option for image name argument in `spcs image-repository list-tags` for consistency with other commands.
Expand Down
7 changes: 5 additions & 2 deletions src/snowflake/cli/api/output/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import typing as t

from snowflake.connector import DictCursor
from snowflake.connector.cursor import SnowflakeCursor


Expand Down Expand Up @@ -43,12 +44,14 @@ def result(self):


class QueryResult(CollectionResult):
def __init__(self, cursor: SnowflakeCursor):
def __init__(self, cursor: SnowflakeCursor | DictCursor):
self.column_names = [col.name for col in cursor.description]
super().__init__(elements=self._prepare_payload(cursor))
self._query = cursor.query

def _prepare_payload(self, cursor):
def _prepare_payload(self, cursor: SnowflakeCursor | DictCursor):
if isinstance(cursor, DictCursor):
return (k for k in cursor)
return ({k: v for k, v in zip(self.column_names, row)} for row in cursor)

@property
Expand Down
19 changes: 16 additions & 3 deletions src/snowflake/cli/plugins/object/stage/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from snowflake.cli.api.commands.snow_typer import SnowTyper
from snowflake.cli.api.output.types import (
CommandResult,
MultipleResults,
ObjectResult,
QueryResult,
SingleQueryResult,
Expand Down Expand Up @@ -51,6 +52,10 @@ def copy(
4,
help="Number of parallel threads to use when uploading files.",
),
recursive: bool = typer.Option(
False,
help="Copy files recursively with directory structure.",
),
**options,
) -> CommandResult:
"""
Expand All @@ -68,12 +73,20 @@ def copy(
raise click.ClickException(
"Both source and target path are local. This operation is not supported."
)
if is_put and recursive:
raise click.ClickException("Recursive for PUT is not supported.")

if is_get:
target = Path(destination_path).resolve()
cursor = StageManager().get(
stage_path=source_path, dest_path=target, parallel=parallel
)
if recursive:
cursors = StageManager().get_recursive(
stage_path=source_path, dest_path=target, parallel=parallel
)
return MultipleResults([QueryResult(c) for c in cursors])
else:
cursor = StageManager().get(
stage_path=source_path, dest_path=target, parallel=parallel
)
else:
source = Path(source_path).resolve()
local_path = str(source) + "/*" if source.is_dir() else str(source)
Expand Down
37 changes: 31 additions & 6 deletions src/snowflake/cli/plugins/object/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import re
from contextlib import nullcontext
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union

from snowflake.cli.api.project.util import to_string_literal
from snowflake.cli.api.secure_path import SecurePath
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.cli.api.utils.path_utils import path_resolver
from snowflake.connector import DictCursor
from snowflake.connector.cursor import SnowflakeCursor

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,20 +61,18 @@ def _to_uri(self, local_path: str):
return uri
return to_string_literal(uri)

def list_files(
self, stage_name: str, pattern: str | None = None
) -> SnowflakeCursor:
def list_files(self, stage_name: str, pattern: str | None = None) -> DictCursor:
stage_name = self.get_standard_stage_prefix(stage_name)
query = f"ls {self.quote_stage_name(stage_name)}"
if pattern is not None:
query += f" pattern = '{pattern}'"
return self._execute_query(query)
return self._execute_query(query, cursor_class=DictCursor)

@staticmethod
def _assure_is_existing_directory(path: Path) -> None:
spath = SecurePath(path)
if not spath.exists():
spath.mkdir()
spath.mkdir(parents=True)
spath.assert_is_directory()

def get(
Expand All @@ -86,6 +85,32 @@ def get(
f"get {self.quote_stage_name(stage_path)} {self._to_uri(dest_directory)} parallel={parallel}"
)

def get_recursive(
self, stage_path: str, dest_path: Path, parallel: int = 4
) -> List[SnowflakeCursor]:
list_files_result = self.list_files(stage_path).fetchall()
files_on_stage = [f["name"] for f in list_files_result]

if stage_path.startswith("snow://"):
stage_path = stage_path[7:]

stage_parts_length = len(Path(stage_path).parts)
results = []
for file in files_on_stage:
dest_directory = dest_path / "/".join(
Path(file).parts[stage_parts_length:-1]
)
self._assure_is_existing_directory(Path(dest_directory))

stage_path_with_prefix = self.get_standard_stage_prefix(file)

result = self._execute_query(
f"get {self.quote_stage_name(stage_path_with_prefix)} {self._to_uri(f'{dest_directory}/')} parallel={parallel}"
)
results.append(result)

return results

def put(
self,
local_path: Union[str, Path],
Expand Down
Loading

0 comments on commit 14abd3f

Please sign in to comment.