diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index ae21cdb3b3..2e87439d6e 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -8,6 +8,9 @@ * Added `snow notebook` commands: * `snow notebook execute` enabling head-less execution of a notebook. * `snow notebook create` proving an option to create a Snowflake Notebook from a file on stage. +* Added templating support for project definition file. + * Templates can now be used within the main section of the project definition file. + * Resolved values of the project definition file are available to all modules. ## Fixes and improvements diff --git a/src/snowflake/cli/api/cli_global_context.py b/src/snowflake/cli/api/cli_global_context.py index 82c6782aca..4e28d2253c 100644 --- a/src/snowflake/cli/api/cli_global_context.py +++ b/src/snowflake/cli/api/cli_global_context.py @@ -212,6 +212,7 @@ def __init__(self): self._experimental = False self._project_definition = None self._project_root = None + self._template_context = None self._silent: bool = False def reset(self): @@ -259,6 +260,13 @@ def project_root(self): def set_project_root(self, project_root: Path): self._project_root = project_root + @property + def template_context(self): + return self._template_context + + def set_template_context(self, template_context: dict): + self._template_context = template_context + @property def connection_context(self) -> _ConnectionContext: return self._connection_context @@ -311,6 +319,10 @@ def project_definition(self) -> ProjectDefinition | None: def project_root(self): return self._manager.project_root + @property + def template_context(self): + return self._manager.template_context + @property def silent(self) -> bool: if self._should_force_mute_intermediate_output: diff --git a/src/snowflake/cli/api/commands/flags.py b/src/snowflake/cli/api/commands/flags.py index 909a282673..e07b35b394 100644 --- a/src/snowflake/cli/api/commands/flags.py +++ b/src/snowflake/cli/api/commands/flags.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import tempfile from dataclasses import dataclass from enum import Enum @@ -14,6 +15,7 @@ from snowflake.cli.api.console import cli_console from snowflake.cli.api.exceptions import MissingConfiguration from snowflake.cli.api.output.formats import OutputFormat +from snowflake.cli.api.utils.rendering import CONTEXT_KEY DEFAULT_CONTEXT_SETTINGS = {"help_option_names": ["--help", "-h"]} @@ -499,6 +501,7 @@ def _callback(project_path: Optional[str]): cli_context_manager.set_project_definition(project_definition) cli_context_manager.set_project_root(project_root) + cli_context_manager.set_template_context(dm.template_context) return project_definition if project_name == "native_app": @@ -527,14 +530,17 @@ def _callback(project_path: Optional[str]): dm = DefinitionManager(project_path) project_definition = dm.project_definition project_root = dm.project_root + template_context = dm.template_context except MissingConfiguration: if optional: project_definition = None project_root = None + template_context = {CONTEXT_KEY: {"env": os.environ}} else: raise cli_context_manager.set_project_definition(project_definition) cli_context_manager.set_project_root(project_root) + cli_context_manager.set_template_context(template_context) return project_definition return typer.Option( diff --git a/src/snowflake/cli/api/project/definition.py b/src/snowflake/cli/api/project/definition.py index f6dac5c1fe..09ba764a90 100644 --- a/src/snowflake/cli/api/project/definition.py +++ b/src/snowflake/cli/api/project/definition.py @@ -1,7 +1,8 @@ from __future__ import annotations +from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List +from typing import List import yaml.loader from snowflake.cli.api.cli_global_context import cli_context @@ -15,33 +16,19 @@ ) from snowflake.cli.api.secure_path import SecurePath from snowflake.cli.api.utils.definition_rendering import render_definition_template +from snowflake.cli.api.utils.dict_utils import deep_merge_dicts from yaml import load DEFAULT_USERNAME = "unknown_user" -def merge_two_dicts( - original_values: Dict[str, Any], update_values: Dict[str, Any] -): # TODO update name of function - if not isinstance(update_values, dict) or not isinstance(original_values, dict): - return +@dataclass +class ProjectProperties: + project_definition: ProjectDefinition + raw_project_definition: dict - for field, value in update_values.items(): - if ( - field in original_values - and isinstance(original_values[field], dict) - and isinstance(value, dict) - ): - merge_two_dicts(original_values[field], value) - else: - original_values[field] = value - -def load_project_definition(paths: List[Path]) -> ProjectDefinition: - """ - Loads project definition, optionally overriding values. Definition values - are merged in left-to-right order (increasing precedence). - """ +def _get_merged_project_files(paths: List[Path]) -> dict: spaths: List[SecurePath] = [SecurePath(p) for p in paths] if len(spaths) == 0: raise ValueError("Need at least one definition file.") @@ -54,11 +41,21 @@ def load_project_definition(paths: List[Path]) -> ProjectDefinition: "r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB ) as override_yml: overrides = load(override_yml.read(), Loader=yaml.loader.BaseLoader) - merge_two_dicts(definition, overrides) + deep_merge_dicts(definition, overrides) + + return definition - rendered_definition = render_definition_template(definition) - rendered_project = ProjectDefinition(**rendered_definition) - return rendered_project + +def load_project(paths: List[Path]) -> ProjectProperties: + """ + Loads project definition, optionally overriding values. Definition values + are merged in left-to-right order (increasing precedence). + """ + merged_files = _get_merged_project_files(paths) + rendered_definition = render_definition_template(merged_files) + return ProjectProperties( + ProjectDefinition(**rendered_definition), rendered_definition + ) def generate_local_override_yml( diff --git a/src/snowflake/cli/api/project/definition_manager.py b/src/snowflake/cli/api/project/definition_manager.py index dd90606163..e094a6e04c 100644 --- a/src/snowflake/cli/api/project/definition_manager.py +++ b/src/snowflake/cli/api/project/definition_manager.py @@ -6,8 +6,9 @@ from typing import List, Optional from snowflake.cli.api.exceptions import MissingConfiguration -from snowflake.cli.api.project.definition import load_project_definition +from snowflake.cli.api.project.definition import ProjectProperties, load_project from snowflake.cli.api.project.schemas.project_definition import ProjectDefinition +from snowflake.cli.api.utils.rendering import CONTEXT_KEY def _compat_is_mount(path: Path): @@ -100,6 +101,16 @@ def _user_definition_file_if_available(project_path: Path) -> Optional[Path]: DefinitionManager.USER_DEFINITION_FILENAME, project_path ) + @functools.cached_property + def _project_properties(self) -> ProjectProperties: + return load_project(self._project_config_paths) + @functools.cached_property def project_definition(self) -> ProjectDefinition: - return load_project_definition(self._project_config_paths) + return self._project_properties.project_definition + + @functools.cached_property + def template_context(self) -> dict: + definition = self._project_properties.raw_project_definition + + return {CONTEXT_KEY: definition} diff --git a/src/snowflake/cli/api/utils/definition_rendering.py b/src/snowflake/cli/api/utils/definition_rendering.py index cfafc8db05..74e3318b92 100644 --- a/src/snowflake/cli/api/utils/definition_rendering.py +++ b/src/snowflake/cli/api/utils/definition_rendering.py @@ -1,27 +1,83 @@ from __future__ import annotations import os -from collections import deque -from dataclasses import dataclass -from typing import Optional -import yaml from jinja2 import Environment, UndefinedError, nodes -from snowflake.cli.api.utils.graph import Graph -from snowflake.cli.api.utils.rendering import get_snowflake_cli_jinja_env +from packaging.version import Version +from snowflake.cli.api.utils.dict_utils import deep_merge_dicts, deep_traverse +from snowflake.cli.api.utils.graph import Graph, Node +from snowflake.cli.api.utils.rendering import CONTEXT_KEY, get_snowflake_cli_jinja_env class Variable: def __init__(self, vars_chain): self._vars_chain = list(vars_chain) - self.value: str | None = None - - def get_vars_hierarchy(self) -> deque[str]: - return deque(self._vars_chain) + self.templated_value = None + self.rendered_value = None def get_key(self): return ".".join(self._vars_chain) + def is_env_var(self): + return ( + len(self._vars_chain) == 3 + and self._vars_chain[0] == CONTEXT_KEY + and self._vars_chain[1] == "env" + ) + + def get_env_var_name(self) -> str: + if not self.is_env_var(): + raise KeyError( + f"Referenced variable {self.get_key()} is not an environment variable" + ) + return self._vars_chain[2] + + def store_in_context(self, context: dict, value): + """ + Takes a generic context dict to modify, and a value + + Traverse through the multi-level dictionary to the location where this variables goes to. + Sets this location to the content of value. + + Example: vars chain contains ['ctx', 'env', 'x'], and context is {}, and value is 'val'. + At the end of this call, context content will be: {'ctx': {'env': {'x': 'val'}}} + """ + current_dict_level = context + for i, var in enumerate(self._vars_chain): + if i == len(self._vars_chain) - 1: + current_dict_level[var] = value + else: + current_dict_level.setdefault(var, {}) + current_dict_level = current_dict_level[var] + + def read_from_context(self, context): + """ + Takes a context dict as input. + + Traverse through the multi-level dictionary to the location where this variable goes to. + Returns the value in that location. + + Raise UndefinedError if the variable is None or not found. + """ + current_dict_level = context + for key in self._vars_chain: + if ( + not isinstance(current_dict_level, dict) + or key not in current_dict_level + ): + raise UndefinedError( + f"Could not find template variable {self.get_key()}" + ) + current_dict_level = current_dict_level[key] + + value = current_dict_level + if value == None or isinstance(value, dict) or isinstance(value, list): + raise UndefinedError( + f"Template variable {self.get_key()} does not contain a valid value" + ) + + return value + def __hash__(self): return hash(self.get_key()) @@ -29,135 +85,112 @@ def __eq__(self, other): return self.get_key() == other.get_key() -@dataclass(eq=False) -class TemplateGraphNode(Graph.Node): - templated_value: str | None = None - rendered_value: str | None = None - variable: Variable | None = None - - def store_in_context(self, context): - # if graph node is ctx.env.test with value x, context will be built as {'ctx': {'env': {'test': 'x'}}} - vars_chain = self.variable.get_vars_hierarchy() - self._store_in_context_recursive(context, vars_chain) - - def _store_in_context_recursive(self, context, vars_chain: deque): - if len(vars_chain) == 0: - return self.rendered_value - - current_level_key = vars_chain.popleft() - current_value_in_key = ( - context[current_level_key] if current_level_key in context else {} - ) - context[current_level_key] = self._store_in_context_recursive( - current_value_in_key, vars_chain - ) - return context - - -def _get_referenced_vars( - ast_node, variable_attr_chain: deque = deque() -) -> set[Variable]: - # Var nodes end in Name node, and starts with Getattr node. - # Example: ctx.env.test will look like Getattr(test) -> Getattr(env) -> Name(ctx) +def _get_referenced_vars(ast_node, current_attr_chain: list[str] = []) -> set[Variable]: + """ + Traverse Jinja AST to find the variable chain referenced by the template. + A variable like ctx.env.test is internally represented in the AST tree as + Getattr Node (attr='test') -> Getattr Node (attr='env') -> Name Node (name='ctx') + """ all_referenced_vars = set() - - variable_appended = False if isinstance(ast_node, nodes.Getattr): - variable_attr_chain.appendleft(getattr(ast_node, "attr")) - variable_appended = True + current_attr_chain = [getattr(ast_node, "attr")] + current_attr_chain elif isinstance(ast_node, nodes.Name): - variable_attr_chain.appendleft(getattr(ast_node, "name")) - variable_appended = True - all_referenced_vars.add(Variable(variable_attr_chain)) + current_attr_chain = [getattr(ast_node, "name")] + current_attr_chain + all_referenced_vars.add(Variable(current_attr_chain)) for child_node in ast_node.iter_child_nodes(): - all_referenced_vars.update( - _get_referenced_vars(child_node, variable_attr_chain) - ) - - if variable_appended: - variable_attr_chain.popleft() + all_referenced_vars.update(_get_referenced_vars(child_node, current_attr_chain)) return all_referenced_vars -def _get_referenced_vars_from_str( - env: Environment, template_str: Optional[str] -) -> set[Variable]: - if template_str == None: - return set() - +def _get_referenced_vars_from_str(env: Environment, template_str: str) -> set[Variable]: ast = env.parse(template_str) return _get_referenced_vars(ast) -def _get_value_from_var_path(env: Environment, context: dict, var_path: str): - # given a variable path (e.g. ctx.env.test), return evaluated value based on context - # fall back to env variables and escape them so we stop the chain of variables - ref_str = "<% " + var_path + " %>" - template = env.from_string(ref_str) - try: - # check in env variables - return f"{env.block_start_string} raw {env.block_end_string}{template.render({'ctx': {'env': os.environ}})}{env.block_start_string} endraw {env.block_end_string}" - except UndefinedError as e: - try: - return template.render(context) - except: - raise UndefinedError("Could not find template variable " + var_path) - - -def _build_dependency_graph(env, all_vars: set[Variable], context_without_env) -> Graph: - dependencies_graph = Graph() +def _build_dependency_graph( + env: Environment, all_vars: set[Variable], context +) -> Graph[Variable]: + dependencies_graph = Graph[Variable]() for variable in all_vars: - dependencies_graph.add( - TemplateGraphNode(key=variable.get_key(), variable=variable) - ) + dependencies_graph.add(Node[Variable](key=variable.get_key(), data=variable)) for variable in all_vars: - node: TemplateGraphNode = dependencies_graph.get(key=variable.get_key()) - node.templated_value = _get_value_from_var_path( - env, context_without_env, variable.get_key() - ) - dependencies_vars = _get_referenced_vars_from_str(env, node.templated_value) - - for referenced_var in dependencies_vars: - dependencies_graph.add_dependency( - variable.get_key(), referenced_var.get_key() + if variable.is_env_var() and variable.get_env_var_name() in os.environ: + # If variable is found in os.environ, then use the value as is + # skip rendering by pre-setting the rendered_value attribute + env_value = os.environ.get(variable.get_env_var_name()) + variable.rendered_value = env_value + variable.templated_value = env_value + else: + variable.templated_value = str(variable.read_from_context(context)) + dependencies_vars = _get_referenced_vars_from_str( + env, variable.templated_value ) + for referenced_var in dependencies_vars: + dependencies_graph.add_directed_edge( + variable.get_key(), referenced_var.get_key() + ) return dependencies_graph -def render_definition_template(definition): +def _render_graph_node(jinja_env: Environment, node: Node[Variable]): + if node.data.rendered_value is not None: + # Do not re-evaluate resolved nodes like env variable nodes, + # which might contain template-like values + return + + current_context: dict = {} + for dep_node in node.neighbors: + dep_node.data.store_in_context(current_context, dep_node.data.rendered_value) + + template = jinja_env.from_string(node.data.templated_value) + node.data.rendered_value = template.render(current_context) + + +def _render_dict_element(jinja_env: Environment, context, element): + if _get_referenced_vars_from_str(jinja_env, element): + template = jinja_env.from_string(element) + return template.render(context) + return element + + +def render_definition_template(definition: dict): + if "definition_version" not in definition or Version( + definition["definition_version"] + ) < Version("1.1"): + return definition + jinja_env = get_snowflake_cli_jinja_env() - pdf_context = {"ctx": definition} - pdf_yaml_str = yaml.dump(definition) + pdf_context = {CONTEXT_KEY: definition} + + referenced_vars = set() - referenced_vars = _get_referenced_vars_from_str(jinja_env, pdf_yaml_str) + def find_any_template_vars(element): + referenced_vars.update(_get_referenced_vars_from_str(jinja_env, element)) + + deep_traverse(definition, visit_action=find_any_template_vars) dependencies_graph = _build_dependency_graph( jinja_env, referenced_vars, pdf_context ) - def evaluate_node(node: TemplateGraphNode): - current_context: dict = {} - dep_node: TemplateGraphNode - for dep_node in node.dependencies: - dep_node.store_in_context(current_context) - - template = jinja_env.from_string(node.templated_value) - node.rendered_value = template.render(current_context) - - dependencies_graph.dfs(visit_action=evaluate_node) + dependencies_graph.dfs( + visit_action=lambda node: _render_graph_node(jinja_env, node) + ) - # now that we determined value of all referenced vars, use these resolved values as a fresh context to resolve project file later - final_context = {} - node: TemplateGraphNode + # now that we determined the values of all tempalted vars, + # use these resolved values as a fresh context to resolve definition + final_context: dict = {} for node in dependencies_graph.get_all_nodes(): - node.store_in_context(final_context) + node.data.store_in_context(final_context, node.data.rendered_value) + + deep_traverse( + definition, + update_action=lambda val: _render_dict_element(jinja_env, final_context, val), + ) + deep_merge_dicts(definition, {"env": dict(os.environ)}) - # resolve combined project file based on final context - template = jinja_env.from_string(pdf_yaml_str) - rendered_template = template.render(final_context) - rendered_definition = yaml.load(rendered_template, Loader=yaml.loader.BaseLoader) - return rendered_definition + return definition diff --git a/src/snowflake/cli/api/utils/dict_utils.py b/src/snowflake/cli/api/utils/dict_utils.py new file mode 100644 index 0000000000..1277cfd072 --- /dev/null +++ b/src/snowflake/cli/api/utils/dict_utils.py @@ -0,0 +1,53 @@ +from __future__ import annotations + + +def deep_merge_dicts(original_values: dict, override_values: dict): + """ + Takes 2 dictionaries as input: original and override. + + For every key in the override dictionary, override the same key + in the original dictionary, or create a new one if it doesn't exist. + + If the override value and the original value are both dictionaries, + instead of overriding, recursively call this function to merge the keys of the sub-dictionaries. + """ + if not isinstance(override_values, dict) or not isinstance(original_values, dict): + return + + for field, value in override_values.items(): + if ( + field in original_values + and isinstance(original_values[field], dict) + and isinstance(value, dict) + ): + deep_merge_dicts(original_values[field], value) + else: + original_values[field] = value + + +def deep_traverse( + element, visit_action=lambda element: None, update_action=lambda element: element +): + """ + Traverse a nested structure (lists, dicts, scalars). + + On traversal, it allows for actions or updates on each visit. + + visit_action: caller can provide a function to execute on each scalar element in the structure (leaves of the tree). + visit_action accepts an element (scalar) as input. Return value is ignored. + + update_action: caller can provide a function to update each scalar element in the structure. + update_action accepts an element (scalar) as input, and returns the modified value. + + """ + if isinstance(element, dict): + for key, value in element.items(): + element[key] = deep_traverse(value, visit_action, update_action) + return element + elif isinstance(element, list): + for index, value in enumerate(element): + element[index] = deep_traverse(value, visit_action, update_action) + return element + else: + visit_action(element) + return update_action(element) diff --git a/src/snowflake/cli/api/utils/graph.py b/src/snowflake/cli/api/utils/graph.py index 6c704b7008..2135bf36b5 100644 --- a/src/snowflake/cli/api/utils/graph.py +++ b/src/snowflake/cli/api/utils/graph.py @@ -1,50 +1,57 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Generic, Optional, TypeVar +from click import ClickException -class Graph: - @dataclass - class Node: - key: str - status: str | None = None - dependencies: set[Graph.Node] = field(default_factory=set) +T = TypeVar("T") - def __eq__(self, other): - return self.key == other.key - def __hash__(self): - return hash(self.key) +@dataclass +class Node(Generic[T]): + key: str + data: T + neighbors: set[Node[T]] = field(default_factory=set) + status: Optional[str] = None + def __hash__(self): + return hash(self.key) + + def __eq__(self, other): + return self.key == other.key + + +class Graph(Generic[T]): def __init__(self): - self._graph_nodes_map: dict[str, Graph.Node] = {} + self._graph_nodes_map: dict[str, Node[T]] = {} - def get(self, key: str): + def get(self, key: str) -> Node[T]: if key in self._graph_nodes_map: - return self._graph_nodes_map.get(key) + return self._graph_nodes_map[key] raise KeyError(f"Node with key {key} not found") - def get_all_nodes(self): - return self._graph_nodes_map.values() + def get_all_nodes(self) -> set[Node[T]]: + return set(self._graph_nodes_map.values()) - def add(self, node: Node): + def add(self, node: Node[T]): if node.key in self._graph_nodes_map: raise KeyError(f"Node key {node.key} already exists") self._graph_nodes_map[node.key] = node - def add_dependency(self, key1: str, key2: str): - node1 = self.get(key1) - node2 = self.get(key2) - node1.dependencies.add(node2) + def add_directed_edge(self, from_node_key: str, to_node_key: str): + from_node = self.get(from_node_key) + to_node = self.get(to_node_key) + from_node.neighbors.add(to_node) - def _dfs_visit(self, node: Node, visit_action): + def _dfs_visit(self, node: Node[T], visit_action): if node.status == "VISITED": return node.status = "VISITING" - for neighbour_node in node.dependencies: + for neighbour_node in node.neighbors: if neighbour_node.status == "VISITING": - raise RecursionError("Cycle detected") + raise ClickException("Cycle detected") self._dfs_visit(neighbour_node, visit_action) visit_action(node) diff --git a/src/snowflake/cli/api/utils/rendering.py b/src/snowflake/cli/api/utils/rendering.py index 15826d2ca8..90845eec75 100644 --- a/src/snowflake/cli/api/utils/rendering.py +++ b/src/snowflake/cli/api/utils/rendering.py @@ -1,22 +1,16 @@ from __future__ import annotations -import re -from collections import defaultdict, deque from pathlib import Path from textwrap import dedent -from typing import Dict, List, Optional, Set, cast +from typing import Dict, Optional import jinja2 from click import ClickException -from jinja2 import Environment, StrictUndefined, UndefinedError, loaders +from jinja2 import Environment, StrictUndefined, loaders from snowflake.cli.api.cli_global_context import cli_context -from snowflake.cli.api.project.schemas.project_definition import ( - ProjectDefinition, -) from snowflake.cli.api.secure_path import UNLIMITED, SecurePath -from snowflake.cli.api.utils.models import EnvironWithDefinedDictFallback -_CONTEXT_KEY = "ctx" +CONTEXT_KEY = "ctx" _YML_TEMPLATE_START = "<%" _YML_TEMPLATE_END = "%>" @@ -115,152 +109,13 @@ def jinja_render_from_file( return rendered_result -def _add_project_context(project_definition: ProjectDefinition) -> Dict: - """ - Updates the external data with variables from snowflake.yml definition file. - """ - context_data = _resolve_variables_in_project(project_definition) - return context_data - - -def _remove_ctx_env_prefix(text: str) -> str: - prefix = "ctx.env." - if text.startswith(prefix): - return text[len(prefix) :] - return text - - -def string_includes_template(text: str) -> bool: - return bool(re.search(rf"{_YML_TEMPLATE_START}.+{_YML_TEMPLATE_END}", text)) - - -def _resolve_variables_in_project(project_definition: ProjectDefinition): - # If there's project definition file then resolve variables from it - if not project_definition or not project_definition.meets_version_requirement( - "1.1" - ): - return {_CONTEXT_KEY: {"env": EnvironWithDefinedDictFallback({})}} - - variables_data: EnvironWithDefinedDictFallback = cast( - EnvironWithDefinedDictFallback, project_definition.env - ) - - env_with_unresolved_keys: List[str] = _get_variables_with_dependencies( - variables_data - ) - other_env = set(variables_data) - set(env_with_unresolved_keys) - env = get_snowflake_cli_jinja_env() - context_data = {_CONTEXT_KEY: project_definition} - - # Resolve env section dependencies - while env_with_unresolved_keys: - key = env_with_unresolved_keys.pop() - value = variables_data[key] - if not isinstance(value, str): - continue - try: # try to evaluate the template given current state of know variables - variables_data[key] = env.from_string(value).render(context_data) - if string_includes_template(variables_data[key]): - env_with_unresolved_keys.append(key) - except UndefinedError: - env_with_unresolved_keys.append(key) - - # Resolve templates in variables without references, for example - for key in other_env: - variable_value = variables_data[key] - if not isinstance(variable_value, str): - continue - variables_data[key] = env.from_string(variable_value).render(context_data) - - return context_data - - -def _check_for_cycles(nodes: defaultdict): - nodes = nodes.copy() - for key in list(nodes): - q = deque([key]) - visited: List[str] = [] - while q: - curr = q.popleft() - if curr in visited: - raise ClickException( - "Cycle detected between variables: {}".format(" -> ".join(visited)) - ) - # Only nodes that have references can cause cycles - if curr in nodes: - visited.append(curr) - q.extendleft(nodes[curr]) - - -def _get_variables_with_dependencies(variables_data: EnvironWithDefinedDictFallback): - """ - Checks consistency of provided dictionary by - 1. checking reference cycles - 2. checking for missing variables - """ - # Variables that are not specified in env section - missing_variables: Set[str] = set() - # Variables that require other variables - variables_with_dependencies = defaultdict(list) - - for key, value in variables_data.items(): - # Templates are reserved only to string variables - if not isinstance(value, str): - continue - - required_variables = _search_for_required_variables(value) - - if required_variables: - variables_with_dependencies[key] = required_variables - - for variable in required_variables: - if variable not in variables_data: - missing_variables.add(variable) - - # If there are unknown env variables then we raise an error - if missing_variables: - raise ClickException( - "The following variables are used in environment definition but are not defined: {}".format( - ", ".join(missing_variables) - ) - ) - - # Look for cycles between variables - _check_for_cycles(variables_with_dependencies) - - # Sort by number of dependencies - return sorted( - list(variables_with_dependencies.keys()), - key=lambda k: len(variables_with_dependencies[k]), - ) - - -def _search_for_required_variables(variable_value: str): - """ - Look for pattern in variable value. Returns a list of env variables required - to expand this template.` - """ - ctx_env_prefix = f"{_CONTEXT_KEY}.env." - found_variables = re.findall( - rf"({_YML_TEMPLATE_START}([\.\w ]+){_YML_TEMPLATE_END})+", variable_value - ) - required_variables = [] - for _, variable in found_variables: - var: str = variable.strip() - if var.startswith(ctx_env_prefix): - required_variables.append(var[len(ctx_env_prefix) :]) - return required_variables - - def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str: data = data or {} - if _CONTEXT_KEY in data: + if CONTEXT_KEY in data: raise ClickException( - f"{_CONTEXT_KEY} in user defined data. The `{_CONTEXT_KEY}` variable is reserved for CLI usage." + f"{CONTEXT_KEY} in user defined data. The `{CONTEXT_KEY}` variable is reserved for CLI usage." ) - context_data = _add_project_context( - project_definition=cli_context.project_definition - ) + context_data = cli_context.template_context context_data.update(data) return get_sql_cli_jinja_env().from_string(content).render(**context_data) diff --git a/tests/api/utils/test_definition_rendering.py b/tests/api/utils/test_definition_rendering.py new file mode 100644 index 0000000000..a17aaca4b2 --- /dev/null +++ b/tests/api/utils/test_definition_rendering.py @@ -0,0 +1,303 @@ +import os +from pathlib import Path +from tempfile import NamedTemporaryFile +from textwrap import dedent +from unittest import mock + +import pytest +from click import ClickException +from jinja2 import UndefinedError +from snowflake.cli.api.project.definition import load_project +from snowflake.cli.api.utils.definition_rendering import render_definition_template + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_resolve_variables_in_project_no_cross_variable_dependencies(): + definition = { + "definition_version": "1.1", + "env": { + "number": 1, + "string": "foo", + "boolean": True, + }, + } + + result = render_definition_template(definition) + + assert result == { + "definition_version": "1.1", + "env": {"number": 1, "string": "foo", "boolean": True}, + } + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_resolve_variables_in_project_cross_variable_dependencies(): + definition = { + "definition_version": "1.1", + "env": { + "A": 42, + "B": "b=<% ctx.env.A %>", + "C": "<% ctx.env.B %> and <% ctx.env.A %>", + }, + } + result = render_definition_template(definition) + + assert result == { + "definition_version": "1.1", + "env": {"A": 42, "B": "b=42", "C": "b=42 and 42"}, + } + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_no_resolve_in_version_1_0(): + definition = { + "definition_version": "1.0", + "env": { + "A": 42, + "B": "b=<% ctx.env.A %>", + "C": "<% ctx.env.B %> and <% ctx.env.A %>", + }, + } + result = render_definition_template(definition) + + assert result == { + "definition_version": "1.0", + "env": { + "A": 42, + "B": "b=<% ctx.env.A %>", + "C": "<% ctx.env.B %> and <% ctx.env.A %>", + }, + } + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_resolve_variables_in_project_cross_project_dependencies(): + definition = { + "definition_version": "1.1", + "streamlit": {"name": "my_app"}, + "env": {"app": "name of streamlit is <% ctx.streamlit.name %>"}, + } + result = render_definition_template(definition) + assert result == { + "definition_version": "1.1", + "streamlit": {"name": "my_app"}, + "env": { + "app": "name of streamlit is my_app", + }, + } + + +@mock.patch.dict( + os.environ, + { + "lowercase": "new_lowercase_value", + "UPPERCASE": "new_uppercase_value", + "should_be_replace_by_env": "test succeeded", + "value_from_env": "this comes from os.environ", + }, + clear=True, +) +def test_resolve_variables_in_project_environment_variables_precedence(): + definition = { + "definition_version": "1.1", + "env": { + "should_be_replace_by_env": "test failed", + "test_variable": "<% ctx.env.lowercase %> and <% ctx.env.UPPERCASE %>", + "test_variable_2": "<% ctx.env.value_from_env %>", + }, + } + result = render_definition_template(definition) + + assert result == { + "definition_version": "1.1", + "env": { + "UPPERCASE": "new_uppercase_value", + "lowercase": "new_lowercase_value", + "should_be_replace_by_env": "test succeeded", + "test_variable": "new_lowercase_value and new_uppercase_value", + "test_variable_2": "this comes from os.environ", + "value_from_env": "this comes from os.environ", + }, + } + + +@mock.patch.dict(os.environ, {"env_var": "<% ctx.definition_version %>"}, clear=True) +def test_env_variables_do_not_get_resolved(): + definition = { + "definition_version": "1.1", + "native_app": { + "name": "test_source_<% ctx.env.env_var %>", + }, + "env": { + "reference_to_name": "<% ctx.native_app.name %>", + }, + } + result = render_definition_template(definition) + + assert result == { + "definition_version": "1.1", + "native_app": { + "name": "test_source_<% ctx.definition_version %>", + }, + "env": { + "reference_to_name": "test_source_<% ctx.definition_version %>", + "env_var": "<% ctx.definition_version %>", + }, + } + + +@pytest.mark.parametrize( + "definition", + [ + {"definition_version": "1.1", "env": {"A": "<% ctx.env.A %>"}}, + { + "definition_version": "1.1", + "env": {"A": "<% ctx.env.B %>", "B": "<% ctx.env.A %>"}, + }, + { + "definition_version": "1.1", + "env": { + "A": "<% ctx.env.B %>", + "B": "<% ctx.env.C %>", + "C": "<% ctx.env.D %>", + "D": "<% ctx.env.A %>", + }, + }, + { + "definition_version": "1.1", + "native_app": {"name": "test_<% ctx.env.A %>"}, + "env": {"A": "<% ctx.native_app.name %>"}, + }, + { + "definition_version": "1.1", + "native_app": {"name": "test_<% ctx.native_app.name %>"}, + }, + { + "definition_version": "1.1", + "native_app": { + "name": "test_<% ctx.native_app.source_stage %>", + "source_stage": "stage <% ctx.native_app.name %>", + }, + }, + ], +) +def test_resolve_variables_error_on_cycle(definition): + with pytest.raises(ClickException) as err: + render_definition_template(definition) + + assert err.value.message == f"Cycle detected" + + +@pytest.mark.parametrize( + "definition, error_var", + [ + ( + { + "definition_version": "1.1", + "native_app": { + "name": "app_name", + "artifacts": [{"src": "src/add.py", "dest": "add.py"}], + }, + "env": {"A": "<% ctx.native_app.artifacts %>"}, + }, + "ctx.native_app.artifacts", + ), + ( + { + "definition_version": "1.1", + "native_app": { + "name": "app_name", + "artifacts": [{"src": "src/add.py", "dest": "add.py"}], + }, + "env": {"A": "<% ctx.native_app %>"}, + }, + "ctx.native_app", + ), + ], +) +def test_resolve_variables_reference_non_scalar(definition, error_var): + with pytest.raises(UndefinedError) as err: + render_definition_template(definition) + + assert ( + err.value.message + == f"Template variable {error_var} does not contain a valid value" + ) + + +@mock.patch.dict(os.environ, {"blank_env": ""}, clear=True) +def test_resolve_variables_blank_is_ok(): + definition = { + "definition_version": "1.1", + "native_app": { + "name": "<% ctx.env.blank_default_env %>", + "source_stage": "", + "deploy_root": "<% ctx.env.blank_env %>", + }, + "env": { + "blank_default_env": "", + "refers_to_blank": "<% ctx.native_app.source_stage %>", + }, + } + result = render_definition_template(definition) + + assert result == { + "definition_version": "1.1", + "native_app": {"name": "", "source_stage": "", "deploy_root": ""}, + "env": { + "blank_env": "", + "blank_default_env": "", + "refers_to_blank": "", + }, + } + + +@pytest.mark.parametrize( + "env, msg", + [ + ({"app": "<% bdbdbd %>"}, "Could not find template variable bdbdbd"), + ( + {"app": "<% ctx.streamlit.name %>"}, + "Could not find template variable ctx.streamlit.name", + ), + ({"app": "<% ctx.foo %>"}, "Could not find template variable ctx.foo"), + ], +) +def test_resolve_variables_fails_if_referencing_unknown_variable(env, msg): + definition = { + "definition_version": "1.1", + "env": env, + } + with pytest.raises(UndefinedError) as err: + render_definition_template(definition) + assert msg in str(err.value) + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_unquoted_template_usage_in_strings_yaml(): + text = """\ + definition_version: "1.1" + env: + value: "Snowflake is great!" + single_line: <% ctx.env.value %> + flow_multiline_quoted: "this is + multiline string with template <% ctx.env.value %>" + flow_multiline_not_quoted: this is + multiline string with template <% ctx.env.value %> + block_multiline: | + this is multiline string + with template <% ctx.env.value %> + """ + + with NamedTemporaryFile(suffix=".yml") as file: + p = Path(file.name) + p.write_text(dedent(text)) + project_definition = load_project([p]).project_definition + + assert project_definition.env == { + "block_multiline": "this is multiline string \nwith template Snowflake is great!\n", + "flow_multiline_not_quoted": "this is multiline string with template Snowflake is great!", + "flow_multiline_quoted": "this is multiline string with template Snowflake is great!", + "single_line": "Snowflake is great!", + "value": "Snowflake is great!", + } diff --git a/tests/api/utils/test_dict_utils.py b/tests/api/utils/test_dict_utils.py new file mode 100644 index 0000000000..0a8eba0c0f --- /dev/null +++ b/tests/api/utils/test_dict_utils.py @@ -0,0 +1,101 @@ +from snowflake.cli.api.utils.dict_utils import deep_merge_dicts, deep_traverse + + +def test_merge_dicts_empty(): + test_dict = {} + deep_merge_dicts(test_dict, {}) + assert test_dict == {} + + +def test_merge_dicts_recursive_map(): + test_dict = {"a": "a1", "b": "b1", "c": {"d": "d1", "e": "e1"}} + deep_merge_dicts(test_dict, {"c": {"d": "d2", "a": "a2"}}) + assert test_dict == {"a": "a1", "b": "b1", "c": {"d": "d2", "e": "e1", "a": "a2"}} + + +def test_merge_dicts_recursive_scalar_replace_map(): + test_dict = {"a": "a1", "b": "b1", "c": {"d": "d1", "e": "e1"}} + deep_merge_dicts(test_dict, {"c": "c2"}) + assert test_dict == {"a": "a1", "b": "b1", "c": "c2"} + + +def test_merge_dicts_recursive_two_arrays(): + test_dict = {"a": "a1", "b": {"c": ["c1", "c2", "c3"], "d": "d1"}} + deep_merge_dicts(test_dict, {"b": {"c": ["c4", "c5"]}}) + assert test_dict == {"a": "a1", "b": {"c": ["c4", "c5"], "d": "d1"}} + + +def test_deep_traverse_on_map(): + test_struct = { + "scalar_key": "hello", + "map_key": {"key1": "value1", "key2": "value2", "key3": True}, + "array_key": ["array1", "array2", 333, {"nestedKey1": "nestedVal1"}], + } + + visited_elements = [] + + def visit_action(element): + visited_elements.append(element) + + deep_traverse(test_struct, visit_action) + + assert visited_elements == [ + "hello", + "value1", + "value2", + True, + "array1", + "array2", + 333, + "nestedVal1", + ] + + +def test_deep_traverse_on_list(): + test_struct = ["val1", 123, False, {"mapKey1": "mapVal1"}] + + visited_elements = [] + + def visit_action(element): + visited_elements.append(element) + + deep_traverse(test_struct, visit_action) + + assert visited_elements == ["val1", 123, False, "mapVal1"] + + +def test_deep_traverse_on_scalar(): + test_struct = 444 + + visited_elements = [] + + def visit_action(element): + visited_elements.append(element) + + deep_traverse(test_struct, visit_action) + + assert visited_elements == [444] + + +def test_deep_traverse_with_updates(): + test_struct = { + "scalar_key": "hello", + "map_key": {"key1": "value1", "key3": True}, + "array_key": ["array1", 333, {"nestedKey1": "nestedVal1"}], + } + + visited_elements = [] + + def update_action(element): + if isinstance(element, str): + return element + "_" + else: + return None + + deep_traverse(test_struct, update_action=update_action) + + assert test_struct == { + "scalar_key": "hello_", + "map_key": {"key1": "value1_", "key3": None}, + "array_key": ["array1_", None, {"nestedKey1": "nestedVal1_"}], + } diff --git a/tests/api/utils/test_graph.py b/tests/api/utils/test_graph.py new file mode 100644 index 0000000000..9b4ac67e8e --- /dev/null +++ b/tests/api/utils/test_graph.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import pytest +from snowflake.cli.api.utils.graph import Graph, Node + + +@pytest.fixture +def nodes() -> list[Node]: + nodes = [] + for i in range(5): + nodes.append(Node(key=i, data=str(i))) + return nodes + + +def test_create_new_graph(nodes: list[Node]): + graph = Graph() + graph.add(nodes[0]) + graph.add(nodes[1]) + assert graph.get_all_nodes() == set([nodes[0], nodes[1]]) + + +def test_add_edges(nodes: list[Node]): + graph = Graph() + graph.add(nodes[0]) + graph.add(nodes[1]) + graph.add_directed_edge(nodes[0].key, nodes[1].key) + + assert nodes[0].neighbors == set([nodes[1]]) + + +def test_dfs(nodes: list[Node]): + graph = Graph() + for i in range(5): + graph.add(nodes[i]) + + graph.add_directed_edge(nodes[0].key, nodes[1].key) + graph.add_directed_edge(nodes[1].key, nodes[2].key) + + graph.add_directed_edge(nodes[0].key, nodes[3].key) + graph.add_directed_edge(nodes[3].key, nodes[4].key) + + visits: list[Node] = [] + + def track_visits_order(node: Node): + visits.append(node) + + graph.dfs(visit_action=track_visits_order) + + assert visits == [nodes[4], nodes[3], nodes[2], nodes[1], nodes[0]] or visits == [ + nodes[2], + nodes[1], + nodes[4], + nodes[3], + nodes[0], + ] diff --git a/tests/api/utils/test_rendering.py b/tests/api/utils/test_rendering.py index ab9b242c0c..2a6aa45652 100644 --- a/tests/api/utils/test_rendering.py +++ b/tests/api/utils/test_rendering.py @@ -1,24 +1,19 @@ import os -from pathlib import Path -from tempfile import NamedTemporaryFile -from textwrap import dedent from unittest import mock import pytest -from click import ClickException from jinja2 import UndefinedError -from snowflake.cli.api.project.definition import load_project_definition -from snowflake.cli.api.project.schemas.project_definition import ( - ProjectDefinition, -) -from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit -from snowflake.cli.api.utils.rendering import ( - _add_project_context, - snowflake_sql_jinja_render, -) +from snowflake.cli.api.utils.rendering import snowflake_sql_jinja_render + +@pytest.fixture +def cli_context(): + with mock.patch("snowflake.cli.api.utils.rendering.cli_context") as cli_context: + cli_context.template_context = {"ctx": {"env": os.environ}} + yield cli_context -def test_rendering_with_data(): + +def test_rendering_with_data(cli_context): assert snowflake_sql_jinja_render("&{ foo }", data={"foo": "bar"}) == "bar" @@ -35,7 +30,7 @@ def test_rendering_with_data(): ("$&{ foo }", "$bar"), ], ) -def test_rendering(text, output): +def test_rendering(text, output, cli_context): assert snowflake_sql_jinja_render(text, data={"foo": "bar"}) == output @@ -55,11 +50,11 @@ def test_rendering(text, output): """, ], ) -def test_that_common_logic_block_are_ignored(text): +def test_that_common_logic_block_are_ignored(text, cli_context): assert snowflake_sql_jinja_render(text) == text -def test_that_common_comments_are_respected(): +def test_that_common_comments_are_respected(cli_context): # Make sure comment are ignored assert snowflake_sql_jinja_render("{# note a comment &{ foo } #}") == "" # Make sure comment's work together with templates @@ -69,188 +64,10 @@ def test_that_common_comments_are_respected(): ) -def test_that_undefined_variables_raise_error(): +def test_that_undefined_variables_raise_error(cli_context): with pytest.raises(UndefinedError): snowflake_sql_jinja_render("&{ foo }") -def test_contex_can_access_environment_variable(): +def test_contex_can_access_environment_variable(cli_context): assert snowflake_sql_jinja_render("&{ ctx.env.USER }") == os.environ.get("USER") - - -def test_resolve_variables_in_project_no_cross_variable_dependencies(): - pdf = ProjectDefinition( - definition_version="1.1", - env={ - "number": 1, - "string": "foo", - "boolean": True, - }, - ) - result = _add_project_context(project_definition=pdf) - assert result == { - "ctx": ProjectDefinition( - definition_version="1.1", - native_app=None, - snowpark=None, - streamlit=None, - env={"number": 1, "string": "foo", "boolean": True}, - ) - } - - -def test_resolve_variables_in_project_cross_variable_dependencies(): - pdf = ProjectDefinition( - definition_version="1.1", - env={ - "A": 42, - "B": "b=<% ctx.env.A %>", - "C": "<% ctx.env.B %> and <% ctx.env.A %>", - }, - ) - result = _add_project_context(project_definition=pdf) - assert result == { - "ctx": ProjectDefinition( - definition_version="1.1", - native_app=None, - snowpark=None, - streamlit=None, - env={"A": 42, "B": "b=42", "C": "b=42 and 42"}, - ) - } - - -def test_resolve_variables_in_project_cross_project_dependencies(): - pdf = ProjectDefinition( - definition_version="1.1", - streamlit=Streamlit(name="my_app"), - env={"app": "name of streamlit is <% ctx.streamlit.name %>"}, - ) - result = _add_project_context(project_definition=pdf) - assert result == { - "ctx": ProjectDefinition( - definition_version="1.1", - native_app=None, - snowpark=None, - streamlit=Streamlit( - name="my_app", - stage="streamlit", - query_warehouse="streamlit", - main_file="streamlit_app.py", - env_file=None, - pages_dir=None, - additional_source_files=None, - ), - env={"app": "name of streamlit is my_app"}, - ) - } - - -@mock.patch.dict( - os.environ, - { - "lowercase": "new_lowercase_value", - "UPPERCASE": "new_uppercase_value", - "should_be_replace_by_env": "test succeeded", - "value_from_env": "this comes from os.environ", - }, -) -def test_resolve_variables_in_project_environment_variables_precedence(): - pdf = ProjectDefinition( - definition_version="1.1", - env={ - "should_be_replace_by_env": "test failed", - "test_variable": "<% ctx.env.lowercase %> and <% ctx.env.UPPERCASE %>", - "test_variable_2": "<% ctx.env.value_from_env %>", - }, - ) - result = _add_project_context(project_definition=pdf) - - assert result == { - "ctx": ProjectDefinition( - definition_version="1.1", - native_app=None, - snowpark=None, - streamlit=None, - env={ - "should_be_replace_by_env": "test succeeded", - "test_variable": "new_lowercase_value and new_uppercase_value", - "test_variable_2": "this comes from os.environ", - }, - ) - } - - -@pytest.mark.parametrize( - "env, cycle", - [ - ({"A": "<% ctx.env.A %>"}, "A"), - ({"A": "<% ctx.env.B %>", "B": "<% ctx.env.A %>"}, "A -> B"), - ( - { - "A": "<% ctx.env.B %>", - "B": "<% ctx.env.C %>", - "C": "<% ctx.env.D %>", - "D": "<% ctx.env.A %>", - }, - "A -> B -> C -> D", - ), - ], -) -def test_resolve_variables_error_on_cycle(env, cycle): - pdf = ProjectDefinition( - definition_version="1.1", - env=env, - ) - with pytest.raises(ClickException) as err: - _add_project_context(project_definition=pdf) - - assert err.value.message == f"Cycle detected between variables: {cycle}" - - -@pytest.mark.parametrize( - "env, msg", - [ - ({"app": "<% bdbdbd %>"}, "'bdbdbd' is undefined"), - ({"app": "<% ctx.streamlit.name %>"}, "'None' has no attribute 'name'"), - ({"app": "<% ctx.foo %>"}, "has no attribute 'foo'"), - ], -) -def test_resolve_variables_fails_if_referencing_unknown_variable(env, msg): - pdf = ProjectDefinition( - definition_version="1.1", - env=env, - ) - with pytest.raises(UndefinedError) as err: - _add_project_context(project_definition=pdf) - assert msg in str(err.value) - - -def tests_unquoted_template_usage_in_strings_yaml(): - text = """\ - definition_version: "1.1" - env: - value: "Snowflake is great!" - single_line: <% ctx.env.value %> - flow_multiline_quoted: "this is - multiline string with template <% ctx.env.value %>" - flow_multiline_not_quoted: this is - multiline string with template <% ctx.env.value %> - block_multiline: | - this is multiline string - with template <% ctx.env.value %> - """ - - with NamedTemporaryFile(suffix=".yml") as file: - p = Path(file.name) - p.write_text(dedent(text)) - definition = load_project_definition([p]) - - _add_project_context(project_definition=definition) - assert definition.env == { - "block_multiline": "this is multiline string \nwith template Snowflake is great!\n", - "flow_multiline_not_quoted": "this is multiline string with template Snowflake is great!", - "flow_multiline_quoted": "this is multiline string with template Snowflake is great!", - "single_line": "Snowflake is great!", - "value": "Snowflake is great!", - } diff --git a/tests/nativeapp/test_annotation_processor_config.py b/tests/nativeapp/test_annotation_processor_config.py index 2ef81c5ba7..5eaed5aa4a 100644 --- a/tests/nativeapp/test_annotation_processor_config.py +++ b/tests/nativeapp/test_annotation_processor_config.py @@ -1,7 +1,5 @@ import pytest -from snowflake.cli.api.project.definition import ( - load_project_definition, -) +from snowflake.cli.api.project.definition import load_project from snowflake.cli.api.project.schemas.native_app.path_mapping import ProcessorMapping @@ -9,7 +7,7 @@ "project_definition_files", ["napp_with_annotation_processor"], indirect=True ) def test_napp_project_with_annotation_processor(project_definition_files): - project = load_project_definition(project_definition_files) + project = load_project(project_definition_files).project_definition assert len(project.native_app.artifacts) == 3 result = project.native_app.artifacts[2] diff --git a/tests/nativeapp/test_artifacts.py b/tests/nativeapp/test_artifacts.py index 5b56af650e..bace344499 100644 --- a/tests/nativeapp/test_artifacts.py +++ b/tests/nativeapp/test_artifacts.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Union import pytest -from snowflake.cli.api.project.definition import load_project_definition +from snowflake.cli.api.project.definition import load_project from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping from snowflake.cli.plugins.nativeapp.artifacts import ( ArtifactError, @@ -702,7 +702,7 @@ def test_bundle_map_ignores_sources_in_deploy_root(bundle_map): @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) def test_napp_project_1_artifacts(project_definition_files, snapshot): project_root = project_definition_files[0].parent - native_app = load_project_definition(project_definition_files).native_app + native_app = load_project(project_definition_files).project_definition.native_app with pushd(project_root) as local_path: deploy_root = Path(local_path, native_app.deploy_root) diff --git a/tests/project/fixtures.py b/tests/project/fixtures.py index d0487dce54..3147edfba9 100644 --- a/tests/project/fixtures.py +++ b/tests/project/fixtures.py @@ -49,7 +49,7 @@ def project_definition_files(request) -> Generator[List[Path], None, None]: Expects indirect parameterization, e.g. @pytest.mark.parametrize("project_definition_files", ["project_1"], indirect=True) def test_my_project(project_definition_files): - project = load_project_definition(project_definition_files) + project = load_project(project_definition_files).project_definition """ dir_name = request.param with snowflake_ymls(dir_name) as ymls: diff --git a/tests/project/test_config.py b/tests/project/test_config.py index 79a43748d2..18783a48a3 100644 --- a/tests/project/test_config.py +++ b/tests/project/test_config.py @@ -8,7 +8,7 @@ import pytest from snowflake.cli.api.project.definition import ( generate_local_override_yml, - load_project_definition, + load_project, ) from snowflake.cli.api.project.errors import SchemaValidationError from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping @@ -17,7 +17,7 @@ @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) def test_napp_project_1(project_definition_files): - project = load_project_definition(project_definition_files) + project = load_project(project_definition_files).project_definition assert project.native_app.name == "myapp" assert project.native_app.deploy_root == "output/deploy/" assert project.native_app.package.role == "accountadmin" @@ -28,7 +28,7 @@ def test_napp_project_1(project_definition_files): @pytest.mark.parametrize("project_definition_files", ["minimal"], indirect=True) def test_na_minimal_project(project_definition_files: List[Path]): - project = load_project_definition(project_definition_files) + project = load_project(project_definition_files).project_definition assert project.native_app.name == "minimal" assert project.native_app.artifacts == [ PathMapping(src="setup.sql"), @@ -64,7 +64,7 @@ def mock_getenv(key: str, default: Optional[str] = None) -> Optional[str]: @pytest.mark.parametrize("project_definition_files", ["underspecified"], indirect=True) def test_underspecified_project(project_definition_files): with pytest.raises(SchemaValidationError) as exc_info: - load_project_definition(project_definition_files) + load_project(project_definition_files).project_definition assert "NativeApp" in str(exc_info) assert "Your project definition is missing following fields: ('artifacts',)" in str( @@ -77,7 +77,7 @@ def test_underspecified_project(project_definition_files): ) def test_fails_without_definition_version(project_definition_files): with pytest.raises(SchemaValidationError) as exc_info: - load_project_definition(project_definition_files) + load_project(project_definition_files).project_definition assert "ProjectDefinition" in str(exc_info) assert ( @@ -89,7 +89,7 @@ def test_fails_without_definition_version(project_definition_files): @pytest.mark.parametrize("project_definition_files", ["unknown_fields"], indirect=True) def test_does_not_accept_unknown_fields(project_definition_files): with pytest.raises(SchemaValidationError) as exc_info: - load_project_definition(project_definition_files) + load_project(project_definition_files).project_definition assert "NativeApp" in str(exc_info) assert ( @@ -120,7 +120,7 @@ def test_does_not_accept_unknown_fields(project_definition_files): indirect=True, ) def test_fields_are_parsed_correctly(project_definition_files, snapshot): - result = load_project_definition(project_definition_files).model_dump() + result = load_project(project_definition_files).project_definition.model_dump() assert result == snapshot diff --git a/tests/streamlit/test_config.py b/tests/streamlit/test_config.py index 86d51c1b7a..724a00f731 100644 --- a/tests/streamlit/test_config.py +++ b/tests/streamlit/test_config.py @@ -2,7 +2,7 @@ import pytest -from src.snowflake.cli.api.project.definition import load_project_definition +from src.snowflake.cli.api.project.definition import load_project TEST_DATA = Path(__file__).parent.parent / "test_data" / "streamlit" FILE_WITH_LONG_LIST = TEST_DATA / "with_list_in_source_file.yml" @@ -28,6 +28,6 @@ ) def test_load_project_definition(test_files, expected): - result = load_project_definition(test_files) + result = load_project(test_files).project_definition assert expected in result.streamlit.additional_source_files