diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application.py b/src/snowflake/cli/_plugins/nativeapp/entities/application.py index b43d4ee8ea..1622fe7eff 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application.py @@ -75,6 +75,7 @@ ) from snowflake.cli.api.metrics import CLICounterField from snowflake.cli.api.project.schemas.entities.common import ( + DependsOnBaseModel, EntityModelBase, Identifier, PostDeployHook, @@ -246,7 +247,7 @@ def events_to_share( return sorted(list(set(events_names))) -class ApplicationEntityModel(EntityModelBase): +class ApplicationEntityModel(EntityModelBase, DependsOnBaseModel): type: Literal["application"] = DiscriminatorField() # noqa A003 from_: TargetField[ApplicationPackageEntityModel] = Field( alias="from", diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py index 1da13ca9d5..5cfdadceb5 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py @@ -81,6 +81,7 @@ from snowflake.cli.api.errno import DOES_NOT_EXIST_OR_NOT_AUTHORIZED from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError from snowflake.cli.api.project.schemas.entities.common import ( + DependsOnBaseModel, EntityModelBase, Identifier, PostDeployHook, @@ -145,7 +146,7 @@ class ApplicationPackageChildField(UpdatableModel): ) -class ApplicationPackageEntityModel(EntityModelBase): +class ApplicationPackageEntityModel(EntityModelBase, DependsOnBaseModel): type: Literal["application package"] = DiscriminatorField() # noqa: A003 artifacts: List[Union[PathMapping, str]] = Field( title="List of paths or file source/destination pairs to add to the deploy root", diff --git a/src/snowflake/cli/_plugins/snowpark/snowpark_entity.py b/src/snowflake/cli/_plugins/snowpark/snowpark_entity.py index 55d1113748..2d2b210561 100644 --- a/src/snowflake/cli/_plugins/snowpark/snowpark_entity.py +++ b/src/snowflake/cli/_plugins/snowpark/snowpark_entity.py @@ -1,16 +1,228 @@ -from typing import Generic, TypeVar +from enum import Enum +from pathlib import Path +from typing import Generic, List, Optional, TypeVar +from click import ClickException +from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag +from snowflake.cli._plugins.snowpark import package_utils +from snowflake.cli._plugins.snowpark.common import DEFAULT_RUNTIME +from snowflake.cli._plugins.snowpark.package.anaconda_packages import ( + AnacondaPackages, + AnacondaPackagesManager, +) +from snowflake.cli._plugins.snowpark.package_utils import ( + DownloadUnavailablePackagesResult, +) from snowflake.cli._plugins.snowpark.snowpark_entity_model import ( FunctionEntityModel, ProcedureEntityModel, ) -from snowflake.cli.api.entities.common import EntityBase +from snowflake.cli._plugins.snowpark.zipper import zip_dir +from snowflake.cli._plugins.workspace.context import ActionContext +from snowflake.cli.api.entities.bundle_and_deploy import BundleAndDeploy +from snowflake.cli.api.secure_path import SecurePath +from snowflake.connector import ProgrammingError T = TypeVar("T") -class SnowparkEntity(EntityBase[Generic[T]]): - pass +class CreateMode( + str, Enum +): # This should probably be moved to some common place, think where + create = "CREATE" + create_or_replace = "CREATE OR REPLACE" + create_if_not_exists = "CREATE IF NOT EXISTS" + + +class SnowparkEntity(BundleAndDeploy[Generic[T]]): + def __init__(self, *args, **kwargs): + + if not FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled(): + raise NotImplementedError("Snowpark entity is not implemented yet") + super().__init__(*args, **kwargs) + + def action_bundle( + self, + action_ctx: ActionContext, + output_dir: Path | None, + ignore_anaconda: bool, + skip_version_check: bool, + index_url: str | None = None, + allow_shared_libraries: bool = False, + *args, + **kwargs, + ) -> List[Path]: + return self.bundle( + output_dir, + ignore_anaconda, + skip_version_check, + index_url, + allow_shared_libraries, + ) + + def action_deploy( + self, action_ctx: ActionContext, mode: CreateMode, *args, **kwargs + ): + # TODO: After introducing bundle map, we should introduce file copying part here + return self._execute_query(self.get_deploy_sql(mode)) + + def action_drop(self, action_ctx: ActionContext, *args, **kwargs): + return self._execute_query(self.get_drop_sql()) + + def action_describe(self, action_ctx: ActionContext, *args, **kwargs): + return self._execute_query(self.get_describe_sql()) + + def action_execute( + self, + action_ctx: ActionContext, + execution_arguments: List[str] | None = None, + *args, + **kwargs, + ): + return self._execute_query(self.get_execute_sql(execution_arguments)) + + def bundle( + self, + output_dir: Path | None, + ignore_anaconda: bool, + skip_version_check: bool, + index_url: str | None = None, + allow_shared_libraries: bool = False, + ) -> List[Path]: + """ + Bundles the entity artifacts and dependencies into a directory. + Parameters: + output_dir: The directory to output the bundled artifacts to. Defaults to output dir in project root + ignore_anaconda: If True, ignores anaconda chceck and tries to download all packages using pip + skip_version_check: If True, skips version check when downloading packages + index_url: The index URL to use when downloading packages, if none set - default pip index is used (in most cases- Pypi) + allow_shared_libraries: If not set to True, using dependency with .so/.dll files will raise an exception + Returns: + """ + # 0 Create a directory for the entity + if not output_dir: + output_dir = self.root / "output" / self.model.stage + output_dir.mkdir(parents=True, exist_ok=True) # type: ignore + + output_files = [] + + # 1 Check if requirements exits + if (self.root / "requirements.txt").exists(): + download_results = self._process_requirements( + bundle_dir=output_dir, # type: ignore + archive_name="dependencies.zip", + requirements_file=SecurePath(self.root / "requirements.txt"), + ignore_anaconda=ignore_anaconda, + skip_version_check=skip_version_check, + index_url=index_url, + allow_shared_libraries=allow_shared_libraries, + ) + + # 3 get the artifacts list + artifacts = self.model.artifacts + + for artifact in artifacts: + output_file = output_dir / artifact.dest / artifact.src.name + + if artifact.src.is_file(): + output_file.mkdir(parents=True, exist_ok=True) + SecurePath(artifact.src).copy(output_file) + elif artifact.is_dir(): + output_file.mkdir(parents=True, exist_ok=True) + + output_files.append(output_file) + + return output_files + + def check_if_exists( + self, action_ctx: ActionContext + ) -> bool: # TODO it should return current state, so we know if update is necessary + try: + current_state = self.action_describe(action_ctx) + return True + except ProgrammingError: + return False + + def get_deploy_sql(self, mode: CreateMode): + query = [ + f"{mode.value} {self.model.type.upper()} {self.identifier}", + "COPY GRANTS", + f"RETURNS {self.model.returns}", + f"LANGUAGE PYTHON", + f"RUNTIME_VERSION '{self.model.runtime or DEFAULT_RUNTIME}'", + f"IMPORTS={','.join(self.model.imports)}", # TODO: Add source files here after introducing bundlemap + f"HANDLER='{self.model.handler}'", + ] + + if self.model.external_access_integrations: + query.append(self.model.get_external_access_integrations_sql()) + + if self.model.secrets: + query.append(self.model.get_secrets_sql()) + + if self.model.type == "procedure" and self.model.execute_as_caller: + query.append("EXECUTE AS CALLER") + + return "\n".join(query) + + def get_execute_sql(self, execution_arguments: List[str] | None = None): + raise NotImplementedError + + def _process_requirements( # TODO: maybe leave all the logic with requirements here - so download, write requirements file etc. + self, + bundle_dir: Path, + archive_name: str, # TODO: not the best name, think of something else + requirements_file: Optional[SecurePath], + ignore_anaconda: bool, + skip_version_check: bool = False, + index_url: Optional[str] = None, + allow_shared_libraries: bool = False, + ) -> DownloadUnavailablePackagesResult: + """ + Processes the requirements file and downloads the dependencies + Parameters: + + """ + anaconda_packages_manager = AnacondaPackagesManager() + with SecurePath.temporary_directory() as tmp_dir: + requirements = package_utils.parse_requirements(requirements_file) + anaconda_packages = ( + AnacondaPackages.empty() + if ignore_anaconda + else anaconda_packages_manager.find_packages_available_in_snowflake_anaconda() + ) + download_result = package_utils.download_unavailable_packages( + requirements=requirements, + target_dir=tmp_dir, + anaconda_packages=anaconda_packages, + skip_version_check=skip_version_check, + pip_index_url=index_url, + ) + + if download_result.anaconda_packages: + anaconda_packages.write_requirements_file_in_snowflake_format( + file_path=SecurePath(bundle_dir / "requirements.txt"), + requirements=download_result.anaconda_packages, + ) + + if download_result.downloaded_packages_details: + if ( + package_utils.detect_and_log_shared_libraries( + download_result.downloaded_packages_details + ) + and not allow_shared_libraries + ): + raise ClickException( + "Some packages contain shared (.so/.dll) libraries. " + "Try again with allow_shared_libraries_flag." + ) + + zip_dir( + source=tmp_dir, + dest_zip=bundle_dir / archive_name, + ) + + return download_result class FunctionEntity(SnowparkEntity[FunctionEntityModel]): @@ -18,7 +230,17 @@ class FunctionEntity(SnowparkEntity[FunctionEntityModel]): A single UDF """ - pass + # TO THINK OF + # Where will we get imports? Should we rely on bundle map? Or should it be self-sufficient in this matter? + + def get_execute_sql( + self, execution_arguments: List[str] | None = None, *args, **kwargs + ): + if not execution_arguments: + execution_arguments = [] + return ( + f"SELECT {self.fqn}({', '.join([str(arg) for arg in execution_arguments])})" + ) class ProcedureEntity(SnowparkEntity[ProcedureEntityModel]): @@ -26,4 +248,12 @@ class ProcedureEntity(SnowparkEntity[ProcedureEntityModel]): A stored procedure """ - pass + def get_execute_sql( + self, + execution_arguments: List[str] | None = None, + ): + if not execution_arguments: + execution_arguments = [] + return ( + f"CALL {self.fqn}({', '.join([str(arg) for arg in execution_arguments])})" + ) diff --git a/src/snowflake/cli/_plugins/snowpark/snowpark_entity_model.py b/src/snowflake/cli/_plugins/snowpark/snowpark_entity_model.py index a92716280c..a6819faf44 100644 --- a/src/snowflake/cli/_plugins/snowpark/snowpark_entity_model.py +++ b/src/snowflake/cli/_plugins/snowpark/snowpark_entity_model.py @@ -20,6 +20,8 @@ from pydantic import Field, field_validator from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.project.schemas.entities.common import ( + Dependency, # noqa # noqa + DependsOnBaseModel, EntityModelBase, ExternalAccessBaseModel, ImportsBaseModel, @@ -44,7 +46,9 @@ class Config: ) -class SnowparkEntityModel(EntityModelBase, ExternalAccessBaseModel, ImportsBaseModel): +class SnowparkEntityModel( + EntityModelBase, ExternalAccessBaseModel, ImportsBaseModel, DependsOnBaseModel +): handler: str = Field( title="Function’s or procedure’s implementation of the object inside source module", examples=["functions.hello_function"], diff --git a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py index c0c6786822..70bfc884e8 100644 --- a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py +++ b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py @@ -9,12 +9,13 @@ StreamlitEntityModel, ) from snowflake.cli._plugins.workspace.context import ActionContext -from snowflake.cli.api.entities.common import EntityBase, get_sql_executor +from snowflake.cli.api.entities.bundle_and_deploy import BundleAndDeploy +from snowflake.cli.api.entities.common import get_sql_executor from snowflake.cli.api.secure_path import SecurePath from snowflake.connector.cursor import SnowflakeCursor -class StreamlitEntity(EntityBase[StreamlitEntityModel]): +class StreamlitEntity(BundleAndDeploy[StreamlitEntityModel]): """ A Streamlit app. """ @@ -45,9 +46,17 @@ def model(self): return self._entity_model # noqa def action_bundle(self, action_ctx: ActionContext, *args, **kwargs): + + if self.model.depends_on: + dependent_entities = self.dependent_entities(action_ctx) + # TODO: we should sanitize the list to remove duplicates + for entity in dependent_entities: + if entity.supports("bundle"): + entity.bundle() + return self.bundle() - def action_deploy(self, action_ctx: ActionContext, *args, **kwargs): + def deploy(self, action_ctx: ActionContext, *args, **kwargs): # After adding bundle map- we should use it's mapping here # To copy artifacts to destination on stage. @@ -149,15 +158,12 @@ def get_deploy_sql( return query + ";" - def get_drop_sql(self): - return f"DROP STREAMLIT {self._entity_model.fqn};" + def get_share_sql(self, to_role: str) -> str: + return f"GRANT USAGE ON STREAMLIT {self.model.fqn.sql_identifier} TO ROLE {to_role};" def get_execute_sql(self): return f"EXECUTE STREAMLIT {self._entity_model.fqn}();" - def get_share_sql(self, to_role: str) -> str: - return f"GRANT USAGE ON STREAMLIT {self.model.fqn.sql_identifier} TO ROLE {to_role};" - def get_usage_grant_sql(self, app_role: str, schema: Optional[str] = None) -> str: entity_id = self.entity_id streamlit_name = f"{schema}.{entity_id}" if schema else entity_id diff --git a/src/snowflake/cli/_plugins/streamlit/streamlit_entity_model.py b/src/snowflake/cli/_plugins/streamlit/streamlit_entity_model.py index 55068adb5a..915bf505a3 100644 --- a/src/snowflake/cli/_plugins/streamlit/streamlit_entity_model.py +++ b/src/snowflake/cli/_plugins/streamlit/streamlit_entity_model.py @@ -18,6 +18,7 @@ from pydantic import Field, model_validator from snowflake.cli.api.project.schemas.entities.common import ( + DependsOnBaseModel, EntityModelBase, ExternalAccessBaseModel, ImportsBaseModel, @@ -27,7 +28,9 @@ ) -class StreamlitEntityModel(EntityModelBase, ExternalAccessBaseModel, ImportsBaseModel): +class StreamlitEntityModel( + EntityModelBase, ExternalAccessBaseModel, ImportsBaseModel, DependsOnBaseModel +): type: Literal["streamlit"] = DiscriminatorField() # noqa: A003 title: Optional[str] = Field( title="Human-readable title for the Streamlit dashboard", default=None diff --git a/src/snowflake/cli/api/entities/bundle_and_deploy.py b/src/snowflake/cli/api/entities/bundle_and_deploy.py new file mode 100644 index 0000000000..dd073774cf --- /dev/null +++ b/src/snowflake/cli/api/entities/bundle_and_deploy.py @@ -0,0 +1,38 @@ +from typing import TypeVar + +from snowflake.cli._plugins.workspace.context import ActionContext +from snowflake.cli.api.entities.common import EntityActions, EntityBase + +T = TypeVar("T") + + +class BundleAndDeploy(EntityBase[T]): + """ + Base class for entities that can be bundled and deployed + Provides basic action logic and abstract methods for bundle and deploy- to be implemented + using subclass specific logic + """ + + def action_bundle(self, action_ctx: ActionContext, *args, **kwargs): + if dependent_entities := self.dependent_entities(action_ctx): + for dependency in dependent_entities: + entity = action_ctx.get_entity(dependency.entity_id) + # TODO think how to pass arguments for dependencies + if entity.supports(EntityActions.BUNDLE): + entity.bundle() + + return self.bundle(*args, **kwargs) + + def action_deploy(self, action_ctx: ActionContext, *args, **kwargs): + if dependent_entities := self.dependent_entities(action_ctx): + for dependency in dependent_entities: + entity = action_ctx.get_entity(dependency.entity_id) + if entity.supports(EntityActions.DEPLOY): + entity.deploy() + return self.deploy(action_ctx, *args, **kwargs) + + def bundle(self, *args, **kwargs): + raise NotImplementedError("Bundle method should be implemented in subclass") + + def deploy(self, *args, **kwargs): + raise NotImplementedError("Deploy method should be implemented in subclass") diff --git a/src/snowflake/cli/api/entities/common.py b/src/snowflake/cli/api/entities/common.py index cbedb87825..dabaabed76 100644 --- a/src/snowflake/cli/api/entities/common.py +++ b/src/snowflake/cli/api/entities/common.py @@ -1,9 +1,17 @@ +import functools from enum import Enum -from typing import Generic, Type, TypeVar, get_args +from pathlib import Path +from typing import Any, Generic, List, Type, TypeVar, get_args from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext from snowflake.cli.api.cli_global_context import span +from snowflake.cli.api.exceptions import CycleDetectedError +from snowflake.cli.api.identifiers import FQN +from snowflake.cli.api.project.schemas.entities.common import Dependency from snowflake.cli.api.sql_execution import SqlExecutor +from snowflake.cli.api.utils.graph import Graph, Node +from snowflake.connector import SnowflakeConnection +from snowflake.connector.cursor import SnowflakeCursor class EntityActions(str, Enum): @@ -66,8 +74,37 @@ def __init__(self, entity_model: T, workspace_ctx: WorkspaceContext): self._workspace_ctx = workspace_ctx @property - def entity_id(self): - return self._entity_model.entity_id + def entity_id(self) -> str: + return self._entity_model.entity_id # type: ignore + + @property + def root(self) -> Path: + return self._workspace_ctx.project_root + + @property + def identifier(self) -> str: + return self.model.fqn.sql_identifier + + @property + def fqn(self) -> FQN: + return self._entity_model.fqn # type: ignore[attr-defined] + + @functools.cached_property + def _sql_executor( + self, + ) -> SqlExecutor: + return get_sql_executor() + + def _execute_query(self, sql: str) -> SnowflakeCursor: + return self._sql_executor.execute_query(sql) + + @functools.cached_property + def _conn(self) -> SnowflakeConnection: + return self._sql_executor._conn # noqa + + @property + def model(self): + return self._entity_model @classmethod def get_entity_model_type(cls) -> Type[T]: @@ -92,7 +129,94 @@ def perform( """ return getattr(self, action)(action_ctx, *args, **kwargs) + def dependent_entities(self, action_ctx: ActionContext) -> List[Dependency]: + """ + Returns a list of entities that this entity depends on. + """ + graph = self._create_dependency_graph(action_ctx) + sorted_dependecies = self._check_dependency_graph_for_cycles(graph) + + return sorted_dependecies + + def _create_dependency_graph(self, action_ctx: ActionContext) -> Graph[Dependency]: + """ + Creates a graph for dependencies. We need the graph, instead of a simple list, because we need to check if + calling dependencies actions in selected order is possible. + """ + graph = Graph() + depends_on = self._entity_model.depends_on or [] # type: ignore + self_dependency = Dependency(id=self.model.entity_id) # type: ignore + resolved_nodes = set() + + graph.add(Node(key=self_dependency.entity_id, data=self_dependency)) + + def _resolve_dependencies(parent_id: str, dependency: Dependency) -> None: + + if not graph.contains_node(dependency.entity_id): + dependency_node = Node(key=dependency.entity_id, data=dependency) + graph.add(dependency_node) + + graph.add_directed_edge(parent_id, dependency.entity_id) + + resolved_nodes.add(dependency_node.key) + + for child_dependency in action_ctx.get_entity( + dependency.entity_id + ).model.depends_on: + if child_dependency.entity_id not in resolved_nodes: + _resolve_dependencies(dependency_node.key, child_dependency) + else: + graph.add_directed_edge( + dependency_node.key, child_dependency.entity_id + ) + + for dependency in depends_on: + _resolve_dependencies(self_dependency.entity_id, dependency) + + return graph + + def _check_dependency_graph_for_cycles( + self, graph: Graph[Dependency] + ) -> List[Dependency]: + """ + This function is used to check, if dependency graph have any cycles, that would make it impossible to + deploy entities in correct order. + If cycle is detected, it raises CycleDetectedError + The result list, shows entities this one depends on, in order they should be called. + Last item is removed from the result list, as it is this entity itself. + """ + result = [] + + def _on_cycle(node: Node[T]) -> None: + raise CycleDetectedError( + f"Cycle detected in entity dependencies: {node.key}" + ) + + def _on_visit(node: Node[T]) -> None: + result.append(node.data) + + graph.dfs(on_cycle_action=_on_cycle, visit_action=_on_visit) + + return clear_duplicates_from_list(result)[:-1] + + def get_usage_grant_sql(self, app_role: str) -> str: + return f"GRANT USAGE ON {self.model.type.upper()} {self.identifier} TO ROLE {app_role};" + + def get_describe_sql(self) -> str: + return f"DESCRIBE {self.model.type.upper()} {self.identifier};" + + def get_drop_sql(self) -> str: + return f"DROP {self.model.type.upper()} {self.identifier};" + def get_sql_executor() -> SqlExecutor: """Returns an SQL Executor that uses the connection from the current CLI context""" return SqlExecutor() + + +def clear_duplicates_from_list(input_list: list[Any]) -> list[Any]: + """ + Removes duplicates from the input list and returns a new list. + """ + seen = set() + return [x for x in input_list if not (x in seen or seen.add(x))] # type: ignore diff --git a/src/snowflake/cli/api/project/schemas/entities/common.py b/src/snowflake/cli/api/project/schemas/entities/common.py index d9036d9a4c..b824d3f268 100644 --- a/src/snowflake/cli/api/project/schemas/entities/common.py +++ b/src/snowflake/cli/api/project/schemas/entities/common.py @@ -162,3 +162,23 @@ def get_secrets_sql(self) -> str | None: return None secrets = ", ".join(f"'{key}'={value}" for key, value in self.secrets.items()) return f"secrets=({secrets})" + + +class Dependency(UpdatableModel): + entity_id: str = Field(title="Id of the entity", alias="id") + arguments: Optional[str] = Field( + title="Arguments that will be passed to entity build and deploy actions", + default="", + ) + + def __eq__(self, other): + return self.entity_id == other.entity_id + + def __hash__(self): + return hash(self.entity_id) + + +class DependsOnBaseModel: + depends_on: Optional[List[Dependency]] = Field( + title="Entities that need to be deployed before this one", default=[] + ) diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index 2b0f4f5cf0..edbaeba871 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -28,6 +28,7 @@ ) from snowflake.cli.api.project.errors import SchemaValidationError from snowflake.cli.api.project.schemas.entities.common import ( + Dependency, # noqa # fmt: off # noqa TargetField, ) from snowflake.cli.api.project.schemas.entities.entities import ( @@ -262,6 +263,20 @@ def _merge_mixins_with_entity( data = cls._merge_data(data, entity) return data + @model_validator(mode="after") + def validate_dependencies(self): + """ + Checks if entities listed in depends_on section exist in the project + """ + + for entity_id, entity in self.entities.items(): + if entity.depends_on: + for dependency in entity.depends_on: + if dependency.entity_id not in self.entities: + raise ValueError( + f"Entity {entity_id} depends on non-existing entity {dependency.entity_id}" + ) + @classmethod def _merge_data( cls, diff --git a/src/snowflake/cli/api/utils/graph.py b/src/snowflake/cli/api/utils/graph.py index 4c818fe66d..078f70eaea 100644 --- a/src/snowflake/cli/api/utils/graph.py +++ b/src/snowflake/cli/api/utils/graph.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Callable, Generic, TypeVar +from typing import Callable, Dict, Generic, Set, TypeVar T = TypeVar("T") @@ -43,6 +43,9 @@ class Graph(Generic[T]): def __init__(self): self._graph_nodes_map: dict[str, Node[T]] = {} + def contains_node(self, key: str) -> bool: + return self.__contains__(key) + def get(self, key: str) -> Node[T]: if key in self._graph_nodes_map: return self._graph_nodes_map[key] @@ -93,5 +96,24 @@ def dfs( for node in self._graph_nodes_map.values(): Graph._dfs_visit(nodes_status, node, visit_action, on_cycle_action) + def layers( + self, starting_node_id: str + ): # TODO: probably can be removed as it is not used + """ + Creates a list of graph layers, relative to the starting node + """ + result: Dict[int, Set] = {} + + def _add_layer(node: Node[T], layer_number: int): + if not result.get(layer_number): + result[layer_number] = set() + result[layer_number].add(node) + for neighbor in node.neighbors: + _add_layer(neighbor, layer_number + 1) + + _add_layer(self.get(starting_node_id), 0) + + return [tuple(result[key]) for key in sorted(result.keys())] + def __contains__(self, key: str) -> bool: return key in self._graph_nodes_map diff --git a/tests/api/utils/test_graph.py b/tests/api/utils/test_graph.py index 202ba8564c..bbe8b59ff7 100644 --- a/tests/api/utils/test_graph.py +++ b/tests/api/utils/test_graph.py @@ -122,3 +122,23 @@ def cycle_detected_action(node): graph.dfs(on_cycle_action=cycle_detected_action) assert cycles_detected["count"] == 1 + + +def test_layers(nodes): + graph = Graph() + for i in range(5): + graph.add(nodes[i]) + + graph.add_directed_edge(nodes[0].key, nodes[1].key) + graph.add_directed_edge(nodes[0].key, nodes[2].key) + graph.add_directed_edge(nodes[2].key, nodes[3].key) + graph.add_directed_edge(nodes[2].key, nodes[4].key) + + result = graph.layers(nodes[0].key) + + assert len(result) == 3 + assert nodes[0] in result[0] + assert nodes[1] in result[1] + assert nodes[2] in result[1] + assert nodes[3] in result[2] + assert nodes[4] in result[2] diff --git a/tests/project/test_depends_on.py b/tests/project/test_depends_on.py new file mode 100644 index 0000000000..891ce1c095 --- /dev/null +++ b/tests/project/test_depends_on.py @@ -0,0 +1,86 @@ +from unittest import mock + +import pytest +from snowflake.cli._plugins.streamlit.streamlit_entity import StreamlitEntity +from snowflake.cli._plugins.workspace.context import WorkspaceContext +from snowflake.cli._plugins.workspace.manager import WorkspaceManager +from snowflake.cli.api.exceptions import CycleDetectedError +from snowflake.cli.api.project.definition import load_project +from snowflake.cli.api.project.errors import SchemaValidationError + +from tests.testing_utils.mock_config import mock_config_key + + +@pytest.fixture +def example_workspace(project_directory): + # TODO: try to make a common fixture for all entities + def _workspace_fixture(project_name: str, entity_name: str, entity_class: type): + with mock_config_key("enable_native_app_children", True): + with project_directory(project_name) as pdir: + project = load_project([pdir / "snowflake.yml"]) + + model = project.project_definition.entities.get(entity_name) + + workspace_context = WorkspaceContext( + console=mock.MagicMock(), + project_root=pdir, + get_default_role=lambda: "test_role", + get_default_warehouse=lambda: "test_warehouse", + ) + + workspace_manager = WorkspaceManager(project.project_definition, pdir) + + return ( + entity_class(workspace_ctx=workspace_context, entity_model=model), + workspace_manager.action_ctx, + ) + + return _workspace_fixture + + +@pytest.fixture +def cyclic_dependency_workspace(example_workspace): + return example_workspace( + "depends_on_with_cyclic_dependency", "test_streamlit", StreamlitEntity + ) + + +@pytest.fixture +def basic_workspace(example_workspace): + return example_workspace("depends_on_basic", "test_streamlit", StreamlitEntity) + + +def test_cyclic(cyclic_dependency_workspace): + with pytest.raises(CycleDetectedError) as err: + with mock_config_key("enable_native_app_children", True): + entity, action_ctx = cyclic_dependency_workspace + _ = entity.dependent_entities(action_ctx) + assert err.value.message == "Cycle detected in entity dependencies: test_function" + + +def test_dependencies_must_exist_in_project_file( + project_directory, alter_snowflake_yml +): + with project_directory("depends_on_with_cyclic_dependency") as pdir: + alter_snowflake_yml( + snowflake_yml_path=pdir / "snowflake.yml", + parameter_path="entities.test_streamlit.depends_on.0.id", + value="foo", + ) + with pytest.raises(SchemaValidationError) as err: + project = load_project([pdir / "snowflake.yml"]) + + assert ( + "Entity test_streamlit depends on non-existing entity foo" in err.value.message + ) + + +def test_dependencies_basic(basic_workspace): + with mock_config_key("enable_native_app_children", True): + entity, action_ctx = basic_workspace + result = entity.dependent_entities(action_ctx) + + assert len(result) == 3 + assert result[0].entity_id == "test_function2" + assert result[1].entity_id == "test_function" + assert result[2].entity_id == "test_procedure" diff --git a/tests/snowpark/__snapshots__/test_snowpark_entity.ambr b/tests/snowpark/__snapshots__/test_snowpark_entity.ambr new file mode 100644 index 0000000000..a486d2f274 --- /dev/null +++ b/tests/snowpark/__snapshots__/test_snowpark_entity.ambr @@ -0,0 +1,63 @@ +# serializer version: 1 +# name: test_action_execute[None] + 'SELECT func1()' +# --- +# name: test_action_execute[execution_arguments1] + 'SELECT func1(arg1, arg2)' +# --- +# name: test_action_execute[execution_arguments2] + 'SELECT func1(foo, 42, bar)' +# --- +# name: test_function_get_execute_sql[None] + 'SELECT func1()' +# --- +# name: test_function_get_execute_sql[execution_arguments1] + 'SELECT func1(arg1, arg2)' +# --- +# name: test_function_get_execute_sql[execution_arguments2] + 'SELECT func1(foo, 42, bar)' +# --- +# name: test_get_deploy_sql[CREATE IF NOT EXISTS] + ''' + CREATE IF NOT EXISTS FUNCTION IDENTIFIER('func1') + COPY GRANTS + RETURNS string + LANGUAGE PYTHON + RUNTIME_VERSION '3.10' + IMPORTS= + HANDLER='app.func1_handler' + ''' +# --- +# name: test_get_deploy_sql[CREATE OR REPLACE] + ''' + CREATE OR REPLACE FUNCTION IDENTIFIER('func1') + COPY GRANTS + RETURNS string + LANGUAGE PYTHON + RUNTIME_VERSION '3.10' + IMPORTS= + HANDLER='app.func1_handler' + ''' +# --- +# name: test_get_deploy_sql[CREATE] + ''' + CREATE FUNCTION IDENTIFIER('func1') + COPY GRANTS + RETURNS string + LANGUAGE PYTHON + RUNTIME_VERSION '3.10' + IMPORTS= + HANDLER='app.func1_handler' + ''' +# --- +# name: test_nativeapp_children_interface + ''' + CREATE FUNCTION IDENTIFIER('func1') + COPY GRANTS + RETURNS string + LANGUAGE PYTHON + RUNTIME_VERSION '3.10' + IMPORTS= + HANDLER='app.func1_handler' + ''' +# --- diff --git a/tests/snowpark/test_snowpark_entity.py b/tests/snowpark/test_snowpark_entity.py new file mode 100644 index 0000000000..419c6aa56c --- /dev/null +++ b/tests/snowpark/test_snowpark_entity.py @@ -0,0 +1,176 @@ +from pathlib import Path +from unittest import mock + +import pytest +import yaml +from snowflake.cli._plugins.snowpark.package.anaconda_packages import ( + AnacondaPackages, + AvailablePackage, +) +from snowflake.cli._plugins.snowpark.snowpark_entity import ( + CreateMode, + FunctionEntity, + ProcedureEntity, +) +from snowflake.cli._plugins.snowpark.snowpark_entity_model import FunctionEntityModel +from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext + +from tests.testing_utils.mock_config import mock_config_key + +CONNECTOR = "snowflake.connector.connect" +CONTEXT = "" +EXECUTE_QUERY = "snowflake.cli.api.sql_execution.BaseSqlExecutor.execute_query" +ANACONDA_PACKAGES = "snowflake.cli._plugins.snowpark.package.anaconda_packages.AnacondaPackagesManager.find_packages_available_in_snowflake_anaconda" + + +@pytest.fixture +def example_function_workspace( + project_directory, +): # TODO: try to make a common fixture for all entities + with mock_config_key("enable_native_app_children", True): + with project_directory("snowpark_functions_v2") as pdir: + with Path(pdir / "snowflake.yml").open() as definition_file: + definition = yaml.safe_load(definition_file) + model = FunctionEntityModel( + **definition.get("entities", {}).get("func1") + ) + + workspace_context = WorkspaceContext( + console=mock.MagicMock(), + project_root=pdir, + get_default_role=lambda: "test_role", + get_default_warehouse=lambda: "test_warehouse", + ) + + return ( + FunctionEntity(workspace_ctx=workspace_context, entity_model=model), + ActionContext( + get_entity=lambda *args: None, + ), + ) + + +def test_cannot_instantiate_without_feature_flag(): + with pytest.raises(NotImplementedError) as err: + FunctionEntity() + assert str(err.value) == "Snowpark entity is not implemented yet" + + with pytest.raises(NotImplementedError) as err: + ProcedureEntity() + assert str(err.value) == "Snowpark entity is not implemented yet" + + +@mock.patch(ANACONDA_PACKAGES) +def test_nativeapp_children_interface( + mock_anaconda, example_function_workspace, snapshot +): + mock_anaconda.return_value = AnacondaPackages( + { + "pandas": AvailablePackage("pandas", "1.2.3"), + "numpy": AvailablePackage("numpy", "1.2.3"), + "snowflake_snowpark_python": AvailablePackage( + "snowflake_snowpark_python", "1.2.3" + ), + } + ) + + sl, action_context = example_function_workspace + + sl.bundle(None, False, False, None, False) + bundle_artifact = ( + sl.root / "output" / sl.model.stage / "my_snowpark_project" / "app.py" + ) + deploy_sql_str = sl.get_deploy_sql(CreateMode.create) + grant_sql_str = sl.get_usage_grant_sql(app_role="app_role") + + assert bundle_artifact.exists() + assert deploy_sql_str == snapshot + assert ( + grant_sql_str + == f"GRANT USAGE ON FUNCTION IDENTIFIER('func1') TO ROLE app_role;" + ) + + +@mock.patch(EXECUTE_QUERY) +def test_action_describe(mock_execute, example_function_workspace): + entity, action_context = example_function_workspace + result = entity.action_describe(action_context) + + mock_execute.assert_called_with("DESCRIBE FUNCTION IDENTIFIER('func1');") + + +@mock.patch(EXECUTE_QUERY) +def test_action_drop(mock_execute, example_function_workspace): + entity, action_context = example_function_workspace + result = entity.action_drop(action_context) + + mock_execute.assert_called_with("DROP FUNCTION IDENTIFIER('func1');") + + +@pytest.mark.parametrize( + "execution_arguments", [None, ["arg1", "arg2"], ["foo", 42, "bar"]] +) +@mock.patch(EXECUTE_QUERY) +def test_action_execute( + mock_execute, execution_arguments, example_function_workspace, snapshot +): + entity, action_context = example_function_workspace + result = entity.action_execute(action_context, execution_arguments) + + mock_execute.assert_called_with(snapshot) + + +@mock.patch(ANACONDA_PACKAGES) +def test_bundle(mock_anaconda, example_function_workspace): + mock_anaconda.return_value = AnacondaPackages( + { + "pandas": AvailablePackage("pandas", "1.2.3"), + "numpy": AvailablePackage("numpy", "1.2.3"), + "snowflake_snowpark_python": AvailablePackage( + "snowflake_snowpark_python", "1.2.3" + ), + } + ) + entity, action_context = example_function_workspace + entity.action_bundle(action_context, None, False, False, None, False) + + output = entity.root / "output" / entity._entity_model.stage # noqa + assert output.exists() + assert (output / "my_snowpark_project" / "app.py").exists() + + +def test_describe_function_sql(example_function_workspace): + entity, _ = example_function_workspace + assert entity.get_describe_sql() == "DESCRIBE FUNCTION IDENTIFIER('func1');" + + +def test_drop_function_sql(example_function_workspace): + entity, _ = example_function_workspace + assert entity.get_drop_sql() == "DROP FUNCTION IDENTIFIER('func1');" + + +@pytest.mark.parametrize( + "execution_arguments", [None, ["arg1", "arg2"], ["foo", 42, "bar"]] +) +def test_function_get_execute_sql( + execution_arguments, example_function_workspace, snapshot +): + entity, _ = example_function_workspace + assert entity.get_execute_sql(execution_arguments) == snapshot + + +@pytest.mark.parametrize( + "mode", + [CreateMode.create, CreateMode.create_or_replace, CreateMode.create_if_not_exists], +) +def test_get_deploy_sql(mode, example_function_workspace, snapshot): + entity, _ = example_function_workspace + assert entity.get_deploy_sql(mode) == snapshot + + +def test_get_usage_grant_sql(example_function_workspace): + entity, _ = example_function_workspace + assert ( + entity.get_usage_grant_sql("test_role") + == "GRANT USAGE ON FUNCTION IDENTIFIER('func1') TO ROLE test_role;" + ) diff --git a/tests/streamlit/test_streamlit_entity.py b/tests/streamlit/test_streamlit_entity.py index 4574939f35..e1587fb599 100644 --- a/tests/streamlit/test_streamlit_entity.py +++ b/tests/streamlit/test_streamlit_entity.py @@ -32,6 +32,7 @@ def example_streamlit_workspace(project_directory): model = StreamlitEntityModel( **definition.get("entities", {}).get("test_streamlit") ) + model.set_entity_id("test_streamlit") workspace_context = WorkspaceContext( console=mock.MagicMock(), @@ -67,7 +68,8 @@ def test_nativeapp_children_interface(example_streamlit_workspace, snapshot): assert bundle_artifact.exists() assert deploy_sql_str == snapshot assert ( - grant_sql_str == f"GRANT USAGE ON STREAMLIT None TO APPLICATION ROLE app_role;" + grant_sql_str + == f"GRANT USAGE ON STREAMLIT test_streamlit TO APPLICATION ROLE app_role;" ) @@ -98,7 +100,7 @@ def test_drop(mock_execute, example_streamlit_workspace): entity, action_ctx = example_streamlit_workspace entity.action_drop(action_ctx) - mock_execute.assert_called_with(f"DROP STREAMLIT {STREAMLIT_NAME};") + mock_execute.assert_called_with(f"DROP STREAMLIT IDENTIFIER('{STREAMLIT_NAME}');") @mock.patch(CONNECTOR) @@ -163,7 +165,7 @@ def test_get_drop_sql(example_streamlit_workspace): entity, action_ctx = example_streamlit_workspace drop_sql = entity.get_drop_sql() - assert drop_sql == f"DROP STREAMLIT {STREAMLIT_NAME};" + assert drop_sql == f"DROP STREAMLIT IDENTIFIER('{STREAMLIT_NAME}');" @pytest.mark.parametrize( diff --git a/tests/test_data/projects/depends_on_basic/app.py b/tests/test_data/projects/depends_on_basic/app.py new file mode 100644 index 0000000000..3cdb8feab8 --- /dev/null +++ b/tests/test_data/projects/depends_on_basic/app.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import sys + +from snowflake.snowpark import Session + + +def hello(name: str) -> str: + return f"Hello {name}!" + + +def test(session: Session) -> str: + return "Test procedure" + + +# For local debugging. Be aware you may need to type-convert arguments if +# you add input parameters +if __name__ == "__main__": + if len(sys.argv) > 1: + print(hello(sys.argv[1])) # type: ignore + else: + print(hello("world")) diff --git a/tests/test_data/projects/depends_on_basic/environment.yml b/tests/test_data/projects/depends_on_basic/environment.yml new file mode 100644 index 0000000000..ac8feac3e8 --- /dev/null +++ b/tests/test_data/projects/depends_on_basic/environment.yml @@ -0,0 +1,5 @@ +name: sf_env +channels: + - snowflake +dependencies: + - pandas diff --git a/tests/test_data/projects/depends_on_basic/pages/my_page.py b/tests/test_data/projects/depends_on_basic/pages/my_page.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_data/projects/depends_on_basic/requirements.txt b/tests/test_data/projects/depends_on_basic/requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_data/projects/depends_on_basic/snowflake.yml b/tests/test_data/projects/depends_on_basic/snowflake.yml new file mode 100644 index 0000000000..f09cd6c0d0 --- /dev/null +++ b/tests/test_data/projects/depends_on_basic/snowflake.yml @@ -0,0 +1,60 @@ +definition_version: 2 +entities: + test_streamlit: + artifacts: + - streamlit_app.py + - environment.yml + - pages + identifier: + name: test_streamlit + main_file: streamlit_app.py + query_warehouse: test_warehouse + stage: streamlit + type: streamlit + depends_on: + - id: test_procedure + + + test_procedure: + artifacts: + - app/ + handler: hello + identifier: + name: test_procedure + returns: string + signature: + - name: "name" + type: "string" + stage: dev_deployment + type: procedure + depends_on: + - id: test_function + - id: test_function2 + + test_function: + artifacts: + - app/ + handler: hello + identifier: + name: test_function + returns: string + signature: + - name: "name" + type: "string" + stage: dev_deployment + type: function + depends_on: + - id: test_function2 + + test_function2: + artifacts: + - app/ + handler: hello + identifier: + name: test_function + returns: string + signature: + - name: "name" + type: "string" + stage: dev_deployment + type: function diff --git a/tests/test_data/projects/depends_on_basic/streamlit_app.py b/tests/test_data/projects/depends_on_basic/streamlit_app.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_data/projects/depends_on_with_cyclic_dependency/app.py b/tests/test_data/projects/depends_on_with_cyclic_dependency/app.py new file mode 100644 index 0000000000..3cdb8feab8 --- /dev/null +++ b/tests/test_data/projects/depends_on_with_cyclic_dependency/app.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import sys + +from snowflake.snowpark import Session + + +def hello(name: str) -> str: + return f"Hello {name}!" + + +def test(session: Session) -> str: + return "Test procedure" + + +# For local debugging. Be aware you may need to type-convert arguments if +# you add input parameters +if __name__ == "__main__": + if len(sys.argv) > 1: + print(hello(sys.argv[1])) # type: ignore + else: + print(hello("world")) diff --git a/tests/test_data/projects/depends_on_with_cyclic_dependency/environment.yml b/tests/test_data/projects/depends_on_with_cyclic_dependency/environment.yml new file mode 100644 index 0000000000..ac8feac3e8 --- /dev/null +++ b/tests/test_data/projects/depends_on_with_cyclic_dependency/environment.yml @@ -0,0 +1,5 @@ +name: sf_env +channels: + - snowflake +dependencies: + - pandas diff --git a/tests/test_data/projects/depends_on_with_cyclic_dependency/pages/my_page.py b/tests/test_data/projects/depends_on_with_cyclic_dependency/pages/my_page.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_data/projects/depends_on_with_cyclic_dependency/requirements.txt b/tests/test_data/projects/depends_on_with_cyclic_dependency/requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_data/projects/depends_on_with_cyclic_dependency/snowflake.yml b/tests/test_data/projects/depends_on_with_cyclic_dependency/snowflake.yml new file mode 100644 index 0000000000..e6abb745f2 --- /dev/null +++ b/tests/test_data/projects/depends_on_with_cyclic_dependency/snowflake.yml @@ -0,0 +1,45 @@ +definition_version: 2 +entities: + test_streamlit: + artifacts: + - streamlit_app.py + - environment.yml + - pages + identifier: + name: test_streamlit + main_file: streamlit_app.py + query_warehouse: test_warehouse + stage: streamlit + type: streamlit + depends_on: + - id: test_procedure + + test_procedure: + artifacts: + - app/ + handler: hello + identifier: + name: test_procedure + returns: string + signature: + - name: "name" + type: "string" + stage: dev_deployment + type: procedure + depends_on: + - id: test_function + + test_function: + artifacts: + - app/ + handler: hello + identifier: + name: test_function + returns: string + signature: + - name: "name" + type: "string" + stage: dev_deployment + type: function + depends_on: + - id: test_procedure diff --git a/tests/test_data/projects/depends_on_with_cyclic_dependency/streamlit_app.py b/tests/test_data/projects/depends_on_with_cyclic_dependency/streamlit_app.py new file mode 100644 index 0000000000..e69de29bb2