From 916238ba6b72945a53ff02f3812b9147ed355d20 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Mon, 22 Apr 2024 17:58:19 +0200 Subject: [PATCH] fixup! Add global context --- RELEASE-NOTES.md | 2 +- src/snowflake/cli/api/commands/flags.py | 23 ++++++++++++++++++ src/snowflake/cli/plugins/sql/commands.py | 5 ++-- src/snowflake/cli/plugins/stage/manager.py | 26 ++------------------- tests/__snapshots__/test_help_messages.ambr | 6 +++-- tests/api/utils/test_rendering.py | 2 ++ 6 files changed, 34 insertions(+), 30 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 9e62e9486b..0eb6ab70e5 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -48,7 +48,7 @@ * `snow snowpark build` * `snow snowpark package lookup` * `snow snowpark package create` -* `snow sql` command supports now templating of queries. +* `snow sql` command supports now client-side templating of queries. ## Fixes and improvements * Adding `--image-name` option for image name argument in `spcs image-repository list-tags` for consistency with other commands. diff --git a/src/snowflake/cli/api/commands/flags.py b/src/snowflake/cli/api/commands/flags.py index 7186a7b877..1bcef29fd8 100644 --- a/src/snowflake/cli/api/commands/flags.py +++ b/src/snowflake/cli/api/commands/flags.py @@ -1,6 +1,7 @@ from __future__ import annotations import tempfile +from dataclasses import dataclass from enum import Enum from inspect import signature from pathlib import Path @@ -532,3 +533,25 @@ def _warning_callback(ctx: click.Context, param: click.Parameter, value: Any): return value.value return _warning_callback + + +@dataclass +class Variable: + key: str + value: str + + def __init__(self, key: str, value: str): + self.key = key + self.value = value + + +def parse_key_value_variables(variables: List[str]) -> List[Variable]: + """Util for parsing key=value input. Useful for commands accepting multiple input options.""" + result = [] + for p in variables: + if "=" not in p: + raise ClickException(f"Invalid variable: '{p}'") + + key, value = p.split("=", 1) + result.append(Variable(key.strip(), value.strip())) + return result diff --git a/src/snowflake/cli/plugins/sql/commands.py b/src/snowflake/cli/plugins/sql/commands.py index 3494a0ecad..0310d02e09 100644 --- a/src/snowflake/cli/plugins/sql/commands.py +++ b/src/snowflake/cli/plugins/sql/commands.py @@ -2,6 +2,7 @@ from typing import List, Optional import typer +from snowflake.cli.api.commands.flags import parse_key_value_variables from snowflake.cli.api.commands.snow_typer import SnowTyper from snowflake.cli.api.output.types import CommandResult, MultipleResults, QueryResult from snowflake.cli.plugins.sql.manager import SqlManager @@ -62,9 +63,7 @@ def execute_sql( """ data = {} if data_override: - for key_value_str in data_override: - key, value = _parse_key_value(key_value_str) - data[key] = value + data = {v.key: v.value for v in parse_key_value_variables(data_override)} single_statement, cursors = SqlManager().execute(query, file, std_in, data=data) if single_statement: diff --git a/src/snowflake/cli/plugins/stage/manager.py b/src/snowflake/cli/plugins/stage/manager.py index 11ffc0e6ff..dad58f43f3 100644 --- a/src/snowflake/cli/plugins/stage/manager.py +++ b/src/snowflake/cli/plugins/stage/manager.py @@ -5,13 +5,12 @@ import logging import re from contextlib import nullcontext -from dataclasses import dataclass from os import path from pathlib import Path from typing import Dict, List, Optional, Union from click import ClickException -from snowflake.cli.api.commands.flags import OnErrorType +from snowflake.cli.api.commands.flags import OnErrorType, parse_key_value_variables from snowflake.cli.api.console import cli_console from snowflake.cli.api.project.util import to_string_literal from snowflake.cli.api.secure_path import SecurePath @@ -27,16 +26,6 @@ EXECUTE_SUPPORTED_FILES_FORMATS = {".sql"} -@dataclass -class Variable: - key: str - value: str - - def __init__(self, key: str, value: str): - self.key = key - self.value = value - - class StageManager(SqlExecutionMixin): @staticmethod def get_standard_stage_prefix(name: str) -> str: @@ -262,21 +251,10 @@ def _parse_execute_variables(variables: Optional[List[str]]) -> Optional[str]: if not variables: return None - parsed_variables = StageManager._parse_variables(variables) + parsed_variables = parse_key_value_variables(variables) query_parameters = [f"{v.key}=>{v.value}" for v in parsed_variables] return f" using ({', '.join(query_parameters)})" - @staticmethod - def _parse_variables(variables: List[str]) -> List[Variable]: - result = [] - for p in variables: - if "=" not in p: - raise ClickException(f"Invalid variable: '{p}'") - - key, value = p.split("=", 1) - result.append(Variable(key.strip(), value.strip())) - return result - def _call_execute_immediate( self, file: str, variables: Optional[str], on_error: OnErrorType ) -> Dict: diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index e3513bf350..9d542d3c5c 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -4820,6 +4820,8 @@ Query to execute can be specified using query option, filename option (all queries from file will be executed) or via stdin by piping output from other command. For example `cat my.sql | snow sql -i`. + The command supports variable substitution that happens on client-side. Both + $VARIABLE or ${ VARIABLE } syntax are supported. ╭─ Options ────────────────────────────────────────────────────────────────────╮ │ --query -q TEXT Query to execute. [default: None] │ @@ -4827,8 +4829,8 @@ │ --stdin -i Read the query from standard input. Use it when │ │ piping input to this command. │ │ --data -D TEXT String in format of key=value. If provided the SQL │ - │ content will be treated and rendered using │ - │ provided data. │ + │ content will be treated as template and rendered │ + │ using provided data. │ │ [default: None] │ │ --help -h Show this message and exit. │ ╰──────────────────────────────────────────────────────────────────────────────╯ diff --git a/tests/api/utils/test_rendering.py b/tests/api/utils/test_rendering.py index b8450552a3..d16e10e1a8 100644 --- a/tests/api/utils/test_rendering.py +++ b/tests/api/utils/test_rendering.py @@ -47,7 +47,9 @@ def test_that_common_logic_block_are_ignored(text): def test_that_common_comments_are_respected(): + # Make sure comment are ignored assert snowflake_cli_jinja_render("{# note a comment &{ foo } #}") == "" + # Make sure comment's work together with templates assert ( snowflake_cli_jinja_render("{# note a comment #}&{ foo }", data={"foo": "bar"}) == "bar"