Skip to content

Commit

Permalink
Add basic templating to snow sql (#879)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek authored Apr 23, 2024
1 parent ab3e1ac commit c962461
Show file tree
Hide file tree
Showing 16 changed files with 279 additions and 305 deletions.
11 changes: 11 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Unreleased
## Backward incompatibility

## Deprecations

## New additions
* `snow sql` command supports now client-side templating of queries.

## Fixes and improvements


# v2.2.0

## Backward incompatibility
Expand Down
23 changes: 23 additions & 0 deletions src/snowflake/cli/api/commands/flags.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import tempfile
from dataclasses import dataclass
from enum import Enum
from inspect import signature
from pathlib import Path
Expand Down Expand Up @@ -532,3 +533,25 @@ def _warning_callback(ctx: click.Context, param: click.Parameter, value: Any):
return value.value

return _warning_callback


@dataclass
class Variable:
key: str
value: str

def __init__(self, key: str, value: str):
self.key = key
self.value = value


def parse_key_value_variables(variables: List[str]) -> List[Variable]:
"""Util for parsing key=value input. Useful for commands accepting multiple input options."""
result = []
for p in variables:
if "=" not in p:
raise ClickException(f"Invalid variable: '{p}'")

key, value = p.split("=", 1)
result.append(Variable(key.strip(), value.strip()))
return result
122 changes: 56 additions & 66 deletions src/snowflake/cli/api/utils/rendering.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import json
import os
from pathlib import Path
from textwrap import dedent
from typing import Optional
from typing import Dict, Optional

import jinja2
from jinja2 import Environment, StrictUndefined, loaders
from snowflake.cli.api.secure_path import UNLIMITED, SecurePath


Expand Down Expand Up @@ -33,66 +34,33 @@ def procedure_from_js_file(env: jinja2.Environment, file_name: str):
)


PROCEDURE_TEMPLATE = dedent(
"""\
CREATE OR REPLACE {{ object_type | upper }} {{ name | upper }}(\
{% for arg in signature %}
{{ arg['name'] | upper }} {{ arg['type'] }}{{ "," if not loop.last -}}
{% endfor %}
)
RETURNS {{ returns }}
LANGUAGE {{ language }}
{% if runtime_version is defined -%}
RUNTIME_VERSION = '{{ runtime_version }}'
{% endif -%}
{% if packages is defined -%}
PACKAGES = ('{{ packages }}')
{% endif -%}
{% if imports is defined -%}
IMPORTS = ({% for import in imports %}'{{ import }}'{{ ", " if not loop.last }}{% endfor %})
{% endif -%}
{% if handler is defined -%}
HANDLER = '{{ handler }}'
{% endif -%}
{% if code is defined -%}
AS
$$
{{ code }}
$$
{%- endif -%}
;
{%- if grants is defined -%}
{%- for grant in grants %}
GRANT USAGE ON {{ object_type | upper }} {{ name | upper }}({% for arg in signature %}{{ arg['type'] }}{{ ", " if not loop.last }}{% endfor %})
TO DATABASE ROLE {{ grant['role'] }};
{% endfor -%}
{% endif -%}\
"""
)
_CUSTOM_FILTERS = [read_file_content, procedure_from_js_file]


@jinja2.pass_environment # type: ignore
def render_metadata(env: jinja2.Environment, file_name: str):
metadata = json.loads(
SecurePath(file_name).absolute().read_text(file_size_limit_mb=UNLIMITED)
def _env_bootstrap(env: Environment) -> Environment:
for custom_filter in _CUSTOM_FILTERS:
env.filters[custom_filter.__name__] = custom_filter

return env


def get_snowflake_cli_jinja_env():
_random_block = "___very___unique___block___to___disable___logic___blocks___"
return _env_bootstrap(
Environment(
loader=loaders.BaseLoader(),
keep_trailing_newline=True,
variable_start_string="&{",
variable_end_string="}",
block_start_string=_random_block,
block_end_string=_random_block,
undefined=StrictUndefined,
)
)
template = env.from_string(PROCEDURE_TEMPLATE)

rendered = []
known_objects = {
"procedures": "procedure",
"udfs": "function",
"udtfs": "function",
}
for object_key, object_type in known_objects.items():
for obj in metadata.get(object_key, []):
rendered.append(template.render(object_type=object_type, **obj))
return "\n".join(rendered)


def generic_render_template(
template_path: Path, data: dict, output_file_path: Optional[Path] = None


def jinja_render_from_file(
template_path: Path, data: Dict, output_file_path: Optional[Path] = None
):
"""
Create a file from a jinja template.
Expand All @@ -105,17 +73,39 @@ def generic_render_template(
Returns:
None
"""
env = jinja2.Environment(
loader=jinja2.loaders.FileSystemLoader(template_path.parent),
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
env = _env_bootstrap(
Environment(
loader=loaders.FileSystemLoader(template_path.parent),
keep_trailing_newline=True,
undefined=StrictUndefined,
)
)
filters = [render_metadata, read_file_content, procedure_from_js_file]
for custom_filter in filters:
env.filters[custom_filter.__name__] = custom_filter
loaded_template = env.get_template(template_path.name)
rendered_result = loaded_template.render(**data)
if output_file_path:
SecurePath(output_file_path).write_text(rendered_result)
else:
print(rendered_result)


class _AttrGetter:
def __init__(self, data_dict):
self._data_dict = data_dict

def __getattr__(self, item):
if item not in self._data_dict:
raise AttributeError(f"No attribute {item}")
return self._data_dict[item]


def _add_project_context(data: Dict):
context_key = "ctx"
if context_key in data:
raise ValueError(f"{context_key} in user defined data")
context_data = {context_key: {"env": _AttrGetter(os.environ)}}
return {**data, **context_data}


def snowflake_cli_jinja_render(content: str, data: Dict | None = None) -> str:
data = _add_project_context(data or dict())
return get_snowflake_cli_jinja_env().from_string(content).render(**data)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from snowflake.cli.plugins.nativeapp import plugin_spec as nativeapp_plugin_spec
from snowflake.cli.plugins.notebook import plugin_spec as notebook_plugin_spec
from snowflake.cli.plugins.object import plugin_spec as object_plugin_spec
from snowflake.cli.plugins.render import plugin_spec as render_plugin_spec
from snowflake.cli.plugins.snowpark import plugin_spec as snowpark_plugin_spec
from snowflake.cli.plugins.spcs import plugin_spec as spcs_plugin_spec
from snowflake.cli.plugins.sql import plugin_spec as sql_plugin_spec
Expand All @@ -18,7 +17,6 @@ def get_builtin_plugin_name_to_plugin_spec():
"spcs": spcs_plugin_spec,
"nativeapp": nativeapp_plugin_spec,
"object": object_plugin_spec,
"render": render_plugin_spec,
"snowpark": snowpark_plugin_spec,
"stage": stage_plugin_spec,
"sql": sql_plugin_spec,
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/cli/plugins/nativeapp/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
to_identifier,
)
from snowflake.cli.api.secure_path import SecurePath
from snowflake.cli.api.utils.rendering import generic_render_template
from snowflake.cli.api.utils.rendering import jinja_render_from_file
from yaml import dump, safe_dump, safe_load

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -121,7 +121,7 @@ def _render_snowflake_yml(parent_to_snowflake_yml: Path, project_identifier: str
snowflake_yml_jinja = "snowflake.yml.jinja"

try:
generic_render_template(
jinja_render_from_file(
template_path=parent_to_snowflake_yml / snowflake_yml_jinja,
data={
# generic_render_template operates on text, not YAML, so escape before rendering
Expand Down
Empty file.
76 changes: 0 additions & 76 deletions src/snowflake/cli/plugins/render/commands.py

This file was deleted.

16 changes: 0 additions & 16 deletions src/snowflake/cli/plugins/render/plugin_spec.py

This file was deleted.

27 changes: 25 additions & 2 deletions src/snowflake/cli/plugins/sql/commands.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pathlib import Path
from typing import Optional
from typing import List, Optional

import typer
from snowflake.cli.api.commands.flags import parse_key_value_variables
from snowflake.cli.api.commands.snow_typer import SnowTyper
from snowflake.cli.api.output.types import CommandResult, MultipleResults, QueryResult
from snowflake.cli.plugins.sql.manager import SqlManager
Expand All @@ -10,6 +11,14 @@
app = SnowTyper()


def _parse_key_value(key_value_str: str):
parts = key_value_str.split("=")
if len(parts) < 2:
raise ValueError("Passed key-value pair does not comform with key=value format")

return parts[0], "=".join(parts[1:])


@app.command(name="sql", requires_connection=True)
def execute_sql(
query: Optional[str] = typer.Option(
Expand All @@ -34,15 +43,29 @@ def execute_sql(
"-i",
help="Read the query from standard input. Use it when piping input to this command.",
),
data_override: List[str] = typer.Option(
None,
"--data",
"-D",
help="String in format of key=value. If provided the SQL content will "
"be treated as template and rendered using provided data.",
),
**options,
) -> CommandResult:
"""
Executes Snowflake query.
Query to execute can be specified using query option, filename option (all queries from file will be executed)
or via stdin by piping output from other command. For example `cat my.sql | snow sql -i`.
The command supports variable substitution that happens on client-side. Both $VARIABLE or ${ VARIABLE }
syntax are supported.
"""
single_statement, cursors = SqlManager().execute(query, file, std_in)
data = {}
if data_override:
data = {v.key: v.value for v in parse_key_value_variables(data_override)}

single_statement, cursors = SqlManager().execute(query, file, std_in, data=data)
if single_statement:
return QueryResult(next(cursors))
return MultipleResults((QueryResult(c) for c in cursors))
Loading

0 comments on commit c962461

Please sign in to comment.