Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jsikorski committed Dec 12, 2024
1 parent 3d1a0d5 commit 6ccfdc8
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 62 deletions.
65 changes: 53 additions & 12 deletions src/snowflake/cli/_plugins/streamlit/streamlit_entity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import shutil

from typing import Optional

from snowflake.cli._plugins.connection.util import make_snowsight_url
from snowflake.cli._plugins.streamlit.streamlit_entity_model import (
Expand All @@ -9,6 +10,8 @@
from snowflake.cli.api.entities.common import EntityBase, get_sql_executor
from snowflake.connector.cursor import SnowflakeCursor

from snowflake.cli.api.secure_path import SecurePath


class StreamlitEntity(EntityBase[StreamlitEntityModel]):
"""
Expand All @@ -31,6 +34,10 @@ def _sql_executor(self):
def _conn(self):
return self._sql_executor._conn # noqa

@property
def model(self):
return self._entity_model # noqa

def action_bundle(self, ctx: ActionContext, *args, **kwargs):
# get all files from the model
artifacts = self._entity_model.artifacts
Expand All @@ -45,10 +52,10 @@ def action_bundle(self, ctx: ActionContext, *args, **kwargs):
output_file = output_folder / file.name

if file.is_file():
shutil.copy(file, output_file)
SecurePath(file).copy(output_file)
elif file.is_dir():
output_file.mkdir(parents=True, exist_ok=True)
shutil.copytree(file, output_file, dirs_exist_ok=True)
SecurePath(file).copy(output_file, dirs_exist_ok=True)

output_files.append(output_file)

Expand All @@ -57,17 +64,17 @@ def action_bundle(self, ctx: ActionContext, *args, **kwargs):
def action_deploy(self, action_ctx: ActionContext, *args, **kwargs):
# After adding bundle map- we should use it's mapping here

query = self.action_get_deploy_sql(action_ctx, *args, **kwargs)
query = self.get_deploy_sql(action_ctx, *args, **kwargs)
result = self._sql_executor.execute_query(query)
return result

def action_drop(self, action_ctx: ActionContext, *args, **kwargs):
return self._sql_executor.execute_query(self.action_get_drop_sql(action_ctx))
return self._sql_executor.execute_query(self.get_drop_sql(action_ctx))

def action_execute(
self, action_ctx: ActionContext, *args, **kwargs
) -> SnowflakeCursor:
return self._sql_executor.execute_query(self.action_get_execute_sql(action_ctx))
return self._sql_executor.execute_query(self.get_execute_sql(action_ctx))

def action_get_url(
self, action_ctx: ActionContext, *args, **kwargs
Expand All @@ -77,21 +84,55 @@ def action_get_url(
self._conn, f"/#/streamlit-apps/{name.url_identifier}"
)

def action_get_deploy_sql(self, action_ctx: ActionContext, *args, **kwargs):
pass
def get_deploy_sql(self, action_ctx: ActionContext, if_not_exists: bool = False, replace: bool = False, from_stage_name: Optional[str] = None, *args, **kwargs):

if replace:
query = "CREATE OR REPLACE "
elif if_not_exists:
query = "CREATE IF NOT EXISTS "
else:
query = "CREATE "

query += f"STREAMLIT {self._entity_model.fqn.sql_identifier} \n"

if from_stage_name:
query += f" ROOT_LOCATION = '{from_stage_name}' \n"

query += f" MAIN_FILE = '{self._entity_model.main_file}' \n"

if self.model.imports:
query += self.model.get_imports_sql() + "\n"

if self.model.query_warehouse:
query += f" QUERY_WAREHOUSE = '{self.model.query_warehouse}' \n"

if self.model.title:
query += f" TITLE = '{self.model.title}' \n"

if self.model.comment:
query += f" COMMENT = '{self.model.comment}' \n"

if self.model.external_access_integrations:
query += self.model.get_external_access_integrations_sql() + "\n"

if self.model.secrets:
query += self.model.get_secrets_sql() + "\n"

return query


def action_share(
self, action_ctx: ActionContext, to_role: str, *args, **kwargs
) -> SnowflakeCursor:
return self._sql_executor.execute_query(self.get_share_sql(action_ctx, to_role))
return self._sql_executor.execute_query(self.get_usage_grant_sql(action_ctx, to_role))

def action_get_drop_sql(self, action_ctx: ActionContext, *args, **kwargs):
def get_drop_sql(self, action_ctx: ActionContext, *args, **kwargs):
return f"DROP STREAMLIT {self._entity_model.fqn}"

def action_get_execute_sql(self, action_ctx: ActionContext, *args, **kwargs):
def get_execute_sql(self, action_ctx: ActionContext, *args, **kwargs):
return f"EXECUTE STREAMLIT {self._entity_model.fqn}()"

def get_share_sql(
def get_usage_grant_sql(
self, action_ctx: ActionContext, to_role: str, *args, **kwargs
) -> str:
return f"GRANT USAGE ON STREAMLIT {{self._entity_model.fqn}} to role {to_role}"
45 changes: 5 additions & 40 deletions tests/streamlit/__snapshots__/test_commands.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,24 @@
| For field entities.my_streamlit you provided '{'artifacts': {'1': |
| 'foo_bar.py'}}'. This caused: Unable to extract tag using discriminator |
| 'type' |
| For field entities.my_streamlit you provided '{'artifacts': {'1': |
| 'foo_bar.py'}}'. This caused: Unable to extract tag using discriminator |
| 'type' |
+------------------------------------------------------------------------------+

'''
# ---
# name: test_deploy_put_files_on_stage[example_streamlit-merge_definition1]
list([
"create stage if not exists IDENTIFIER('MockDatabase.MockSchema.streamlit_stage')",
'put file://streamlit_app.py @MockDatabase.MockSchema.streamlit_stage/test_streamlit auto_compress=false parallel=4 overwrite=True',
'put file://environment.yml @MockDatabase.MockSchema.streamlit_stage/test_streamlit auto_compress=false parallel=4 overwrite=True',
'put file://pages/* @MockDatabase.MockSchema.streamlit_stage/test_streamlit/pages auto_compress=false parallel=4 overwrite=True',
'''
CREATE STREAMLIT IDENTIFIER('MockDatabase.MockSchema.test_streamlit')
ROOT_LOCATION = '@MockDatabase.MockSchema.streamlit_stage/test_streamlit'
MAIN_FILE = 'streamlit_app.py'
QUERY_WAREHOUSE = test_warehouse
TITLE = 'My Fancy Streamlit'
''',
'select system$get_snowsight_host()',
'select current_account_name()',
])
# ---
# name: test_deploy_put_files_on_stage[example_streamlit_v2-merge_definition0]
list([
"create stage if not exists IDENTIFIER('MockDatabase.MockSchema.streamlit_stage')",
'put file://streamlit_app.py @MockDatabase.MockSchema.streamlit_stage/test_streamlit auto_compress=false parallel=4 overwrite=True',
'''
CREATE STREAMLIT IDENTIFIER('MockDatabase.MockSchema.test_streamlit')
ROOT_LOCATION = '@MockDatabase.MockSchema.streamlit_stage/test_streamlit'
MAIN_FILE = 'streamlit_app.py'
QUERY_WAREHOUSE = test_warehouse
TITLE = 'My Fancy Streamlit'
''',
'select system$get_snowsight_host()',
'select current_account_name()',
])
# ---
# name: test_deploy_streamlit_nonexisting_file[example_streamlit-opts0]
'''
+- Error ----------------------------------------------------------------------+
| Provided file foo.bar does not exist |
+------------------------------------------------------------------------------+

'''
# ---
# name: test_deploy_streamlit_nonexisting_file[example_streamlit-opts1]
'''
+- Error ----------------------------------------------------------------------+
| Provided file foo.bar does not exist |
+------------------------------------------------------------------------------+

'''
# ---
# name: test_deploy_streamlit_nonexisting_file[example_streamlit_v2-opts2]
Expand All @@ -68,7 +33,7 @@
| Streamlit test_streamlit already exist. If you want to replace it use |
| --replace flag. |
+------------------------------------------------------------------------------+

'''
# ---
# name: test_deploy_streamlit_nonexisting_file[example_streamlit_v2-opts3]
Expand All @@ -77,7 +42,7 @@
| During evaluation of DefinitionV20 in project definition following errors |
| were encountered: |
| For field entities.test_streamlit.streamlit you provided '{'artifacts': |
| ['foo.bar'], 'identifier': {'name': 'test_streamlit'}, 'main_file': |
| ['foo.bar'], 'identifier': 'test_streamlit', 'main_file': |
| 'streamlit_app.py', 'query_warehouse': 'test_warehouse', 'stage': |
| 'streamlit', 'title': 'My Fancy Streamlit', 'type': 'streamlit'}'. This |
| caused: Value error, Specified artifact foo.bar does not exist locally. |
Expand Down
30 changes: 23 additions & 7 deletions tests/streamlit/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from snowflake.cli._plugins.streamlit.streamlit_entity_model import StreamlitEntityModel
from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext

STREAMLIT_NAME = "test_streamlit"
CONNECTOR = "snowflake.connector.connect"
CONTEXT = ""
EXECUTE_QUERY = "snowflake.cli.api.sql_execution.BaseSqlExecutor.execute_query"
Expand All @@ -19,7 +20,6 @@ def example_streamlit_workspace(project_directory):
with project_directory("example_streamlit_v2") as pdir:
with Path(pdir / "snowflake.yml").open() as definition_file:
definition = yaml.safe_load(definition_file)
print(definition)
model = StreamlitEntityModel(
**definition.get("entities", {}).get("test_streamlit")
)
Expand Down Expand Up @@ -50,13 +50,19 @@ def test_bundle(example_streamlit_workspace):
assert (output / "environment.yml").exists()
assert (output / "pages" / "my_page.py").exists()

@mock.patch(EXECUTE_QUERY)
def test_deploy(mock_execute, example_streamlit_workspace):
entity, action_ctx = example_streamlit_workspace
entity.action_deploy(action_ctx)

mock_execute.assert_called_with(f"CREATE STREAMLIT IDENTIFIER('{STREAMLIT_NAME}') \n MAIN_FILE = 'streamlit_app.py' \n QUERY_WAREHOUSE = 'test_warehouse' \n TITLE = 'My Fancy Streamlit' \n")

@mock.patch(EXECUTE_QUERY)
def test_drop(mock_execute, example_streamlit_workspace):
entity, action_ctx = example_streamlit_workspace
entity.action_drop(action_ctx)

mock_execute.assert_called_with("DROP STREAMLIT test_streamlit_deploy_snowcli")
mock_execute.assert_called_with(f"DROP STREAMLIT {STREAMLIT_NAME}")


@mock.patch(CONNECTOR)
Expand Down Expand Up @@ -97,18 +103,28 @@ def test_execute(mock_execute, example_streamlit_workspace):
entity, action_ctx = example_streamlit_workspace
entity.action_execute(action_ctx)

mock_execute.assert_called_with("EXECUTE STREAMLIT test_streamlit_deploy_snowcli()")
mock_execute.assert_called_with(f"EXECUTE STREAMLIT {STREAMLIT_NAME}()")


def test_get_execute_sql(example_streamlit_workspace):
entity, action_ctx = example_streamlit_workspace
execute_sql = entity.action_get_execute_sql(action_ctx)
execute_sql = entity.get_execute_sql(action_ctx)

assert execute_sql == "EXECUTE STREAMLIT test_streamlit_deploy_snowcli()"
assert execute_sql == f"EXECUTE STREAMLIT {STREAMLIT_NAME}()"


def test_get_drop_sql(example_streamlit_workspace):
entity, action_ctx = example_streamlit_workspace
drop_sql = entity.action_get_drop_sql(action_ctx)
drop_sql = entity.get_drop_sql(action_ctx)

assert drop_sql == f"DROP STREAMLIT {STREAMLIT_NAME}"

def test_get_deploy_sql(example_streamlit_workspace):
entity, action_ctx = example_streamlit_workspace
deploy_sql = entity.get_deploy_sql(action_ctx)

assert drop_sql == "DROP STREAMLIT test_streamlit_deploy_snowcli"
assert deploy_sql == f"""CREATE STREAMLIT IDENTIFIER('{STREAMLIT_NAME}')
MAIN_FILE = 'streamlit_app.py'
QUERY_WAREHOUSE = 'test_warehouse'
TITLE = 'My Fancy Streamlit'
"""
8 changes: 5 additions & 3 deletions tests/test_data/projects/example_streamlit_v2/snowflake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ definition_version: '2'
entities:
test_streamlit:
type: "streamlit"
identifier: test_streamlit_deploy_snowcli
identifier: test_streamlit
title: "My Fancy Streamlit"
stage: streamlit
query_warehouse: xsmall
query_warehouse: test_warehouse
main_file: streamlit_app.py
stage: streamlit
artifacts:
- streamlit_app.py
- utils/utils.py
- pages/
- environment.yml

0 comments on commit 6ccfdc8

Please sign in to comment.