Skip to content

Commit

Permalink
Solution
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jsikorski committed Jan 9, 2025
1 parent 5d1c94f commit 801b67d
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 206 deletions.
60 changes: 10 additions & 50 deletions src/snowflake/cli/_plugins/snowpark/snowpark_entity.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import functools
from enum import Enum
from pathlib import Path
from typing import Generic, List, Optional, TypeVar

from click import ClickException

from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag
from snowflake.cli._plugins.snowpark import package_utils
from snowflake.cli._plugins.snowpark.common import DEFAULT_RUNTIME
Expand All @@ -21,14 +19,14 @@
)
from snowflake.cli._plugins.snowpark.zipper import zip_dir
from snowflake.cli._plugins.workspace.context import ActionContext
from snowflake.cli.api.entities.common import EntityBase, get_sql_executor
from snowflake.cli.api.entities.common import EntityBase
from snowflake.cli.api.secure_path import SecurePath
from snowflake.connector import ProgrammingError

T = TypeVar("T")


class DeployMode(
class CreateMode(
str, Enum
): # This should probably be moved to some common place, think where
create = "CREATE"
Expand All @@ -37,39 +35,12 @@ class DeployMode(


class SnowparkEntity(EntityBase[Generic[T]]):

def __init__(self, *args, **kwargs):

if not FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled():
raise NotImplementedError("Streamlit entity is not implemented yet")
raise NotImplementedError("Snowpark entity is not implemented yet")
super().__init__(*args, **kwargs)

@property
def root(self):
return self._workspace_ctx.project_root

@property
def identifier(self):
return self.model.fqn.sql_identifier

@property
def fqn(self):
return self.model.fqn

@functools.cached_property
def _sql_executor(
self,
): # maybe this could be moved to parent class, as it is used in streamlit entity as well
return get_sql_executor()

@functools.cached_property
def _conn(self):
return self._sql_executor._conn # noqa

@property
def model(self):
return self._entity_model # noqa

def action_bundle(
self,
action_ctx: ActionContext,
Expand All @@ -80,7 +51,7 @@ def action_bundle(
allow_shared_libraries: bool = False,
*args,
**kwargs,
):
) -> List[Path]:
return self.bundle(
output_dir,
ignore_anaconda,
Expand All @@ -90,16 +61,16 @@ def action_bundle(
)

def action_deploy(
self, action_ctx: ActionContext, mode: DeployMode, *args, **kwargs
self, action_ctx: ActionContext, mode: CreateMode, *args, **kwargs
):
# TODO: After introducing bundle map, we should introduce file copying part here
return self._sql_executor.execute_query(self.get_deploy_sql(mode))
return self._execute_query(self.get_deploy_sql(mode))

def action_drop(self, action_ctx: ActionContext, *args, **kwargs):
return self._sql_executor.execute_query(self.get_drop_sql())
return self._execute_query(self.get_drop_sql())

def action_describe(self, action_ctx: ActionContext, *args, **kwargs):
return self._sql_executor.execute_query(self.get_describe_sql())
return self._execute_query(self.get_describe_sql())

def action_execute(
self,
Expand All @@ -108,9 +79,7 @@ def action_execute(
*args,
**kwargs,
):
return self._sql_executor.execute_query(
self.get_execute_sql(execution_arguments)
)
return self._execute_query(self.get_execute_sql(execution_arguments))

def bundle(
self,
Expand Down Expand Up @@ -174,7 +143,7 @@ def check_if_exists(
except ProgrammingError:
return False

def get_deploy_sql(self, mode: DeployMode):
def get_deploy_sql(self, mode: CreateMode):
query = [
f"{mode.value} {self.model.type.upper()} {self.identifier}",
"COPY GRANTS",
Expand All @@ -196,18 +165,9 @@ def get_deploy_sql(self, mode: DeployMode):

return "\n".join(query)

def get_describe_sql(self):
return f"DESCRIBE {self.model.type.upper()} {self.identifier}"

def get_drop_sql(self):
return f"DROP {self.model.type.upper()} {self.identifier}"

def get_execute_sql(self, execution_arguments: List[str] | None = None):
raise NotImplementedError

def get_usage_grant_sql(self, app_role: str):
return f"GRANT USAGE ON {self.model.type.upper()} {self.identifier} TO ROLE {app_role}"

def _process_requirements( # TODO: maybe leave all the logic with requirements here - so download, write requirements file etc.
self,
bundle_dir: Path,
Expand Down
7 changes: 2 additions & 5 deletions src/snowflake/cli/_plugins/streamlit/streamlit_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,12 @@ def get_deploy_sql(

return query + ";"

def get_drop_sql(self):
return f"DROP STREAMLIT {self._entity_model.fqn};"
def get_share_sql(self, to_role: str) -> str:
return f"GRANT USAGE ON STREAMLIT {self.model.fqn.sql_identifier} TO ROLE {to_role};"

def get_execute_sql(self):
return f"EXECUTE STREAMLIT {self._entity_model.fqn}();"

def get_share_sql(self, to_role: str) -> str:
return f"GRANT USAGE ON STREAMLIT {self.model.fqn.sql_identifier} TO ROLE {to_role};"

def get_usage_grant_sql(self, app_role: str, schema: Optional[str] = None) -> str:
entity_id = self.entity_id
streamlit_name = f"{schema}.{entity_id}" if schema else entity_id
Expand Down
47 changes: 45 additions & 2 deletions src/snowflake/cli/api/entities/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import functools
from enum import Enum
from pathlib import Path
from typing import Generic, Type, TypeVar, get_args

from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext
from snowflake.cli.api.cli_global_context import span
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.sql_execution import SqlExecutor
from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor


class EntityActions(str, Enum):
Expand Down Expand Up @@ -66,8 +71,8 @@ def __init__(self, entity_model: T, workspace_ctx: WorkspaceContext):
self._workspace_ctx = workspace_ctx

@property
def entity_id(self):
return self._entity_model.entity_id
def entity_id(self) -> str:
return self._entity_model.entity_id # type: ignore

@classmethod
def get_entity_model_type(cls) -> Type[T]:
Expand All @@ -92,6 +97,44 @@ def perform(
"""
return getattr(self, action)(action_ctx, *args, **kwargs)

@property
def root(self) -> Path:
return self._workspace_ctx.project_root

@property
def identifier(self) -> str:
return self.model.fqn.sql_identifier

@property
def fqn(self) -> FQN:
return self._entity_model.fqn # type: ignore[attr-defined]

@functools.cached_property
def _sql_executor(
self,
) -> SqlExecutor:
return get_sql_executor()

def _execute_query(self, sql: str) -> SnowflakeCursor:
return self._sql_executor.execute_query(sql)

@functools.cached_property
def _conn(self) -> SnowflakeConnection:
return self._sql_executor._conn # noqa

@property
def model(self):
return self._entity_model

def get_usage_grant_sql(self, app_role: str) -> str:
return f"GRANT USAGE ON {self.model.type.upper()} {self.identifier} TO ROLE {app_role};"

def get_describe_sql(self) -> str:
return f"DESCRIBE {self.model.type.upper()} {self.identifier};"

def get_drop_sql(self) -> str:
return f"DROP {self.model.type.upper()} {self.identifier};"


def get_sql_executor() -> SqlExecutor:
"""Returns an SQL Executor that uses the connection from the current CLI context"""
Expand Down
11 changes: 11 additions & 0 deletions tests/snowpark/__snapshots__/test_snowpark_entity.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,14 @@
HANDLER='app.func1_handler'
'''
# ---
# name: test_nativeapp_children_interface
'''
CREATE FUNCTION IDENTIFIER('func1')
COPY GRANTS
RETURNS string
LANGUAGE PYTHON
RUNTIME_VERSION '3.10'
IMPORTS=
HANDLER='app.func1_handler'
'''
# ---
59 changes: 50 additions & 9 deletions tests/snowpark/test_snowpark_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
AnacondaPackages,
AvailablePackage,
)
from snowflake.cli._plugins.snowpark.snowpark_entity import DeployMode, FunctionEntity, ProcedureEntity
from snowflake.cli._plugins.snowpark.snowpark_entity import (
CreateMode,
FunctionEntity,
ProcedureEntity,
)
from snowflake.cli._plugins.snowpark.snowpark_entity_model import FunctionEntityModel
from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext

Expand Down Expand Up @@ -45,29 +49,62 @@ def example_function_workspace(
),
)


def test_cannot_instantiate_without_feature_flag():
with pytest.raises(NotImplementedError) as err:
FunctionEntity()
assert str(err.value) == "Snowpark entities are not implemented yet"
assert str(err.value) == "Snowpark entity is not implemented yet"

with pytest.raises(NotImplementedError) as err:
ProcedureEntity()
assert str(err.value) == "Snowpark entities are not implemented yet"
assert str(err.value) == "Snowpark entity is not implemented yet"


@mock.patch(ANACONDA_PACKAGES)
def test_nativeapp_children_interface(
mock_anaconda, example_function_workspace, snapshot
):
mock_anaconda.return_value = AnacondaPackages(
{
"pandas": AvailablePackage("pandas", "1.2.3"),
"numpy": AvailablePackage("numpy", "1.2.3"),
"snowflake_snowpark_python": AvailablePackage(
"snowflake_snowpark_python", "1.2.3"
),
}
)

sl, action_context = example_function_workspace

sl.bundle(None, False, False, None, False)
bundle_artifact = (
sl.root / "output" / sl.model.stage / "my_snowpark_project" / "app.py"
)
deploy_sql_str = sl.get_deploy_sql(CreateMode.create)
grant_sql_str = sl.get_usage_grant_sql(app_role="app_role")

assert bundle_artifact.exists()
assert deploy_sql_str == snapshot
assert (
grant_sql_str
== f"GRANT USAGE ON FUNCTION IDENTIFIER('func1') TO ROLE app_role;"
)


@mock.patch(EXECUTE_QUERY)
def test_action_describe(mock_execute, example_function_workspace):
entity, action_context = example_function_workspace
result = entity.action_describe(action_context)

mock_execute.assert_called_with("DESCRIBE FUNCTION IDENTIFIER('func1')")
mock_execute.assert_called_with("DESCRIBE FUNCTION IDENTIFIER('func1');")


@mock.patch(EXECUTE_QUERY)
def test_action_drop(mock_execute, example_function_workspace):
entity, action_context = example_function_workspace
result = entity.action_drop(action_context)

mock_execute.assert_called_with("DROP FUNCTION IDENTIFIER('func1')")
mock_execute.assert_called_with("DROP FUNCTION IDENTIFIER('func1');")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -104,12 +141,12 @@ def test_bundle(mock_anaconda, example_function_workspace):

def test_describe_function_sql(example_function_workspace):
entity, _ = example_function_workspace
assert entity.get_describe_sql() == "DESCRIBE FUNCTION IDENTIFIER('func1')"
assert entity.get_describe_sql() == "DESCRIBE FUNCTION IDENTIFIER('func1');"


def test_drop_function_sql(example_function_workspace):
entity, _ = example_function_workspace
assert entity.get_drop_sql() == "DROP FUNCTION IDENTIFIER('func1')"
assert entity.get_drop_sql() == "DROP FUNCTION IDENTIFIER('func1');"


@pytest.mark.parametrize(
Expand All @@ -124,12 +161,16 @@ def test_function_get_execute_sql(

@pytest.mark.parametrize(
"mode",
[DeployMode.create, DeployMode.create_or_replace, DeployMode.create_if_not_exists],
[CreateMode.create, CreateMode.create_or_replace, CreateMode.create_if_not_exists],
)
def test_get_deploy_sql(mode, example_function_workspace, snapshot):
entity, _ = example_function_workspace
assert entity.get_deploy_sql(mode) == snapshot


def test_get_usage_grant_sql(example_function_workspace):
entity, _ = example_function_workspace
assert entity.get_usage_grant_sql("test_role") == "GRANT USAGE ON FUNCTION IDENTIFIER('func1') TO ROLE test_role"
assert (
entity.get_usage_grant_sql("test_role")
== "GRANT USAGE ON FUNCTION IDENTIFIER('func1') TO ROLE test_role;"
)
Loading

0 comments on commit 801b67d

Please sign in to comment.