diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py index 8f76c7ac4c..894e500a4c 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py @@ -833,6 +833,7 @@ def _bundle_children(self, action_ctx: ActionContext) -> List[str]: child_entity.get_deploy_sql( artifacts_dir=child_artifacts_dir.relative_to(self.deploy_root), schema=child_schema, + replace=True, ) ) if app_role: diff --git a/src/snowflake/cli/_plugins/nativeapp/feature_flags.py b/src/snowflake/cli/_plugins/nativeapp/feature_flags.py index dc7e93bf51..498c430c2c 100644 --- a/src/snowflake/cli/_plugins/nativeapp/feature_flags.py +++ b/src/snowflake/cli/_plugins/nativeapp/feature_flags.py @@ -18,7 +18,9 @@ @unique -class FeatureFlag(FeatureFlagMixin): +class FeatureFlag( + FeatureFlagMixin +): # TODO move this to snowflake.cli.api.feature_flags ENABLE_NATIVE_APP_PYTHON_SETUP = BooleanFlag( "ENABLE_NATIVE_APP_PYTHON_SETUP", False ) diff --git a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py index 3a66cbfff2..81840cd17f 100644 --- a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py +++ b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py @@ -1,7 +1,10 @@ import functools +from pathlib import Path from typing import Optional +from click import ClickException from snowflake.cli._plugins.connection.util import make_snowsight_url +from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( StreamlitEntityModel, ) @@ -16,6 +19,11 @@ class StreamlitEntity(EntityBase[StreamlitEntityModel]): A Streamlit app. """ + def __init__(self, *args, **kwargs): + if not FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled(): + raise NotImplementedError("Streamlit entity is not implemented yet") + super().__init__(*args, **kwargs) + @property def root(self): return self._workspace_ctx.project_root @@ -36,43 +44,23 @@ def _conn(self): def model(self): return self._entity_model # noqa - def action_bundle(self, ctx: ActionContext, *args, **kwargs): - # get all files from the model - artifacts = self._entity_model.artifacts - # get root - output_folder = self.root / "output" / self._entity_model.stage - output_folder.mkdir(parents=True, exist_ok=True) - - output_files = [] - - # This is far from , but will be replaced by bundlemap mappings. - for file in artifacts: - output_file = output_folder / file.name - - if file.is_file(): - SecurePath(file).copy(output_file) - elif file.is_dir(): - output_file.mkdir(parents=True, exist_ok=True) - SecurePath(file).copy(output_file, dirs_exist_ok=True) - - output_files.append(output_file) - - return output_files + def action_bundle(self, action_ctx: ActionContext, *args, **kwargs): + return self.bundle() def action_deploy(self, action_ctx: ActionContext, *args, **kwargs): # After adding bundle map- we should use it's mapping here - query = self.get_deploy_sql(action_ctx, *args, **kwargs) + query = self.get_deploy_sql() result = self._sql_executor.execute_query(query) return result def action_drop(self, action_ctx: ActionContext, *args, **kwargs): - return self._sql_executor.execute_query(self.get_drop_sql(action_ctx)) + return self._sql_executor.execute_query(self.get_drop_sql()) def action_execute( self, action_ctx: ActionContext, *args, **kwargs ) -> SnowflakeCursor: - return self._sql_executor.execute_query(self.get_execute_sql(action_ctx)) + return self._sql_executor.execute_query(self.get_execute_sql()) def action_get_url( self, action_ctx: ActionContext, *args, **kwargs @@ -82,15 +70,46 @@ def action_get_url( self._conn, f"/#/streamlit-apps/{name.url_identifier}" ) + def bundle(self, output_dir: Optional[Path] = None): + + if not output_dir: + output_dir = self.root / "output" / self._entity_model.stage + + artifacts = self._entity_model.artifacts + + output_dir.mkdir(parents=True, exist_ok=True) # type: ignore + + output_files = [] + + # This is far from , but will be replaced by bundlemap mappings. + for file in artifacts: + output_file = output_dir / file.name + + if file.is_file(): + SecurePath(file).copy(output_file) + elif file.is_dir(): + output_file.mkdir(parents=True, exist_ok=True) + SecurePath(file).copy(output_file, dirs_exist_ok=True) + + output_files.append(output_file) + + return output_files + + def action_share( + self, action_ctx: ActionContext, to_role: str, *args, **kwargs + ) -> SnowflakeCursor: + return self._sql_executor.execute_query(self.get_share_sql(to_role)) + def get_deploy_sql( self, - action_ctx: ActionContext, if_not_exists: bool = False, replace: bool = False, from_stage_name: Optional[str] = None, *args, **kwargs, ): + if replace and if_not_exists: + raise ClickException("Cannot specify both replace and if_not_exists") if replace: query = "CREATE OR REPLACE " @@ -124,22 +143,20 @@ def get_deploy_sql( if self.model.secrets: query += self.model.get_secrets_sql() + "\n" - return query + return query + ";" - def action_share( - self, action_ctx: ActionContext, to_role: str, *args, **kwargs - ) -> SnowflakeCursor: - return self._sql_executor.execute_query( - self.get_usage_grant_sql(action_ctx, to_role) - ) + def get_drop_sql(self): + return f"DROP STREAMLIT {self._entity_model.fqn};" - def get_drop_sql(self, action_ctx: ActionContext, *args, **kwargs): - return f"DROP STREAMLIT {self._entity_model.fqn}" + def get_execute_sql(self): + return f"EXECUTE STREAMLIT {self._entity_model.fqn}();" - def get_execute_sql(self, action_ctx: ActionContext, *args, **kwargs): - return f"EXECUTE STREAMLIT {self._entity_model.fqn}()" + def get_share_sql(self, to_role: str) -> str: + return f"grant usage on streamlit {self.model.fqn.sql_identifier} to role {to_role};" - def get_usage_grant_sql( - self, action_ctx: ActionContext, to_role: str, *args, **kwargs - ) -> str: - return f"GRANT USAGE ON STREAMLIT {{self._entity_model.fqn}} to role {to_role}" + def get_usage_grant_sql(self, app_role: str, schema: Optional[str] = None) -> str: + entity_id = self.entity_id + streamlit_name = f"{schema}.{entity_id}" if schema else entity_id + return ( + f"GRANT USAGE ON STREAMLIT {streamlit_name} TO APPLICATION ROLE {app_role};" + ) diff --git a/tests/streamlit/test_actions.py b/tests/streamlit/test_actions.py deleted file mode 100644 index 3a170ee187..0000000000 --- a/tests/streamlit/test_actions.py +++ /dev/null @@ -1,138 +0,0 @@ -from pathlib import Path -from unittest import mock - -import pytest -import yaml -from snowflake.cli._plugins.streamlit.streamlit_entity import StreamlitEntity -from snowflake.cli._plugins.streamlit.streamlit_entity_model import StreamlitEntityModel -from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext - -STREAMLIT_NAME = "test_streamlit" -CONNECTOR = "snowflake.connector.connect" -CONTEXT = "" -EXECUTE_QUERY = "snowflake.cli.api.sql_execution.BaseSqlExecutor.execute_query" - -GET_UI_PARAMETERS = "snowflake.cli._plugins.connection.util.get_ui_parameters" - - -@pytest.fixture -def example_streamlit_workspace(project_directory): - with project_directory("example_streamlit_v2") as pdir: - with Path(pdir / "snowflake.yml").open() as definition_file: - definition = yaml.safe_load(definition_file) - model = StreamlitEntityModel( - **definition.get("entities", {}).get("test_streamlit") - ) - - workspace_context = WorkspaceContext( - console=mock.MagicMock(), - project_root=pdir, - get_default_role=lambda: "test_role", - get_default_warehouse=lambda: "test_warehouse", - ) - - return ( - StreamlitEntity(workspace_ctx=workspace_context, entity_model=model), - ActionContext( - get_entity=lambda *args: None, - ), - ) - - -def test_bundle(example_streamlit_workspace): - - entity, action_ctx = example_streamlit_workspace - entity.action_bundle(action_ctx) - - output = entity.root / "output" / entity._entity_model.stage # noqa - assert output.exists() - assert (output / "streamlit_app.py").exists() - assert (output / "environment.yml").exists() - assert (output / "pages" / "my_page.py").exists() - - -@mock.patch(EXECUTE_QUERY) -def test_deploy(mock_execute, example_streamlit_workspace): - entity, action_ctx = example_streamlit_workspace - entity.action_deploy(action_ctx) - - mock_execute.assert_called_with( - f"CREATE STREAMLIT IDENTIFIER('{STREAMLIT_NAME}') \n MAIN_FILE = 'streamlit_app.py' \n QUERY_WAREHOUSE = 'test_warehouse' \n TITLE = 'My Fancy Streamlit' \n" - ) - - -@mock.patch(EXECUTE_QUERY) -def test_drop(mock_execute, example_streamlit_workspace): - entity, action_ctx = example_streamlit_workspace - entity.action_drop(action_ctx) - - mock_execute.assert_called_with(f"DROP STREAMLIT {STREAMLIT_NAME}") - - -@mock.patch(CONNECTOR) -@mock.patch( - GET_UI_PARAMETERS, - return_value={"UI_SNOWSIGHT_ENABLE_REGIONLESS_REDIRECT": "false"}, -) -@mock.patch("click.get_current_context") -def test_get_url( - mock_get_ctx, - mock_param, - mock_connect, - mock_cursor, - example_streamlit_workspace, - mock_ctx, -): - ctx = mock_ctx( - mock_cursor( - rows=[ - {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"SYSTEM$RETURN_CURRENT_ORG_NAME()": "FOOBARBAZ"}, - {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, - ], - columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], - ) - ) - mock_connect.return_value = ctx - mock_get_ctx.return_value = ctx - - entity, action_ctx = example_streamlit_workspace - result = entity.action_get_url(action_ctx) - - mock_connect.assert_called() - - -@mock.patch(EXECUTE_QUERY) -def test_execute(mock_execute, example_streamlit_workspace): - entity, action_ctx = example_streamlit_workspace - entity.action_execute(action_ctx) - - mock_execute.assert_called_with(f"EXECUTE STREAMLIT {STREAMLIT_NAME}()") - - -def test_get_execute_sql(example_streamlit_workspace): - entity, action_ctx = example_streamlit_workspace - execute_sql = entity.get_execute_sql(action_ctx) - - assert execute_sql == f"EXECUTE STREAMLIT {STREAMLIT_NAME}()" - - -def test_get_drop_sql(example_streamlit_workspace): - entity, action_ctx = example_streamlit_workspace - drop_sql = entity.get_drop_sql(action_ctx) - - assert drop_sql == f"DROP STREAMLIT {STREAMLIT_NAME}" - - -def test_get_deploy_sql(example_streamlit_workspace): - entity, action_ctx = example_streamlit_workspace - deploy_sql = entity.get_deploy_sql(action_ctx) - - assert ( - deploy_sql - == f"""CREATE STREAMLIT IDENTIFIER('{STREAMLIT_NAME}') - MAIN_FILE = 'streamlit_app.py' - QUERY_WAREHOUSE = 'test_warehouse' - TITLE = 'My Fancy Streamlit' -""" - ) diff --git a/tests/streamlit/test_streamlit_entity.py b/tests/streamlit/test_streamlit_entity.py index 315e34b8e5..dbbe3d6c22 100644 --- a/tests/streamlit/test_streamlit_entity.py +++ b/tests/streamlit/test_streamlit_entity.py @@ -1,20 +1,54 @@ from __future__ import annotations from pathlib import Path +from unittest import mock import pytest +import yaml from snowflake.cli._plugins.streamlit.streamlit_entity import ( StreamlitEntity, ) from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( StreamlitEntityModel, ) -from snowflake.cli._plugins.workspace.context import WorkspaceContext -from snowflake.cli.api.console import cli_console as cc -from snowflake.cli.api.project.definition_manager import DefinitionManager +from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext from tests.testing_utils.mock_config import mock_config_key +STREAMLIT_NAME = "test_streamlit" +CONNECTOR = "snowflake.connector.connect" +CONTEXT = "" +EXECUTE_QUERY = "snowflake.cli.api.sql_execution.BaseSqlExecutor.execute_query" + +GET_UI_PARAMETERS = "snowflake.cli._plugins.connection.util.get_ui_parameters" + + +@pytest.fixture +def example_streamlit_workspace(project_directory): + with mock_config_key("enable_native_app_children", True): + with project_directory("example_streamlit_v2") as pdir: + with Path(pdir / "snowflake.yml").open() as definition_file: + definition = yaml.safe_load(definition_file) + model = StreamlitEntityModel( + **definition.get("entities", {}).get("test_streamlit") + ) + + workspace_context = WorkspaceContext( + console=mock.MagicMock(), + project_root=pdir, + get_default_role=lambda: "test_role", + get_default_warehouse=lambda: "test_warehouse", + ) + + return ( + StreamlitEntity( + workspace_ctx=workspace_context, entity_model=model + ), + ActionContext( + get_entity=lambda *args: None, + ), + ) + def test_cannot_instantiate_without_feature_flag(): with pytest.raises(NotImplementedError) as err: @@ -22,32 +56,135 @@ def test_cannot_instantiate_without_feature_flag(): assert str(err.value) == "Streamlit entity is not implemented yet" -def test_nativeapp_children_interface(temp_dir): +def test_nativeapp_children_interface(example_streamlit_workspace, snapshot): + sl, action_context = example_streamlit_workspace + + sl.bundle() + bundle_artifact = sl.root / "output" / sl.model.stage / "streamlit_app.py" + deploy_sql_str = sl.get_deploy_sql() + grant_sql_str = sl.get_usage_grant_sql(app_role="app_role") + + assert bundle_artifact.exists() + assert deploy_sql_str == snapshot + assert ( + grant_sql_str == f"GRANT USAGE ON STREAMLIT None TO APPLICATION ROLE app_role;" + ) + + +def test_bundle(example_streamlit_workspace): with mock_config_key("enable_native_app_children", True): - dm = DefinitionManager() - ctx = WorkspaceContext( - console=cc, - project_root=dm.project_root, - get_default_role=lambda: "mock_role", - get_default_warehouse=lambda: "mock_warehouse", - ) - main_file = "main.py" - (Path(temp_dir) / main_file).touch() - model = StreamlitEntityModel( - type="streamlit", - main_file=main_file, - artifacts=[main_file], - ) - sl = StreamlitEntity(model, ctx) - - sl.bundle() - bundle_artifact = Path(temp_dir) / "output" / "deploy" / main_file - deploy_sql_str = sl.get_deploy_sql() - grant_sql_str = sl.get_usage_grant_sql(app_role="app_role") - - assert bundle_artifact.exists() - assert deploy_sql_str == "CREATE OR REPLACE STREAMLIT None MAIN_FILE='main.py';" - assert ( - grant_sql_str - == "GRANT USAGE ON STREAMLIT None TO APPLICATION ROLE app_role;" + entity, action_ctx = example_streamlit_workspace + entity.action_bundle(action_ctx) + + output = entity.root / "output" / entity._entity_model.stage # noqa + assert output.exists() + assert (output / "streamlit_app.py").exists() + assert (output / "environment.yml").exists() + assert (output / "pages" / "my_page.py").exists() + + +@mock.patch(EXECUTE_QUERY) +def test_deploy(mock_execute, example_streamlit_workspace): + with mock_config_key("enable_native_app_children", True): + entity, action_ctx = example_streamlit_workspace + entity.action_deploy(action_ctx) + + mock_execute.assert_called_with( + f"CREATE STREAMLIT IDENTIFIER('{STREAMLIT_NAME}') \n MAIN_FILE = 'streamlit_app.py' \n QUERY_WAREHOUSE = 'test_warehouse' \n TITLE = 'My Fancy Streamlit' \n;" + ) + + +@mock.patch(EXECUTE_QUERY) +def test_drop(mock_execute, example_streamlit_workspace): + with mock_config_key("enable_native_app_children", True): + entity, action_ctx = example_streamlit_workspace + entity.action_drop(action_ctx) + + mock_execute.assert_called_with(f"DROP STREAMLIT {STREAMLIT_NAME};") + + +@mock.patch(CONNECTOR) +@mock.patch( + GET_UI_PARAMETERS, + return_value={"UI_SNOWSIGHT_ENABLE_REGIONLESS_REDIRECT": "false"}, +) +@mock.patch("click.get_current_context") +def test_get_url( + mock_get_ctx, + mock_param, + mock_connect, + mock_cursor, + example_streamlit_workspace, + mock_ctx, +): + ctx = mock_ctx( + mock_cursor( + rows=[ + {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, + {"SYSTEM$RETURN_CURRENT_ORG_NAME()": "FOOBARBAZ"}, + {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, + ], + columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], ) + ) + mock_connect.return_value = ctx + mock_get_ctx.return_value = ctx + + entity, action_ctx = example_streamlit_workspace + with mock_config_key("enable_native_app_children", True): + result = entity.action_get_url(action_ctx) + + mock_connect.assert_called() + + +@mock.patch(EXECUTE_QUERY) +def test_share(mock_connect, example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + entity.action_share(action_ctx, to_role="test_role") + + mock_connect.assert_called_with( + "grant usage on streamlit IDENTIFIER('test_streamlit') to role test_role;" + ) + + +@mock.patch(EXECUTE_QUERY) +def test_execute(mock_execute, example_streamlit_workspace): + with mock_config_key("enable_native_app_children", True): + entity, action_ctx = example_streamlit_workspace + entity.action_execute(action_ctx) + + mock_execute.assert_called_with(f"EXECUTE STREAMLIT {STREAMLIT_NAME}();") + + +def test_get_execute_sql(example_streamlit_workspace): + with mock_config_key("enable_native_app_children", True): + entity, action_ctx = example_streamlit_workspace + execute_sql = entity.get_execute_sql() + + assert execute_sql == f"EXECUTE STREAMLIT {STREAMLIT_NAME}();" + + +def test_get_drop_sql(example_streamlit_workspace): + with mock_config_key("enable_native_app_children", True): + entity, action_ctx = example_streamlit_workspace + drop_sql = entity.get_drop_sql() + + assert drop_sql == f"DROP STREAMLIT {STREAMLIT_NAME};" + + +@pytest.mark.parametrize( + "kwargs", + [ + {"replace": True}, + {"if_not_exists": True}, + {"from_stage_name": "test_stage"}, + {"from_stage_name": "test_stage", "replace": True}, + {"from_stage_name": "test_stage", "if_not_exists": True}, + ], +) +def test_get_deploy_sql(kwargs, example_streamlit_workspace, snapshot): + with mock_config_key("enable_native_app_children", True): + entity, action_ctx = example_streamlit_workspace + deploy_sql = entity.get_deploy_sql(**kwargs) + + assert deploy_sql == snapshot