Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stage copy with directory structure #942

Merged
merged 4 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
* `snow snowpark build`:
* new `--skip-version-check` skips comparing versions of dependencies between requirements and Anaconda.
* new `--index-url` flag sets up Base URL of the Python Package Index to use for package lookup.
* Added `--recursive` flag for copy from stage, it will reproduce the directory structure locally.

## Fixes and improvements
* Adding `--image-name` option for image name argument in `spcs image-repository list-tags` for consistency with other commands.
Expand Down
7 changes: 5 additions & 2 deletions src/snowflake/cli/api/output/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import typing as t

from snowflake.connector import DictCursor
from snowflake.connector.cursor import SnowflakeCursor


Expand Down Expand Up @@ -43,12 +44,14 @@ def result(self):


class QueryResult(CollectionResult):
def __init__(self, cursor: SnowflakeCursor):
def __init__(self, cursor: SnowflakeCursor | DictCursor):
self.column_names = [col.name for col in cursor.description]
super().__init__(elements=self._prepare_payload(cursor))
self._query = cursor.query

def _prepare_payload(self, cursor):
def _prepare_payload(self, cursor: SnowflakeCursor | DictCursor):
if isinstance(cursor, DictCursor):
return (k for k in cursor)
return ({k: v for k, v in zip(self.column_names, row)} for row in cursor)

@property
Expand Down
77 changes: 64 additions & 13 deletions src/snowflake/cli/plugins/object/stage/commands.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import itertools
from os import path
from pathlib import Path

import click
import typer
from snowflake.cli.api.commands.flags import PatternOption
from snowflake.cli.api.commands.snow_typer import SnowTyper
from snowflake.cli.api.console import cli_console
from snowflake.cli.api.output.types import (
CollectionResult,
CommandResult,
ObjectResult,
QueryResult,
Expand Down Expand Up @@ -51,6 +55,10 @@ def copy(
4,
help="Number of parallel threads to use when uploading files.",
),
recursive: bool = typer.Option(
False,
help="Copy files recursively with directory structure.",
),
sfc-gh-mraba marked this conversation as resolved.
Show resolved Hide resolved
**options,
) -> CommandResult:
"""
Expand All @@ -70,21 +78,19 @@ def copy(
)

if is_get:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing order of this if can improve readability, WDYT?

Suggested change
if is_get:
if not is_get:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see that improvement 🤔 , for me it will not change anything. I added second if for put, WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not is_get:
        # PUT case
        source = Path(source_path).resolve()
        local_path = str(source) + "/*" if source.is_dir() else str(source)

        cursor = StageManager().put(
            local_path=local_path,
            stage_path=destination_path,
            overwrite=overwrite,
            parallel=parallel,
        )
        return QueryResult(cursor)

# GET case
if recursive:
    cursors = StageManager().get_recursive(
        stage_path=source_path, dest_path=target, parallel=parallel
    )
    result = MultipleResults([QueryResult(c) for c in cursors])
else:
    cli_console.warning(
        "Use `--recursive` flag, which copy files recursively with directory structure. This will be the default behavior in the future."
    )
    cursor = StageManager().get(
        stage_path=source_path, dest_path=target, parallel=parallel
    )
    result = QueryResult(cursor)

But if we foresee more changes in put then introducing helper methods may make more sense:

if is_get:
   return _get(...)
return _put(...)

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like helper methods

target = Path(destination_path).resolve()
cursor = StageManager().get(
stage_path=source_path, dest_path=target, parallel=parallel
)
else:
source = Path(source_path).resolve()
local_path = str(source) + "/*" if source.is_dir() else str(source)

cursor = StageManager().put(
local_path=local_path,
stage_path=destination_path,
overwrite=overwrite,
return _get(
recursive=recursive,
source_path=source_path,
destination_path=destination_path,
parallel=parallel,
)
return QueryResult(cursor)
return _put(
recursive=recursive,
source_path=source_path,
destination_path=destination_path,
parallel=parallel,
overwrite=overwrite,
)


@app.command("create", requires_connection=True)
Expand Down Expand Up @@ -121,3 +127,48 @@ def stage_diff(
"""
diff: DiffResult = stage_diff(Path(folder_name), stage_name)
return ObjectResult(str(diff))


def _get(recursive: bool, source_path: str, destination_path: str, parallel: int):
target = Path(destination_path).resolve()
if not recursive:
cli_console.warning(
"Use `--recursive` flag, which copy files recursively with directory structure. This will be the default behavior in the future."
)
cursor = StageManager().get(
stage_path=source_path, dest_path=target, parallel=parallel
)
return QueryResult(cursor)

cursors = StageManager().get_recursive(
stage_path=source_path, dest_path=target, parallel=parallel
)
results = [list(QueryResult(c).result) for c in cursors]
flattened_results = list(itertools.chain.from_iterable(results))
sorted_results = sorted(
flattened_results,
key=lambda e: (path.dirname(e["file"]), path.basename(e["file"])),
)
return CollectionResult(sorted_results)


def _put(
recursive: bool,
source_path: str,
destination_path: str,
parallel: int,
overwrite: bool,
):
if recursive:
raise click.ClickException("Recursive flag for upload is not supported.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😢


source = Path(source_path).resolve()
local_path = str(source) + "/*" if source.is_dir() else str(source)

cursor = StageManager().put(
local_path=local_path,
stage_path=destination_path,
overwrite=overwrite,
parallel=parallel,
)
return QueryResult(cursor)
39 changes: 33 additions & 6 deletions src/snowflake/cli/plugins/object/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import re
from contextlib import nullcontext
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union

from snowflake.cli.api.project.util import to_string_literal
from snowflake.cli.api.secure_path import SecurePath
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.cli.api.utils.path_utils import path_resolver
from snowflake.connector import DictCursor
from snowflake.connector.cursor import SnowflakeCursor

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,20 +61,18 @@ def _to_uri(self, local_path: str):
return uri
return to_string_literal(uri)

def list_files(
self, stage_name: str, pattern: str | None = None
) -> SnowflakeCursor:
def list_files(self, stage_name: str, pattern: str | None = None) -> DictCursor:
stage_name = self.get_standard_stage_prefix(stage_name)
query = f"ls {self.quote_stage_name(stage_name)}"
if pattern is not None:
query += f" pattern = '{pattern}'"
return self._execute_query(query)
return self._execute_query(query, cursor_class=DictCursor)

@staticmethod
def _assure_is_existing_directory(path: Path) -> None:
spath = SecurePath(path)
if not spath.exists():
spath.mkdir()
spath.mkdir(parents=True)
spath.assert_is_directory()

def get(
Expand All @@ -86,6 +85,30 @@ def get(
f"get {self.quote_stage_name(stage_path)} {self._to_uri(dest_directory)} parallel={parallel}"
)

def get_recursive(
self, stage_path: str, dest_path: Path, parallel: int = 4
) -> List[SnowflakeCursor]:
stage_path_only = stage_path
if stage_path_only.startswith("snow://"):
stage_path_only = stage_path_only[7:]
stage_parts_length = len(Path(stage_path_only).parts)

results = []
for file in self.iter_stage(stage_path):
dest_directory = dest_path / "/".join(
Path(file).parts[stage_parts_length:-1]
)
self._assure_is_existing_directory(Path(dest_directory))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean the directory has to exists?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't get make sure it creates the directories? If I want to use this piece of code in other place, does it mean I have also to create all directories?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_assure_is_existing_directory is already existing function which checks if directory exists and creates it if not. Normal get uses the same function.


stage_path_with_prefix = self.get_standard_stage_prefix(file)

result = self._execute_query(
f"get {self.quote_stage_name(stage_path_with_prefix)} {self._to_uri(f'{dest_directory}/')} parallel={parallel}"
)
results.append(result)

return results

def put(
self,
local_path: Union[str, Path],
Expand Down Expand Up @@ -137,3 +160,7 @@ def create(self, stage_name: str, comment: Optional[str] = None) -> SnowflakeCur
if comment:
query += f" comment='{comment}'"
return self._execute_query(query)

def iter_stage(self, stage_path: str):
for file in self.list_files(stage_path).fetchall():
yield file["name"]
Loading
Loading