Skip to content

Commit

Permalink
Add basic templating to snow sql
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek committed Mar 14, 2024
1 parent 5ca1847 commit 60f42ca
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 18 deletions.
48 changes: 38 additions & 10 deletions src/snowflake/cli/api/utils/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import json
from pathlib import Path
from textwrap import dedent
from typing import Optional
from typing import Dict, Optional

import jinja2
from jinja2 import Environment, StrictUndefined, loaders
from snowflake.cli.api.secure_path import UNLIMITED, SecurePath


Expand Down Expand Up @@ -91,8 +92,31 @@ def render_metadata(env: jinja2.Environment, file_name: str):
return "\n".join(rendered)


def generic_render_template(
template_path: Path, data: dict, output_file_path: Optional[Path] = None
_CUSTOM_FILTERS = [render_metadata, read_file_content, procedure_from_js_file]


def _env_bootstrap(env: Environment) -> Environment:
for custom_filter in _CUSTOM_FILTERS:
env.filters[custom_filter.__name__] = custom_filter

return env


_RANDOM_BLOCK = "___very___unique___block___to___disable___logic___blocks___"
SNOWFLAKE_CLI_JINJA_ENV = _env_bootstrap(
Environment(
loader=loaders.BaseLoader(),
keep_trailing_newline=True,
variable_start_string="%{",
variable_end_string="}",
block_start_string=_RANDOM_BLOCK,
block_end_string=_RANDOM_BLOCK,
)
)


def jinja_render_from_file(
template_path: Path, data: Dict, output_file_path: Optional[Path] = None
):
"""
Create a file from a jinja template.
Expand All @@ -105,17 +129,21 @@ def generic_render_template(
Returns:
None
"""
env = jinja2.Environment(
loader=jinja2.loaders.FileSystemLoader(template_path.parent),
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
env = _env_bootstrap(
Environment(
loader=loaders.FileSystemLoader(template_path.parent),
keep_trailing_newline=True,
undefined=StrictUndefined,
)
)
filters = [render_metadata, read_file_content, procedure_from_js_file]
for custom_filter in filters:
env.filters[custom_filter.__name__] = custom_filter
loaded_template = env.get_template(template_path.name)
rendered_result = loaded_template.render(**data)
if output_file_path:
SecurePath(output_file_path).write_text(rendered_result)
else:
print(rendered_result)


def snowflake_cli_jinja_render(content: str, data: Dict | None = None) -> str:
data = data or dict()
return SNOWFLAKE_CLI_JINJA_ENV.from_string(content).render(**data)
4 changes: 2 additions & 2 deletions src/snowflake/cli/plugins/nativeapp/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
to_identifier,
)
from snowflake.cli.api.secure_path import SecurePath
from snowflake.cli.api.utils.rendering import generic_render_template
from snowflake.cli.api.utils.rendering import jinja_render_from_file
from yaml import dump, safe_dump, safe_load

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -121,7 +121,7 @@ def _render_snowflake_yml(parent_to_snowflake_yml: Path, project_identifier: str
snowflake_yml_jinja = "snowflake.yml.jinja"

try:
generic_render_template(
jinja_render_from_file(
template_path=parent_to_snowflake_yml / snowflake_yml_jinja,
data={
# generic_render_template operates on text, not YAML, so escape before rendering
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/cli/plugins/render/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from snowflake.cli.api import secure_path
from snowflake.cli.api.commands.decorators import global_options
from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS
from snowflake.cli.api.utils.rendering import generic_render_template
from snowflake.cli.api.utils.rendering import jinja_render_from_file

app = typer.Typer(context_settings=DEFAULT_CONTEXT_SETTINGS, hidden=True, name="render")

Expand Down Expand Up @@ -71,6 +71,6 @@ def render_template(
key, value = _parse_key_value(key_value_str)
data[key] = value

generic_render_template(
jinja_render_from_file(
template_path=template_path, data=data, output_file_path=output_file_path
)
25 changes: 23 additions & 2 deletions src/snowflake/cli/plugins/sql/commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional
from typing import List, Optional

import typer
from snowflake.cli.api.commands.snow_typer import SnowTyper
Expand All @@ -10,6 +10,14 @@
app = SnowTyper()


def _parse_key_value(key_value_str: str):
parts = key_value_str.split("=")
if len(parts) < 2:
raise ValueError("Passed key-value pair does not comform with key=value format")

return parts[0], "=".join(parts[1:])


@app.command(name="sql", requires_connection=True)
def execute_sql(
query: Optional[str] = typer.Option(
Expand All @@ -34,6 +42,13 @@ def execute_sql(
"-i",
help="Read the query from standard input. Use it when piping input to this command.",
),
data_override: List[str] = typer.Option(
None,
"--data",
"-D",
help="String in format of key=value. If provided the SQL content will "
"be treated and rendered using provided data.",
),
**options
) -> CommandResult:
"""
Expand All @@ -42,7 +57,13 @@ def execute_sql(
Query to execute can be specified using query option, filename option (all queries from file will be executed)
or via stdin by piping output from other command. For example `cat my.sql | snow sql -i`.
"""
single_statement, cursors = SqlManager().execute(query, file, std_in)
data = {}
if data_override:
for key_value_str in data_override:
key, value = _parse_key_value(key_value_str)
data[key] = value

single_statement, cursors = SqlManager().execute(query, file, std_in, data=data)
if single_statement:
return QueryResult(next(cursors))
return MultipleResults((QueryResult(c) for c in cursors))
13 changes: 11 additions & 2 deletions src/snowflake/cli/plugins/sql/manager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import sys
from io import StringIO
from pathlib import Path
from typing import Iterable, Optional, Tuple
from typing import Dict, Iterable, Optional, Tuple

from click import UsageError
from snowflake.cli.api.secure_path import UNLIMITED, SecurePath
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.cli.api.utils.rendering import snowflake_cli_jinja_render
from snowflake.connector.cursor import SnowflakeCursor
from snowflake.connector.util_text import split_statements


class SqlManager(SqlExecutionMixin):
def execute(
self, query: Optional[str], file: Optional[Path], std_in: bool
self,
query: Optional[str],
file: Optional[Path],
std_in: bool,
data: Dict | None = None,
) -> Tuple[int, Iterable[SnowflakeCursor]]:
inputs = [query, file, std_in]
if not any(inputs):
Expand All @@ -29,6 +34,10 @@ def execute(
elif file:
query = SecurePath(file).read_text(file_size_limit_mb=UNLIMITED)

if data:
# Do rendering if any data was provided
query = snowflake_cli_jinja_render(content=query, data=data)

statements = tuple(
statement
for statement, _ in split_statements(StringIO(query), remove_comments=True)
Expand Down
23 changes: 23 additions & 0 deletions tests/api/utils/test_rendering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
from snowflake.cli.api.utils.rendering import snowflake_cli_jinja_render


def test_rendering_with_data():
assert snowflake_cli_jinja_render("%{ foo }", data={"foo": "bar"}) == "bar"


@pytest.mark.parametrize(
"text, output",
[
# Green path
("%{ foo }", "bar"),
# Using $ as sf variable and basic jinja for server side
("${{ foo }}", "${{ foo }}"),
("$%{ foo }{{ var }}", "$bar{{ var }}"),
("${{ %{ foo } }}", "${{ bar }}"),
# Using $ as sf variable and client side rendering
("$%{ foo }", "$bar"),
],
)
def test_rendering(text, output):
assert snowflake_cli_jinja_render(text, data={"foo": "bar"}) == output
16 changes: 16 additions & 0 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,19 @@ def test_show_specific_object_multiple_rows(mock_execute_query):
mock_execute_query.assert_called_once_with(
r"show objects like 'NAME'", cursor_class=DictCursor
)


@mock.patch("snowflake.cli.plugins.sql.commands.SqlManager._execute_string")
def test_rendering_of_sql(mock_execute_query, runner):
result = runner.invoke(
["sql", "-q", "select %{ aaa }.%{ bbb }", "-D", "aaa=foo", "-D", "bbb=bar"]
)
assert result.exit_code == 0, result.output
mock_execute_query.assert_called_once_with("select foo.bar")


@mock.patch("snowflake.cli.plugins.sql.commands.SqlManager._execute_string")
def test_no_rendering_of_sql_if_no_data(mock_execute_query, runner):
result = runner.invoke(["sql", "-q", "select %{ aaa }.%{ bbb }"])
assert result.exit_code == 0, result.output
mock_execute_query.assert_called_once_with("select %{ aaa }.%{ bbb }")

0 comments on commit 60f42ca

Please sign in to comment.