-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4e59e0d
commit cb798a6
Showing
3 changed files
with
174 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,97 @@ | ||
import functools | ||
import shutil | ||
|
||
from snowflake.cli._plugins.connection.util import make_snowsight_url | ||
from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( | ||
StreamlitEntityModel, | ||
) | ||
from snowflake.cli.api.entities.common import EntityBase | ||
from snowflake.cli._plugins.workspace.context import ActionContext | ||
from snowflake.cli.api.entities.common import EntityBase, get_sql_executor | ||
from snowflake.connector.cursor import SnowflakeCursor | ||
|
||
|
||
class StreamlitEntity(EntityBase[StreamlitEntityModel]): | ||
""" | ||
A Streamlit app. | ||
""" | ||
|
||
pass | ||
@property | ||
def root(self): | ||
return self._workspace_ctx.project_root | ||
|
||
@property | ||
def artifacts(self): | ||
return self._entity_model.artifacts | ||
|
||
@functools.cached_property | ||
def _sql_executor(self): | ||
return get_sql_executor() | ||
|
||
@functools.cached_property | ||
def _conn(self): | ||
return self._sql_executor._conn # noqa | ||
|
||
def action_bundle(self, ctx: ActionContext, *args, **kwargs): | ||
# get all files from the model | ||
artifacts = self._entity_model.artifacts | ||
# get root | ||
output_folder = self.root / "output" / self._entity_model.stage | ||
output_folder.mkdir(parents=True, exist_ok=True) | ||
|
||
output_files = [] | ||
|
||
# This is far from , but will be replaced by bundlemap mappings. | ||
for file in artifacts: | ||
output_file = output_folder / file.name | ||
|
||
if file.is_file(): | ||
shutil.copy(file, output_file) | ||
elif file.is_dir(): | ||
output_file.mkdir(parents=True, exist_ok=True) | ||
shutil.copytree(file, output_file, dirs_exist_ok=True) | ||
|
||
output_files.append(output_file) | ||
|
||
return output_files | ||
|
||
def action_deploy(self, action_ctx: ActionContext, *args, **kwargs): | ||
# What about file copying? | ||
|
||
query = self.action_get_deploy_sql(action_ctx, *args, **kwargs) | ||
result = self._sql_executor.execute_query(query) | ||
return result | ||
|
||
def action_drop(self, action_ctx: ActionContext, *args, **kwargs): | ||
return self._sql_executor.execute_query(self.action_get_drop_sql(action_ctx)) | ||
|
||
def action_execute( | ||
self, action_ctx: ActionContext, *args, **kwargs | ||
) -> SnowflakeCursor: | ||
return self._sql_executor.execute_query(self.action_get_execute_sql(action_ctx)) | ||
|
||
def action_get_url( | ||
self, action_ctx: ActionContext, *args, **kwargs | ||
): # maybe this should be a property | ||
name = self._entity_model.fqn.using_connection(self._conn) | ||
return make_snowsight_url( | ||
self._conn, f"/#/streamlit-apps/{name.url_identifier}" | ||
) | ||
|
||
def action_get_deploy_sql(self, action_ctx: ActionContext, *args, **kwargs): | ||
pass | ||
|
||
def action_share( | ||
self, action_ctx: ActionContext, to_role: str, *args, **kwargs | ||
) -> SnowflakeCursor: | ||
return self._sql_executor.execute_query(self.get_share_sql(action_ctx, to_role)) | ||
|
||
def action_get_drop_sql(self, action_ctx: ActionContext, *args, **kwargs): | ||
return f"DROP STREAMLIT {self._entity_model.fqn}" | ||
|
||
def action_get_execute_sql(self, action_ctx: ActionContext, *args, **kwargs): | ||
return f"EXECUTE STREAMLIT {self._entity_model.fqn}()" | ||
|
||
def get_share_sql( | ||
self, action_ctx: ActionContext, to_role: str, *args, **kwargs | ||
) -> str: | ||
return f"GRANT USAGE ON STREAMLIT {{self._entity_model.fqn}} to role {to_role}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from pathlib import Path | ||
from unittest import mock | ||
|
||
import pytest | ||
import yaml | ||
from snowflake.cli._plugins.streamlit.streamlit_entity import StreamlitEntity | ||
from snowflake.cli._plugins.streamlit.streamlit_entity_model import StreamlitEntityModel | ||
from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext | ||
|
||
CONNECTOR = "snowflake.connector.connect" | ||
CONTEXT = "" | ||
EXECUTE_QUERY = "snowflake.cli.api.sql_execution.BaseSqlExecutor.execute_query" | ||
|
||
|
||
@pytest.fixture | ||
def example_streamlit_workspace(project_directory): | ||
with project_directory("example_streamlit_v2") as pdir: | ||
with Path(pdir / "snowflake.yml").open() as definition_file: | ||
definition = yaml.safe_load(definition_file) | ||
model = StreamlitEntityModel( | ||
**definition.get("entities", {}).get("my_streamlit") | ||
) | ||
|
||
workspace_context = WorkspaceContext( | ||
console=mock.MagicMock(), | ||
project_root=pdir, | ||
get_default_role=lambda: "test_role", | ||
get_default_warehouse=lambda: "test_warehouse", | ||
) | ||
|
||
return ( | ||
StreamlitEntity(workspace_ctx=workspace_context, entity_model=model), | ||
ActionContext( | ||
get_entity=lambda *args: None, | ||
), | ||
) | ||
|
||
|
||
def test_bundle(example_streamlit_workspace): | ||
|
||
entity, action_ctx = example_streamlit_workspace | ||
entity.action_bundle(action_ctx) | ||
|
||
output = entity.root / "output" / entity._entity_model.stage # noqa | ||
assert output.exists() | ||
assert (output / "streamlit_app.py").exists() | ||
assert (output / "environment.yml").exists() | ||
assert (output / "pages" / "my_page.py").exists() | ||
|
||
|
||
@mock.patch(EXECUTE_QUERY) | ||
def test_drop(mock_execute, example_streamlit_workspace): | ||
entity, action_ctx = example_streamlit_workspace | ||
entity.action_drop(action_ctx) | ||
|
||
mock_execute.assert_called_with("DROP STREAMLIT test_streamlit_deploy_snowcli") | ||
|
||
|
||
@mock.patch(CONNECTOR) | ||
def test_get_url(mock_connect, example_streamlit_workspace, mock_ctx): | ||
entity, action_ctx = example_streamlit_workspace | ||
cli_c | ||
result = entity.action_get_url(action_ctx) | ||
|
||
mock_connect.assert_called() | ||
|
||
|
||
@mock.patch(EXECUTE_QUERY) | ||
def test_execute(mock_execute, example_streamlit_workspace): | ||
entity, action_ctx = example_streamlit_workspace | ||
entity.action_execute(action_ctx) | ||
|
||
mock_execute.assert_called_with("EXECUTE STREAMLIT test_streamlit_deploy_snowcli()") | ||
|
||
|
||
def test_get_execute_sql(example_streamlit_workspace): | ||
entity, action_ctx = example_streamlit_workspace | ||
execute_sql = entity.action_get_execute_sql() | ||
|
||
assert execute_sql == "EXECUTE STREAMLIT test_streamlit_deploy_snowcli()" | ||
|
||
|
||
def test_get_drop_sql(example_streamlit_workspace): | ||
entity, action_ctx = example_streamlit_workspace | ||
drop_sql = entity.action_get_drop_sql() | ||
|
||
assert drop_sql == "DROP STREAMLIT test_streamlit_deploy_snowcli" |
Empty file.