Skip to content

Commit

Permalink
Refactored script execution logic and integ test
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-melnacouzi committed Jul 18, 2024
1 parent b0d718c commit a080e15
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 159 deletions.
12 changes: 6 additions & 6 deletions src/snowflake/cli/plugins/nativeapp/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,18 @@ def __init__(self, item: str, expected_owner: str, actual_owner: str):
)


class MissingPackageScriptError(ClickException):
"""A referenced package script was not found."""
class MissingScriptError(ClickException):
"""A referenced script was not found."""

def __init__(self, relpath: str):
super().__init__(f'Package script "{relpath}" does not exist')
super().__init__(f'Script "{relpath}" does not exist')


class InvalidPackageScriptError(ClickException):
"""A referenced package script had syntax error(s)."""
class InvalidScriptError(ClickException):
"""A referenced script had syntax error(s)."""

def __init__(self, relpath: str, err: jinja2.TemplateError):
super().__init__(f'Package script "{relpath}" is not a valid jinja2 template')
super().__init__(f'Script "{relpath}" does not contain a valid template')
self.err = err


Expand Down
43 changes: 26 additions & 17 deletions src/snowflake/cli/plugins/nativeapp/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
from snowflake.cli.plugins.nativeapp.exceptions import (
ApplicationPackageAlreadyExistsError,
ApplicationPackageDoesNotExistError,
InvalidPackageScriptError,
MissingPackageScriptError,
InvalidScriptError,
MissingScriptError,
SetupScriptFailedValidation,
UnexpectedOwnerError,
)
Expand Down Expand Up @@ -561,6 +561,27 @@ def create_app_package(self) -> None:
)
)

def _expand_script_templates(
self, env: jinja2.Environment, jinja_context, scripts: List[str]
):
queued_queries = []
for relpath in scripts:
try:
template = env.get_template(relpath)
result = template.render(**jinja_context)
queued_queries.append(result)

except jinja2.TemplateNotFound as e:
raise MissingScriptError(e.name) from e

except jinja2.TemplateSyntaxError as e:
raise InvalidScriptError(e.name, e) from e

except jinja2.UndefinedError as e:
raise InvalidScriptError(relpath, e) from e

return queued_queries

def _apply_package_scripts(self) -> None:
"""
Assuming the application package exists and we are using the correct role,
Expand All @@ -572,21 +593,9 @@ def _apply_package_scripts(self) -> None:
undefined=jinja2.StrictUndefined,
)

queued_queries = []
for relpath in self.package_scripts:
try:
template = env.get_template(relpath)
result = template.render(dict(package_name=self.package_name))
queued_queries.append(result)

except jinja2.TemplateNotFound as e:
raise MissingPackageScriptError(e.name)

except jinja2.TemplateSyntaxError as e:
raise InvalidPackageScriptError(e.name, e)

except jinja2.UndefinedError as e:
raise InvalidPackageScriptError(relpath, e)
queued_queries = self._expand_script_templates(
env, dict(package_name=self.package_name), self.package_scripts
)

# once we're sure all the templates expanded correctly, execute all of them
with self.use_package_warehouse():
Expand Down
39 changes: 23 additions & 16 deletions src/snowflake/cli/plugins/nativeapp/run_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,41 +139,48 @@ class NativeAppRunProcessor(NativeAppManager, NativeAppCommandProcessor):
def __init__(self, project_definition: NativeApp, project_root: Path):
super().__init__(project_definition, project_root)

def _execute_sql_script(self, sql_script_path):
def _execute_sql_script(
self, script_content: str, database_name: Optional[str] = None
):
"""
Executing the SQL script in the provided file path after expanding template variables.
Executing the provided SQL script content.
This assumes that a relevant warehouse is already active.
Consequently, "use database" will be executed first if it is set in definition file or in the current connection.
If database_name is passed in, it will be used first.
"""
env = get_sql_cli_jinja_env(
loader=jinja2.loaders.FileSystemLoader(self.project_root)
)

try:
if self._conn.database:
self._execute_query(f"use database {self._conn.database}")
if database_name:
self._execute_query(f"use database {database_name}")

context_data = cli_context.template_context
sql_script_template = env.get_template(sql_script_path)
sql_script = sql_script_template.render(**context_data)

self._execute_queries(sql_script)
self._execute_queries(script_content)
except ProgrammingError as err:
generic_sql_error_handler(err)

def _execute_post_deploy_hooks(self):
post_deploy_script_hooks = self.app_post_deploy_hooks
if post_deploy_script_hooks:
with cc.phase("Executing application post-deploy actions"):
sql_scripts_paths = []
for hook in post_deploy_script_hooks:
if hook.sql_script:
cc.step(f"Executing SQL script: {hook.sql_script}")
self._execute_sql_script(hook.sql_script)
sql_scripts_paths.append(hook.sql_script)
else:
raise ValueError(
f"Unsupported application post-deploy hook type: {hook}"
)

env = get_sql_cli_jinja_env(
loader=jinja2.loaders.FileSystemLoader(self.project_root)
)
scripts_content_list = self._expand_script_templates(
env, cli_context.template_context, sql_scripts_paths
)

for index, sql_script_path in enumerate(sql_scripts_paths):
cc.step(f"Executing SQL script: {sql_script_path}")
self._execute_sql_script(
scripts_content_list[index], self._conn.database
)

def get_all_existing_versions(self) -> SnowflakeCursor:
"""
Get all existing versions, if defined, for an application package.
Expand Down
10 changes: 5 additions & 5 deletions tests/nativeapp/test_package_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
)
from snowflake.cli.api.project.definition_manager import DefinitionManager
from snowflake.cli.plugins.nativeapp.exceptions import (
InvalidPackageScriptError,
MissingPackageScriptError,
InvalidScriptError,
MissingScriptError,
)
from snowflake.cli.plugins.nativeapp.run_processor import NativeAppRunProcessor
from snowflake.connector import ProgrammingError
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_package_scripts_without_conn_info_succeeds(
def test_missing_package_script(mock_execute, project_definition_files):
working_dir: Path = project_definition_files[0].parent
native_app_manager = _get_na_manager(str(working_dir))
with pytest.raises(MissingPackageScriptError):
with pytest.raises(MissingScriptError):
(working_dir / "002-shared.sql").unlink()
native_app_manager._apply_package_scripts() # noqa: SLF001

Expand All @@ -211,7 +211,7 @@ def test_missing_package_script(mock_execute, project_definition_files):
def test_invalid_package_script(mock_execute, project_definition_files):
working_dir: Path = project_definition_files[0].parent
native_app_manager = _get_na_manager(str(working_dir))
with pytest.raises(InvalidPackageScriptError):
with pytest.raises(InvalidScriptError):
second_file = working_dir / "002-shared.sql"
second_file.unlink()
second_file.write_text("select * from {{ package_name")
Expand All @@ -226,7 +226,7 @@ def test_invalid_package_script(mock_execute, project_definition_files):
def test_undefined_var_package_script(mock_execute, project_definition_files):
working_dir: Path = project_definition_files[0].parent
native_app_manager = _get_na_manager(str(working_dir))
with pytest.raises(InvalidPackageScriptError):
with pytest.raises(InvalidScriptError):
second_file = working_dir / "001-shared.sql"
second_file.unlink()
second_file.write_text("select * from {{ abc }}")
Expand Down
3 changes: 2 additions & 1 deletion tests/nativeapp/test_post_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from snowflake.cli.api.project.schemas.native_app.application import (
ApplicationPostDeployHook,
)
from snowflake.cli.plugins.nativeapp.exceptions import MissingScriptError
from snowflake.cli.plugins.nativeapp.run_processor import NativeAppRunProcessor

from tests.nativeapp.patch_utils import mock_connection
Expand Down Expand Up @@ -139,7 +140,7 @@ def test_missing_sql_script(
with project_directory("napp_post_deploy_missing_file") as project_dir:
processor = _get_run_processor(str(project_dir))

with pytest.raises(FileNotFoundError) as err:
with pytest.raises(MissingScriptError) as err:
processor._execute_post_deploy_hooks() # noqa SLF001


Expand Down
91 changes: 24 additions & 67 deletions tests_integration/nativeapp/test_init_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
not_contains_row_with,
row_from_snowflake_session,
)
from tests_integration.testing_utils.working_directory_utils import (
WorkingDirectoryChanger,
)

USER_NAME = f"user_{uuid.uuid4().hex}"
TEST_ENV = generate_user_env(USER_NAME)
Expand Down Expand Up @@ -423,93 +426,39 @@ def test_nativeapp_init_from_repo_with_single_template(
# Tests that application post-deploy scripts are executed by creating a post_deploy_log table and having each post-deploy script add a record to it
@pytest.mark.integration
@pytest.mark.parametrize("is_versioned", [True, False])
@pytest.mark.parametrize("with_project_flag", [True, False])
def test_nativeapp_app_post_deploy(
runner, snowflake_session, project_directory, is_versioned
runner, snowflake_session, project_directory, is_versioned, with_project_flag
):
version = "v1"
project_name = "myapp"
app_name = f"{project_name}_{USER_NAME}"

def run():
"""(maybe) create a version, then snow app run"""
if is_versioned:
result = runner.invoke_with_connection_json(
["app", "version", "create", version],
env=TEST_ENV,
)
assert result.exit_code == 0

run_args = ["--version", version] if is_versioned else []
result = runner.invoke_with_connection_json(
["app", "run"] + run_args,
env=TEST_ENV,
)
assert result.exit_code == 0

with project_directory("napp_application_post_deploy") as tmp_dir:
try:
# First run, application is created (and maybe a version)
run()

# Verify both scripts were executed
assert row_from_snowflake_session(
snowflake_session.execute_string(
f"select * from {app_name}.public.post_deploy_log",
)
) == [
{"TEXT": "post-deploy-part-1"},
{"TEXT": "post-deploy-part-2"},
]

# Second run, application is upgraded
run()

# Verify both scripts were executed
assert row_from_snowflake_session(
snowflake_session.execute_string(
f"select * from {app_name}.public.post_deploy_log",
)
) == [
{"TEXT": "post-deploy-part-1"},
{"TEXT": "post-deploy-part-2"},
{"TEXT": "post-deploy-part-1"},
{"TEXT": "post-deploy-part-2"},
]
version_run_args = ["--version", version] if is_versioned else []
project_args = ["--project", f"{tmp_dir}"] if with_project_flag else []

finally:
# need to drop the version before we can teardown
def run():
"""(maybe) create a version, then snow app run"""
if is_versioned:
result = runner.invoke_with_connection_json(
["app", "version", "drop", version, "--force"],
["app", "version", "create", version] + project_args,
env=TEST_ENV,
)
assert result.exit_code == 0

result = runner.invoke_with_connection_json(
["app", "teardown", "--force"],
["app", "run"] + version_run_args + project_args,
env=TEST_ENV,
)
assert result.exit_code == 0

if with_project_flag:
working_directory_changer = WorkingDirectoryChanger()
working_directory_changer.change_working_directory_to("app")

# Tests that application post-deploy scripts are executed even when they are used with --project and project is in another directory
@pytest.mark.integration
def test_nativeapp_app_post_deploy_with_project_subdirectory(
runner, snowflake_session, project_directory
):
project_name = "myapp"
app_name = f"{project_name}_{USER_NAME}"

def run():
result = runner.invoke_with_connection_json(
["app", "run", "--project", "project_subdir"],
env=TEST_ENV,
)
assert result.exit_code == 0

with project_directory("napp_application_with_project_subdir") as tmp_dir:
try:
# First run, application is created
# First run, application is created (and maybe a version)
run()

# Verify both scripts were executed
Expand Down Expand Up @@ -538,8 +487,16 @@ def run():
]

finally:
# need to drop the version before we can teardown
if is_versioned:
result = runner.invoke_with_connection_json(
["app", "version", "drop", version, "--force"] + project_args,
env=TEST_ENV,
)
assert result.exit_code == 0

result = runner.invoke_with_connection_json(
["app", "teardown", "--force", "--project", "project_subdir"],
["app", "teardown", "--force"] + project_args,
env=TEST_ENV,
)
assert result.exit_code == 0
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

0 comments on commit a080e15

Please sign in to comment.