Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive depends on #1980

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
)
from snowflake.cli.api.metrics import CLICounterField
from snowflake.cli.api.project.schemas.entities.common import (
DependsOnBaseModel,
EntityModelBase,
Identifier,
PostDeployHook,
Expand Down Expand Up @@ -246,7 +247,7 @@ def events_to_share(
return sorted(list(set(events_names)))


class ApplicationEntityModel(EntityModelBase):
class ApplicationEntityModel(EntityModelBase, DependsOnBaseModel):
type: Literal["application"] = DiscriminatorField() # noqa A003
from_: TargetField[ApplicationPackageEntityModel] = Field(
alias="from",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from snowflake.cli.api.errno import DOES_NOT_EXIST_OR_NOT_AUTHORIZED
from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError
from snowflake.cli.api.project.schemas.entities.common import (
DependsOnBaseModel,
EntityModelBase,
Identifier,
PostDeployHook,
Expand Down Expand Up @@ -145,7 +146,7 @@ class ApplicationPackageChildField(UpdatableModel):
)


class ApplicationPackageEntityModel(EntityModelBase):
class ApplicationPackageEntityModel(EntityModelBase, DependsOnBaseModel):
type: Literal["application package"] = DiscriminatorField() # noqa: A003
artifacts: List[Union[PathMapping, str]] = Field(
title="List of paths or file source/destination pairs to add to the deploy root",
Expand Down
242 changes: 236 additions & 6 deletions src/snowflake/cli/_plugins/snowpark/snowpark_entity.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,259 @@
from typing import Generic, TypeVar
from enum import Enum
from pathlib import Path
from typing import Generic, List, Optional, TypeVar

from click import ClickException
from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag
from snowflake.cli._plugins.snowpark import package_utils
from snowflake.cli._plugins.snowpark.common import DEFAULT_RUNTIME
from snowflake.cli._plugins.snowpark.package.anaconda_packages import (
AnacondaPackages,
AnacondaPackagesManager,
)
from snowflake.cli._plugins.snowpark.package_utils import (
DownloadUnavailablePackagesResult,
)
from snowflake.cli._plugins.snowpark.snowpark_entity_model import (
FunctionEntityModel,
ProcedureEntityModel,
)
from snowflake.cli.api.entities.common import EntityBase
from snowflake.cli._plugins.snowpark.zipper import zip_dir
from snowflake.cli._plugins.workspace.context import ActionContext
from snowflake.cli.api.entities.bundle_and_deploy import BundleAndDeploy
from snowflake.cli.api.secure_path import SecurePath
from snowflake.connector import ProgrammingError

T = TypeVar("T")


class SnowparkEntity(EntityBase[Generic[T]]):
pass
class CreateMode(
str, Enum
): # This should probably be moved to some common place, think where
create = "CREATE"
create_or_replace = "CREATE OR REPLACE"
create_if_not_exists = "CREATE IF NOT EXISTS"


class SnowparkEntity(BundleAndDeploy[Generic[T]]):
def __init__(self, *args, **kwargs):

if not FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled():
raise NotImplementedError("Snowpark entity is not implemented yet")
super().__init__(*args, **kwargs)

def action_bundle(
self,
action_ctx: ActionContext,
output_dir: Path | None,
ignore_anaconda: bool,
skip_version_check: bool,
index_url: str | None = None,
allow_shared_libraries: bool = False,
*args,
**kwargs,
) -> List[Path]:
return self.bundle(
output_dir,
ignore_anaconda,
skip_version_check,
index_url,
allow_shared_libraries,
)

def action_deploy(
self, action_ctx: ActionContext, mode: CreateMode, *args, **kwargs
):
# TODO: After introducing bundle map, we should introduce file copying part here
return self._execute_query(self.get_deploy_sql(mode))

def action_drop(self, action_ctx: ActionContext, *args, **kwargs):
return self._execute_query(self.get_drop_sql())

def action_describe(self, action_ctx: ActionContext, *args, **kwargs):
return self._execute_query(self.get_describe_sql())

def action_execute(
self,
action_ctx: ActionContext,
execution_arguments: List[str] | None = None,
*args,
**kwargs,
):
return self._execute_query(self.get_execute_sql(execution_arguments))

def bundle(
self,
output_dir: Path | None,
ignore_anaconda: bool,
skip_version_check: bool,
index_url: str | None = None,
allow_shared_libraries: bool = False,
) -> List[Path]:
"""
Bundles the entity artifacts and dependencies into a directory.
Parameters:
output_dir: The directory to output the bundled artifacts to. Defaults to output dir in project root
ignore_anaconda: If True, ignores anaconda chceck and tries to download all packages using pip
skip_version_check: If True, skips version check when downloading packages
index_url: The index URL to use when downloading packages, if none set - default pip index is used (in most cases- Pypi)
allow_shared_libraries: If not set to True, using dependency with .so/.dll files will raise an exception
Returns:
"""
# 0 Create a directory for the entity
if not output_dir:
output_dir = self.root / "output" / self.model.stage
output_dir.mkdir(parents=True, exist_ok=True) # type: ignore

output_files = []

# 1 Check if requirements exits
if (self.root / "requirements.txt").exists():
download_results = self._process_requirements(
bundle_dir=output_dir, # type: ignore
archive_name="dependencies.zip",
requirements_file=SecurePath(self.root / "requirements.txt"),
ignore_anaconda=ignore_anaconda,
skip_version_check=skip_version_check,
index_url=index_url,
allow_shared_libraries=allow_shared_libraries,
)

# 3 get the artifacts list
artifacts = self.model.artifacts

for artifact in artifacts:
output_file = output_dir / artifact.dest / artifact.src.name

if artifact.src.is_file():
output_file.mkdir(parents=True, exist_ok=True)
SecurePath(artifact.src).copy(output_file)
elif artifact.is_dir():
output_file.mkdir(parents=True, exist_ok=True)

output_files.append(output_file)

return output_files

def check_if_exists(
self, action_ctx: ActionContext
) -> bool: # TODO it should return current state, so we know if update is necessary
try:
current_state = self.action_describe(action_ctx)
return True
except ProgrammingError:
return False

def get_deploy_sql(self, mode: CreateMode):
query = [
f"{mode.value} {self.model.type.upper()} {self.identifier}",
"COPY GRANTS",
f"RETURNS {self.model.returns}",
f"LANGUAGE PYTHON",
f"RUNTIME_VERSION '{self.model.runtime or DEFAULT_RUNTIME}'",
f"IMPORTS={','.join(self.model.imports)}", # TODO: Add source files here after introducing bundlemap
f"HANDLER='{self.model.handler}'",
]

if self.model.external_access_integrations:
query.append(self.model.get_external_access_integrations_sql())

if self.model.secrets:
query.append(self.model.get_secrets_sql())

if self.model.type == "procedure" and self.model.execute_as_caller:
query.append("EXECUTE AS CALLER")

return "\n".join(query)

def get_execute_sql(self, execution_arguments: List[str] | None = None):
raise NotImplementedError

def _process_requirements( # TODO: maybe leave all the logic with requirements here - so download, write requirements file etc.
self,
bundle_dir: Path,
archive_name: str, # TODO: not the best name, think of something else
requirements_file: Optional[SecurePath],
ignore_anaconda: bool,
skip_version_check: bool = False,
index_url: Optional[str] = None,
allow_shared_libraries: bool = False,
) -> DownloadUnavailablePackagesResult:
"""
Processes the requirements file and downloads the dependencies
Parameters:

"""
anaconda_packages_manager = AnacondaPackagesManager()
with SecurePath.temporary_directory() as tmp_dir:
requirements = package_utils.parse_requirements(requirements_file)
anaconda_packages = (
AnacondaPackages.empty()
if ignore_anaconda
else anaconda_packages_manager.find_packages_available_in_snowflake_anaconda()
)
download_result = package_utils.download_unavailable_packages(
requirements=requirements,
target_dir=tmp_dir,
anaconda_packages=anaconda_packages,
skip_version_check=skip_version_check,
pip_index_url=index_url,
)

if download_result.anaconda_packages:
anaconda_packages.write_requirements_file_in_snowflake_format(
file_path=SecurePath(bundle_dir / "requirements.txt"),
requirements=download_result.anaconda_packages,
)

if download_result.downloaded_packages_details:
if (
package_utils.detect_and_log_shared_libraries(
download_result.downloaded_packages_details
)
and not allow_shared_libraries
):
raise ClickException(
"Some packages contain shared (.so/.dll) libraries. "
"Try again with allow_shared_libraries_flag."
)

zip_dir(
source=tmp_dir,
dest_zip=bundle_dir / archive_name,
)

return download_result


class FunctionEntity(SnowparkEntity[FunctionEntityModel]):
"""
A single UDF
"""

pass
# TO THINK OF
# Where will we get imports? Should we rely on bundle map? Or should it be self-sufficient in this matter?

def get_execute_sql(
self, execution_arguments: List[str] | None = None, *args, **kwargs
):
if not execution_arguments:
execution_arguments = []
return (
f"SELECT {self.fqn}({', '.join([str(arg) for arg in execution_arguments])})"
)


class ProcedureEntity(SnowparkEntity[ProcedureEntityModel]):
"""
A stored procedure
"""

pass
def get_execute_sql(
self,
execution_arguments: List[str] | None = None,
):
if not execution_arguments:
execution_arguments = []
return (
f"CALL {self.fqn}({', '.join([str(arg) for arg in execution_arguments])})"
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from pydantic import Field, field_validator
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.project.schemas.entities.common import (
Dependency, # noqa # noqa
DependsOnBaseModel,
EntityModelBase,
ExternalAccessBaseModel,
ImportsBaseModel,
Expand All @@ -44,7 +46,9 @@ class Config:
)


class SnowparkEntityModel(EntityModelBase, ExternalAccessBaseModel, ImportsBaseModel):
class SnowparkEntityModel(
EntityModelBase, ExternalAccessBaseModel, ImportsBaseModel, DependsOnBaseModel
):
handler: str = Field(
title="Function’s or procedure’s implementation of the object inside source module",
examples=["functions.hello_function"],
Expand Down
22 changes: 14 additions & 8 deletions src/snowflake/cli/_plugins/streamlit/streamlit_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
StreamlitEntityModel,
)
from snowflake.cli._plugins.workspace.context import ActionContext
from snowflake.cli.api.entities.common import EntityBase, get_sql_executor
from snowflake.cli.api.entities.bundle_and_deploy import BundleAndDeploy
from snowflake.cli.api.entities.common import get_sql_executor
from snowflake.cli.api.secure_path import SecurePath
from snowflake.connector.cursor import SnowflakeCursor


class StreamlitEntity(EntityBase[StreamlitEntityModel]):
class StreamlitEntity(BundleAndDeploy[StreamlitEntityModel]):
"""
A Streamlit app.
"""
Expand Down Expand Up @@ -45,9 +46,17 @@ def model(self):
return self._entity_model # noqa

def action_bundle(self, action_ctx: ActionContext, *args, **kwargs):

if self.model.depends_on:
dependent_entities = self.dependent_entities(action_ctx)
# TODO: we should sanitize the list to remove duplicates
for entity in dependent_entities:
if entity.supports("bundle"):
entity.bundle()

return self.bundle()

def action_deploy(self, action_ctx: ActionContext, *args, **kwargs):
def deploy(self, action_ctx: ActionContext, *args, **kwargs):
# After adding bundle map- we should use it's mapping here
# To copy artifacts to destination on stage.

Expand Down Expand Up @@ -149,15 +158,12 @@ def get_deploy_sql(

return query + ";"

def get_drop_sql(self):
return f"DROP STREAMLIT {self._entity_model.fqn};"
def get_share_sql(self, to_role: str) -> str:
return f"GRANT USAGE ON STREAMLIT {self.model.fqn.sql_identifier} TO ROLE {to_role};"

def get_execute_sql(self):
return f"EXECUTE STREAMLIT {self._entity_model.fqn}();"

def get_share_sql(self, to_role: str) -> str:
return f"GRANT USAGE ON STREAMLIT {self.model.fqn.sql_identifier} TO ROLE {to_role};"

def get_usage_grant_sql(self, app_role: str, schema: Optional[str] = None) -> str:
entity_id = self.entity_id
streamlit_name = f"{schema}.{entity_id}" if schema else entity_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from pydantic import Field, model_validator
from snowflake.cli.api.project.schemas.entities.common import (
DependsOnBaseModel,
EntityModelBase,
ExternalAccessBaseModel,
ImportsBaseModel,
Expand All @@ -27,7 +28,9 @@
)


class StreamlitEntityModel(EntityModelBase, ExternalAccessBaseModel, ImportsBaseModel):
class StreamlitEntityModel(
EntityModelBase, ExternalAccessBaseModel, ImportsBaseModel, DependsOnBaseModel
):
type: Literal["streamlit"] = DiscriminatorField() # noqa: A003
title: Optional[str] = Field(
title="Human-readable title for the Streamlit dashboard", default=None
Expand Down
Loading
Loading