Skip to content

Commit

Permalink
Snow 1625040 unify sql template syntax (#1458)
Browse files Browse the repository at this point in the history
* Add <% ... %> syntax to SQL rendering

* add unit tests

* fix nativeapp rendering

* update release notes

* Add integration tests

* update nativeapp unit tests

* Fix Windows paths

* self-review

* Change SQL rendering to choose syntax depending on template

* revert nativeapp changes

* refactor nativeapp usage

* use SecurePath to open files
  • Loading branch information
sfc-gh-pczajka authored and sfc-gh-jvasquezrojas committed Aug 26, 2024
1 parent 56af21a commit ae6de77
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 60 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
* Added `snow spcs service execute-job` command, which supports creating and executing a job service in the current schema.
* Added `snow app events` command to fetch logs and traces from local and customer app installations.
* Added support for external access (api integrations and secrets) in Streamlit.
* Added support for `<% ... %>` syntax in SQL templating.
* Support multiple Streamlit application in single snowflake.yml project definition file.

## Fixes and improvements
Expand Down
56 changes: 30 additions & 26 deletions src/snowflake/cli/_plugins/nativeapp/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from functools import cached_property
from pathlib import Path
from textwrap import dedent
from typing import Any, Generator, List, NoReturn, Optional, TypedDict
from typing import Any, Callable, Dict, Generator, List, NoReturn, Optional, TypedDict

import jinja2
from click import ClickException
Expand Down Expand Up @@ -67,7 +67,6 @@
)
from snowflake.cli._plugins.stage.manager import StageManager
from snowflake.cli._plugins.stage.utils import print_diff_to_console
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli.api.console import cli_console as cc
from snowflake.cli.api.errno import (
DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED,
Expand All @@ -84,9 +83,13 @@
identifier_for_url,
unquote_identifier,
)
from snowflake.cli.api.rendering.jinja import (
jinja_render_from_str,
)
from snowflake.cli.api.rendering.sql_templates import (
get_sql_cli_jinja_env,
snowflake_sql_jinja_render,
)
from snowflake.cli.api.secure_path import UNLIMITED, SecurePath
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.connector import DictCursor, ProgrammingError

Expand Down Expand Up @@ -576,30 +579,36 @@ def create_app_package(self) -> None:
)
)

def _expand_script_templates(
self, env: jinja2.Environment, jinja_context: dict[str, Any], scripts: List[str]
def _render_script_templates(
self,
render_from_str: Callable[[str, Dict[str, Any]], str],
jinja_context: dict[str, Any],
scripts: List[str],
) -> List[str]:
"""
Input:
- env: Jinja2 environment
- render_from_str: function which renders a jinja template from a string and jinja context
- jinja_context: a dictionary with the jinja context
- scripts: list of scripts that need to be expanded with Jinja
- scripts: list of script paths relative to the project root
Returns:
- List of expanded scripts content.
- List of rendered scripts content
Size of the return list is the same as the size of the input scripts list.
"""
scripts_contents = []
for relpath in scripts:
script_full_path = SecurePath(self.project_root) / relpath
try:
template = env.get_template(relpath)
result = template.render(**jinja_context)
template_content = script_full_path.read_text(
file_size_limit_mb=UNLIMITED
)
result = render_from_str(template_content, jinja_context)
scripts_contents.append(result)

except jinja2.TemplateNotFound as e:
raise MissingScriptError(e.name) from e
except FileNotFoundError as e:
raise MissingScriptError(relpath) from e

except jinja2.TemplateSyntaxError as e:
raise InvalidScriptError(e.name, e, e.lineno) from e
raise InvalidScriptError(relpath, e, e.lineno) from e

except jinja2.UndefinedError as e:
raise InvalidScriptError(relpath, e) from e
Expand All @@ -617,14 +626,10 @@ def _apply_package_scripts(self) -> None:
"WARNING: native_app.package.scripts is deprecated. Please migrate to using native_app.package.post_deploy."
)

env = jinja2.Environment(
loader=jinja2.loaders.FileSystemLoader(self.project_root),
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)

queued_queries = self._expand_script_templates(
env, dict(package_name=self.package_name), self.package_scripts
queued_queries = self._render_script_templates(
jinja_render_from_str,
dict(package_name=self.package_name),
self.package_scripts,
)

# once we're sure all the templates expanded correctly, execute all of them
Expand Down Expand Up @@ -678,11 +683,10 @@ def _execute_post_deploy_hooks(
f"Unsupported {deployed_object_type} post-deploy hook type: {hook}"
)

env = get_sql_cli_jinja_env(
loader=jinja2.loaders.FileSystemLoader(self.project_root)
)
scripts_content_list = self._expand_script_templates(
env, get_cli_context().template_context, sql_scripts_paths
scripts_content_list = self._render_script_templates(
snowflake_sql_jinja_render,
{},
sql_scripts_paths,
)

for index, sql_script_path in enumerate(sql_scripts_paths):
Expand Down
36 changes: 28 additions & 8 deletions src/snowflake/cli/api/rendering/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pathlib import Path
from textwrap import dedent
from typing import Dict, Optional
from typing import Any, Dict, Optional

import jinja2
from jinja2 import Environment, StrictUndefined, loaders
Expand Down Expand Up @@ -82,8 +82,32 @@ def getitem(self, obj, argument):
return self.undefined(obj=obj, name=argument)


def _get_jinja_env(loader: Optional[loaders.BaseLoader] = None) -> Environment:
return env_bootstrap(
IgnoreAttrEnvironment(
loader=loader or loaders.BaseLoader(),
keep_trailing_newline=True,
undefined=StrictUndefined,
)
)


def jinja_render_from_str(template_content: str, data: Dict[str, Any]) -> str:
"""
Renders a jinja template and outputs either the rendered contents as string or writes to a file.
Args:
template_content (str): template contents
data (dict): A dictionary of jinja variables and their actual values
Returns:
None if file path is provided, else returns the rendered string.
"""
return _get_jinja_env().from_string(template_content).render(data)


def jinja_render_from_file(
template_path: Path, data: Dict, output_file_path: Optional[Path] = None
template_path: Path, data: Dict[str, Any], output_file_path: Optional[Path] = None
) -> Optional[str]:
"""
Renders a jinja template and outputs either the rendered contents as string or writes to a file.
Expand All @@ -96,12 +120,8 @@ def jinja_render_from_file(
Returns:
None if file path is provided, else returns the rendered string.
"""
env = env_bootstrap(
IgnoreAttrEnvironment(
loader=loaders.FileSystemLoader(template_path.parent),
keep_trailing_newline=True,
undefined=StrictUndefined,
)
env = _get_jinja_env(
loader=loaders.FileSystemLoader(template_path.parent.as_posix())
)
loaded_template = env.get_template(template_path.name)
rendered_result = loaded_template.render(**data)
Expand Down
49 changes: 39 additions & 10 deletions src/snowflake/cli/api/rendering/sql_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,66 @@

from __future__ import annotations

from typing import Dict, Optional
from typing import Dict

from click import ClickException
from jinja2 import StrictUndefined, loaders
from jinja2 import Environment, StrictUndefined, loaders, meta
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli.api.console.console import cli_console
from snowflake.cli.api.exceptions import InvalidTemplate
from snowflake.cli.api.rendering.jinja import (
CONTEXT_KEY,
FUNCTION_KEY,
IgnoreAttrEnvironment,
env_bootstrap,
)

_SQL_TEMPLATE_START = "&{"
_SQL_TEMPLATE_END = "}"
_SQL_TEMPLATE_START = "<%"
_SQL_TEMPLATE_END = "%>"
_OLD_SQL_TEMPLATE_START = "&{"
_OLD_SQL_TEMPLATE_END = "}"
RESERVED_KEYS = [CONTEXT_KEY, FUNCTION_KEY]


def get_sql_cli_jinja_env(*, loader: Optional[loaders.BaseLoader] = None):
def _get_sql_jinja_env(template_start: str, template_end: str) -> Environment:
_random_block = "___very___unique___block___to___disable___logic___blocks___"
return env_bootstrap(
IgnoreAttrEnvironment(
loader=loader or loaders.BaseLoader(),
keep_trailing_newline=True,
variable_start_string=_SQL_TEMPLATE_START,
variable_end_string=_SQL_TEMPLATE_END,
variable_start_string=template_start,
variable_end_string=template_end,
loader=loaders.BaseLoader(),
block_start_string=_random_block,
block_end_string=_random_block,
keep_trailing_newline=True,
undefined=StrictUndefined,
)
)


def _does_template_have_env_syntax(env: Environment, template_content: str) -> bool:
template = env.parse(template_content)
return bool(meta.find_undeclared_variables(template))


def choose_sql_jinja_env_based_on_template_syntax(template_content: str) -> Environment:
old_syntax_env = _get_sql_jinja_env(_OLD_SQL_TEMPLATE_START, _OLD_SQL_TEMPLATE_END)
new_syntax_env = _get_sql_jinja_env(_SQL_TEMPLATE_START, _SQL_TEMPLATE_END)
has_old_syntax = _does_template_have_env_syntax(old_syntax_env, template_content)
has_new_syntax = _does_template_have_env_syntax(new_syntax_env, template_content)
if has_old_syntax and has_new_syntax:
raise InvalidTemplate(
f"The SQL query mixes {_OLD_SQL_TEMPLATE_START} ... {_OLD_SQL_TEMPLATE_END} syntax"
f" and {_SQL_TEMPLATE_START} ... {_SQL_TEMPLATE_END} syntax."
)
if has_old_syntax:
cli_console.warning(
f"Warning: {_OLD_SQL_TEMPLATE_START} ... {_OLD_SQL_TEMPLATE_END} syntax is deprecated."
f" Use {_SQL_TEMPLATE_START} ... {_SQL_TEMPLATE_END} syntax instead."
)
return old_syntax_env
return new_syntax_env


def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str:
data = data or {}

Expand All @@ -57,4 +85,5 @@ def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str:

context_data = get_cli_context().template_context
context_data.update(data)
return get_sql_cli_jinja_env().from_string(content).render(**context_data)
env = choose_sql_jinja_env_based_on_template_syntax(content)
return env.from_string(content).render(context_data)
49 changes: 41 additions & 8 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from unittest import mock
Expand Down Expand Up @@ -321,6 +320,7 @@ def test_use_command(mock_execute_query, _object):
"select &{ aaa }.&{ bbb }",
"select &aaa.&bbb",
"select &aaa.&{ bbb }",
"select <% aaa %>.<% bbb %>",
],
)
@mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string")
Expand All @@ -332,7 +332,29 @@ def test_rendering_of_sql(mock_execute_query, query, runner):
)


@pytest.mark.parametrize("query", ["select &{ foo }", "select &foo"])
@mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string")
def test_old_template_syntax_causes_warning(mock_execute_query, runner):
result = runner.invoke(["sql", "-q", "select &{ aaa }", "-D", "aaa=foo"])
assert result.exit_code == 0
assert (
"Warning: &{ ... } syntax is deprecated. Use <% ... %> syntax instead."
in result.output
)
mock_execute_query.assert_called_once_with("select foo", cursor_class=VerboseCursor)


@mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string")
def test_mixed_template_syntax_error(mock_execute_query, runner):
result = runner.invoke(
["sql", "-q", "select <% aaa %>.&{ bbb }", "-D", "aaa=foo", "-D", "bbb=bar"]
)
assert result.exit_code == 1
assert "The SQL query mixes &{ ... } syntax and <% ... %> syntax." in result.output


@pytest.mark.parametrize(
"query", ["select &{ foo }", "select &foo", "select <% foo %>"]
)
def test_execution_fails_if_unknown_variable(runner, query):
result = runner.invoke(["sql", "-q", query, "-D", "bbb=1"])
assert "SQL template rendering error: 'foo' is undefined" in result.output
Expand All @@ -356,42 +378,53 @@ def test_snowsql_compatibility(text, expected):
assert transpile_snowsql_templates(text) == expected


@pytest.mark.parametrize("template_start,template_end", [("&{", "}"), ("<%", "%>")])
@mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string")
def test_uses_variables_from_snowflake_yml(
mock_execute_query, project_directory, runner
mock_execute_query, project_directory, runner, template_start, template_end
):
with project_directory("sql_templating"):
result = runner.invoke(["sql", "-q", "select &{ ctx.env.sf_var }"])
result = runner.invoke(
["sql", "-q", f"select {template_start} ctx.env.sf_var {template_end}"]
)

assert result.exit_code == 0
mock_execute_query.assert_called_once_with(
"select foo_value", cursor_class=VerboseCursor
)


@pytest.mark.parametrize("template_start,template_end", [("&{", "}"), ("<%", "%>")])
@mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string")
def test_uses_variables_from_snowflake_local_yml(
mock_execute_query, project_directory, runner
mock_execute_query, project_directory, runner, template_start, template_end
):
with project_directory("sql_templating"):
result = runner.invoke(["sql", "-q", "select &{ ctx.env.sf_var_override }"])
result = runner.invoke(
[
"sql",
"-q",
f"select {template_start} ctx.env.sf_var_override {template_end}",
]
)

assert result.exit_code == 0
mock_execute_query.assert_called_once_with(
"select foo_value_override", cursor_class=VerboseCursor
)


@pytest.mark.parametrize("template_start,template_end", [("&{", "}"), ("<%", "%>")])
@mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string")
def test_uses_variables_from_cli_are_added_outside_context(
mock_execute_query, project_directory, runner
mock_execute_query, project_directory, runner, template_start, template_end
):
with project_directory("sql_templating"):
result = runner.invoke(
[
"sql",
"-q",
"select &{ ctx.env.sf_var } &{ other }",
f"select {template_start} ctx.env.sf_var {template_end} {template_start} other {template_end}",
"-D",
"other=other_value",
]
Expand Down
Loading

0 comments on commit ae6de77

Please sign in to comment.