Skip to content

Commit

Permalink
Stage copy with directory structure
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-astus committed Apr 4, 2024
1 parent ce5e603 commit 5e28362
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 43 deletions.
80 changes: 50 additions & 30 deletions src/snowflake/cli/plugins/object/stage/commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import itertools
from os import path
from pathlib import Path

import click
Expand All @@ -8,8 +10,8 @@
from snowflake.cli.api.commands.snow_typer import SnowTyper
from snowflake.cli.api.console import cli_console
from snowflake.cli.api.output.types import (
CollectionResult,
CommandResult,
MultipleResults,
ObjectResult,
QueryResult,
SingleQueryResult,
Expand Down Expand Up @@ -76,35 +78,8 @@ def copy(
)

if is_get:
target = Path(destination_path).resolve()
if recursive:
cursors = StageManager().get_recursive(
stage_path=source_path, dest_path=target, parallel=parallel
)
return MultipleResults([QueryResult(c) for c in cursors])
else:
cli_console.warning(
"Use `--recursive` flag, which copy files recursively with directory structure. This will be the default behavior in the future."
)
cursor = StageManager().get(
stage_path=source_path, dest_path=target, parallel=parallel
)
return QueryResult(cursor)

if is_put:
if recursive:
raise click.ClickException("Recursive flag for upload is not supported.")
else:
source = Path(source_path).resolve()
local_path = str(source) + "/*" if source.is_dir() else str(source)

cursor = StageManager().put(
local_path=local_path,
stage_path=destination_path,
overwrite=overwrite,
parallel=parallel,
)
return QueryResult(cursor)
return _get(recursive, source_path, destination_path, parallel)
return _put(recursive, source_path, destination_path, parallel, overwrite)


@app.command("create", requires_connection=True)
Expand Down Expand Up @@ -141,3 +116,48 @@ def stage_diff(
"""
diff: DiffResult = stage_diff(Path(folder_name), stage_name)
return ObjectResult(str(diff))


def _get(recursive: bool, source_path: str, destination_path: str, parallel: int):
target = Path(destination_path).resolve()
if recursive:
cursors = StageManager().get_recursive(
stage_path=source_path, dest_path=target, parallel=parallel
)
results = [list(QueryResult(c).result) for c in cursors]
flattened_results = list(itertools.chain.from_iterable(results))
sorted_results = sorted(
flattened_results,
key=lambda e: (path.dirname(e["file"]), path.basename(e["file"])),
)
return CollectionResult(sorted_results)

cli_console.warning(
"Use `--recursive` flag, which copy files recursively with directory structure. This will be the default behavior in the future."
)
cursor = StageManager().get(
stage_path=source_path, dest_path=target, parallel=parallel
)
return QueryResult(cursor)


def _put(
recursive: bool,
source_path: str,
destination_path: str,
parallel: int,
overwrite: bool,
):
if recursive:
raise click.ClickException("Recursive flag for upload is not supported.")

source = Path(source_path).resolve()
local_path = str(source) + "/*" if source.is_dir() else str(source)

cursor = StageManager().put(
local_path=local_path,
stage_path=destination_path,
overwrite=overwrite,
parallel=parallel,
)
return QueryResult(cursor)
11 changes: 0 additions & 11 deletions tests/object/stage/__snapshots__/test_stage.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,11 @@
# ---
# name: test_stage_print_result_for_get_all_files_from_stage_recursive
'''
SELECT A MOCK QUERY
+-----------------------------------------+
| file | size | status | message |
|-----------+------+------------+---------|
| file1.txt | 10 | DOWNLOADED | |
+-----------------------------------------+
SELECT A MOCK QUERY
+-----------------------------------------+
| file | size | status | message |
|-----------+------+------------+---------|
| file2.txt | 10 | DOWNLOADED | |
+-----------------------------------------+
SELECT A MOCK QUERY
+-----------------------------------------+
| file | size | status | message |
|-----------+------+------------+---------|
| file3.txt | 10 | DOWNLOADED | |
+-----------------------------------------+

Expand Down
4 changes: 2 additions & 2 deletions tests/object/stage/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_stage_copy_remote_to_local_quoted_stage_recursive(
):
mock_execute.side_effect = [
mock_cursor([{"name": '"stage name"/file'}], []),
mock_cursor(["row"], []),
mock_cursor([("file")], ["file"]),
]
with TemporaryDirectory() as tmp_dir:
result = runner.invoke(
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_stage_copy_remote_to_local_quoted_uri_recursive(
):
mock_execute.side_effect = [
mock_cursor([{"name": "stageName/file"}], []),
mock_cursor(["row"], []),
mock_cursor([(raw_path)], ["file"]),
]
with TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir).resolve()
Expand Down

0 comments on commit 5e28362

Please sign in to comment.