From ab95923394574d3be118778357143b7a009922ab Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Wed, 4 Sep 2024 00:18:24 -0600 Subject: [PATCH] Update stage execute to use stage directory to filter --- src/snowflake/cli/_plugins/git/manager.py | 10 +++++- src/snowflake/cli/_plugins/stage/manager.py | 37 ++++++++++++++++----- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/src/snowflake/cli/_plugins/git/manager.py b/src/snowflake/cli/_plugins/git/manager.py index 0edb6aecb6..a411aa7193 100644 --- a/src/snowflake/cli/_plugins/git/manager.py +++ b/src/snowflake/cli/_plugins/git/manager.py @@ -43,15 +43,23 @@ def __init__(self, stage_path: str): def path(self) -> str: return f"{self.stage_name.rstrip('/')}/{self.directory}" + @classmethod + def get_directory(cls, stage_path: str) -> str: + return "/".join(Path(stage_path).parts[3:]) + @property def full_path(self) -> str: return f"{self.stage.rstrip('/')}/{self.directory}" - def add_stage_prefix(self, file_path: str) -> str: + def replace_stage_prefix(self, file_path: str) -> str: stage = Path(self.stage).parts[0] file_path_without_prefix = Path(file_path).parts[1:] return f"{stage}/{'/'.join(file_path_without_prefix)}" + def add_stage_prefix(self, file_path: str) -> str: + stage = self.stage.rstrip("/") + return f"{stage}/{file_path.lstrip('/')}" + def get_directory_from_file_path(self, file_path: str) -> List[str]: stage_path_length = len(Path(self.directory).parts) return list(Path(file_path).parts[3 + stage_path_length : -1]) diff --git a/src/snowflake/cli/_plugins/stage/manager.py b/src/snowflake/cli/_plugins/stage/manager.py index f01612343d..90621e7011 100644 --- a/src/snowflake/cli/_plugins/stage/manager.py +++ b/src/snowflake/cli/_plugins/stage/manager.py @@ -65,8 +65,8 @@ class StagePathParts: stage_name: str is_directory: bool - @staticmethod - def get_directory(stage_path: str) -> str: + @classmethod + def get_directory(cls, stage_path: str) -> str: return "/".join(Path(stage_path).parts[1:]) @property @@ -77,6 +77,9 @@ def path(self) -> str: def full_path(self) -> str: raise NotImplementedError + def replace_stage_prefix(self, file_path: str) -> str: + raise NotImplementedError + def add_stage_prefix(self, file_path: str) -> str: raise NotImplementedError @@ -128,11 +131,15 @@ def path(self) -> str: def full_path(self) -> str: return f"{self.stage.rstrip('/')}/{self.directory}" - def add_stage_prefix(self, file_path: str) -> str: + def replace_stage_prefix(self, file_path: str) -> str: stage = Path(self.stage).parts[0] file_path_without_prefix = Path(file_path).parts[1:] return f"{stage}/{'/'.join(file_path_without_prefix)}" + def add_stage_prefix(self, file_path: str) -> str: + stage = self.stage.rstrip("/") + return f"{stage}/{file_path.lstrip('/')}" + def get_directory_from_file_path(self, file_path: str) -> List[str]: stage_path_length = len(Path(self.directory).parts) return list(Path(file_path).parts[1 + stage_path_length : -1]) @@ -149,10 +156,16 @@ class UserStagePathParts(StagePathParts): def __init__(self, stage_path: str): self.directory = self.get_directory(stage_path) - self.stage = "@~" - self.stage_name = "@~" + self.stage = USER_STAGE_PREFIX + self.stage_name = USER_STAGE_PREFIX self.is_directory = True if stage_path.endswith("/") else False + @classmethod + def get_directory(cls, stage_path: str) -> str: + if Path(stage_path).parts[0] == USER_STAGE_PREFIX: + return super().get_directory(stage_path) + return stage_path + @property def path(self) -> str: return f"{self.directory}" @@ -161,6 +174,11 @@ def path(self) -> str: def full_path(self) -> str: return f"{self.stage}/{self.directory}" + def replace_stage_prefix(self, file_path: str) -> str: + if Path(file_path).parts[0] == self.stage_name: + return file_path + return f"{self.stage}/{file_path}" + def add_stage_prefix(self, file_path: str) -> str: return f"{self.stage}/{file_path}" @@ -248,7 +266,7 @@ def get_recursive( self._assure_is_existing_directory(dest_directory) result = self._execute_query( - f"get {self.quote_stage_name(stage_path_parts.add_stage_prefix(file_path))} {self._to_uri(f'{dest_directory}/')} parallel={parallel}" + f"get {self.quote_stage_name(stage_path_parts.replace_stage_prefix(file_path))} {self._to_uri(f'{dest_directory}/')} parallel={parallel}" ) results.append(result) @@ -327,8 +345,9 @@ def execute( ): stage_path_parts = self._stage_path_part_factory(stage_path) all_files_list = self._get_files_list_from_stage(stage_path_parts) + all_files_with_stage_name_prefix = [ - stage_path_parts.add_stage_prefix(file) for file in all_files_list + stage_path_parts.get_directory(file) for file in all_files_list ] # filter files from stage if match stage_path pattern @@ -355,7 +374,7 @@ def execute( ) for file_path in sorted_file_path_list: - file_stage_path = self.get_standard_stage_prefix(file_path) + file_stage_path = stage_path_parts.add_stage_prefix(file_path) if file_path.endswith(".py"): result = self._execute_python( file_stage_path=file_stage_path, @@ -390,7 +409,7 @@ def _filter_files_list( if not stage_path_parts.directory: return self._filter_supported_files(files_on_stage) - stage_path = stage_path_parts.full_path + stage_path = stage_path_parts.directory # Exact file path was provided if stage_path in file list if stage_path in files_on_stage: