diff --git a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py index 6def772525..7aaf10d6f0 100644 --- a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py +++ b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py @@ -1,7 +1,13 @@ +import functools +import shutil + +from snowflake.cli._plugins.connection.util import make_snowsight_url from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( StreamlitEntityModel, ) -from snowflake.cli.api.entities.common import EntityBase +from snowflake.cli._plugins.workspace.context import ActionContext +from snowflake.cli.api.entities.common import EntityBase, get_sql_executor +from snowflake.connector.cursor import SnowflakeCursor class StreamlitEntity(EntityBase[StreamlitEntityModel]): @@ -9,4 +15,83 @@ class StreamlitEntity(EntityBase[StreamlitEntityModel]): A Streamlit app. """ - pass + @property + def root(self): + return self._workspace_ctx.project_root + + @property + def artifacts(self): + return self._entity_model.artifacts + + @functools.cached_property + def _sql_executor(self): + return get_sql_executor() + + @functools.cached_property + def _conn(self): + return self._sql_executor._conn # noqa + + def action_bundle(self, ctx: ActionContext, *args, **kwargs): + # get all files from the model + artifacts = self._entity_model.artifacts + # get root + output_folder = self.root / "output" / self._entity_model.stage + output_folder.mkdir(parents=True, exist_ok=True) + + output_files = [] + + # This is far from , but will be replaced by bundlemap mappings. + for file in artifacts: + output_file = output_folder / file.name + + if file.is_file(): + shutil.copy(file, output_file) + elif file.is_dir(): + output_file.mkdir(parents=True, exist_ok=True) + shutil.copytree(file, output_file, dirs_exist_ok=True) + + output_files.append(output_file) + + return output_files + + def action_deploy(self, action_ctx: ActionContext, *args, **kwargs): + # What about file copying? + + query = self.action_get_deploy_sql(action_ctx, *args, **kwargs) + result = self._sql_executor.execute_query(query) + return result + + def action_drop(self, action_ctx: ActionContext, *args, **kwargs): + return self._sql_executor.execute_query(self.action_get_drop_sql(action_ctx)) + + def action_execute( + self, action_ctx: ActionContext, *args, **kwargs + ) -> SnowflakeCursor: + return self._sql_executor.execute_query(self.action_get_execute_sql(action_ctx)) + + def action_get_url( + self, action_ctx: ActionContext, *args, **kwargs + ): # maybe this should be a property + name = self._entity_model.fqn.using_connection(self._conn) + return make_snowsight_url( + self._conn, f"/#/streamlit-apps/{name.url_identifier}" + ) + + def action_get_deploy_sql(self, action_ctx: ActionContext, *args, **kwargs): + pass + + def action_share( + self, action_ctx: ActionContext, to_role: str, *args, **kwargs + ) -> SnowflakeCursor: + return self._sql_executor.execute_query(self.get_share_sql(action_ctx, to_role)) + + def action_get_drop_sql(self, action_ctx: ActionContext, *args, **kwargs): + return f"DROP STREAMLIT {self._entity_model.fqn}" + + def action_get_execute_sql(self, action_ctx: ActionContext, *args, **kwargs): + return f"EXECUTE STREAMLIT {self._entity_model.fqn}()" + + def get_share_sql( + self, action_ctx: ActionContext, to_role: str, *args, **kwargs + ) -> str: + return f"GRANT USAGE ON STREAMLIT {{self._entity_model.fqn}} to role {to_role}" diff --git a/tests/streamlit/test_actions.py b/tests/streamlit/test_actions.py new file mode 100644 index 0000000000..d0d234f394 --- /dev/null +++ b/tests/streamlit/test_actions.py @@ -0,0 +1,87 @@ +from pathlib import Path +from unittest import mock + +import pytest +import yaml +from snowflake.cli._plugins.streamlit.streamlit_entity import StreamlitEntity +from snowflake.cli._plugins.streamlit.streamlit_entity_model import StreamlitEntityModel +from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext + +CONNECTOR = "snowflake.connector.connect" +CONTEXT = "" +EXECUTE_QUERY = "snowflake.cli.api.sql_execution.BaseSqlExecutor.execute_query" + + +@pytest.fixture +def example_streamlit_workspace(project_directory): + with project_directory("example_streamlit_v2") as pdir: + with Path(pdir / "snowflake.yml").open() as definition_file: + definition = yaml.safe_load(definition_file) + model = StreamlitEntityModel( + **definition.get("entities", {}).get("my_streamlit") + ) + + workspace_context = WorkspaceContext( + console=mock.MagicMock(), + project_root=pdir, + get_default_role=lambda: "test_role", + get_default_warehouse=lambda: "test_warehouse", + ) + + return ( + StreamlitEntity(workspace_ctx=workspace_context, entity_model=model), + ActionContext( + get_entity=lambda *args: None, + ), + ) + + +def test_bundle(example_streamlit_workspace): + + entity, action_ctx = example_streamlit_workspace + entity.action_bundle(action_ctx) + + output = entity.root / "output" / entity._entity_model.stage # noqa + assert output.exists() + assert (output / "streamlit_app.py").exists() + assert (output / "environment.yml").exists() + assert (output / "pages" / "my_page.py").exists() + + +@mock.patch(EXECUTE_QUERY) +def test_drop(mock_execute, example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + entity.action_drop(action_ctx) + + mock_execute.assert_called_with("DROP STREAMLIT test_streamlit_deploy_snowcli") + + +@mock.patch(CONNECTOR) +def test_get_url(mock_connect, example_streamlit_workspace, mock_ctx): + entity, action_ctx = example_streamlit_workspace + cli_c + result = entity.action_get_url(action_ctx) + + mock_connect.assert_called() + + +@mock.patch(EXECUTE_QUERY) +def test_execute(mock_execute, example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + entity.action_execute(action_ctx) + + mock_execute.assert_called_with("EXECUTE STREAMLIT test_streamlit_deploy_snowcli()") + + +def test_get_execute_sql(example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + execute_sql = entity.action_get_execute_sql() + + assert execute_sql == "EXECUTE STREAMLIT test_streamlit_deploy_snowcli()" + + +def test_get_drop_sql(example_streamlit_workspace): + entity, action_ctx = example_streamlit_workspace + drop_sql = entity.action_get_drop_sql() + + assert drop_sql == "DROP STREAMLIT test_streamlit_deploy_snowcli" diff --git a/tests/test_data/projects/example_streamlit_v2/utils/utils.py b/tests/test_data/projects/example_streamlit_v2/utils/utils.py new file mode 100644 index 0000000000..e69de29bb2