Skip to content

Commit

Permalink
Update algorithm of templating and refactor Graph traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-melnacouzi committed May 30, 2024
1 parent 9cdc5ea commit d60f3c7
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 141 deletions.
27 changes: 22 additions & 5 deletions src/snowflake/cli/api/project/definition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import List
from typing import Any, Dict, List

import yaml.loader
from snowflake.cli.api.cli_global_context import cli_context
Expand All @@ -14,12 +14,29 @@
to_identifier,
)
from snowflake.cli.api.secure_path import SecurePath
from snowflake.cli.api.utils.definition_rendering import render_project_template
from snowflake.cli.api.utils.definition_rendering import render_definition_template
from yaml import load

DEFAULT_USERNAME = "unknown_user"


def merge_two_dicts(
original_values: Dict[str, Any], update_values: Dict[str, Any]
): # TODO update name of function
if not isinstance(update_values, dict) or not isinstance(original_values, dict):
return

for field, value in update_values.items():
if (
field in original_values
and isinstance(original_values[field], dict)
and isinstance(value, dict)
):
merge_two_dicts(original_values[field], value)
else:
original_values[field] = value


def load_project_definition(paths: List[Path]) -> ProjectDefinition:
"""
Loads project definition, optionally overriding values. Definition values
Expand All @@ -31,16 +48,16 @@ def load_project_definition(paths: List[Path]) -> ProjectDefinition:

with spaths[0].open("r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB) as base_yml:
definition = load(base_yml.read(), Loader=yaml.loader.BaseLoader)
project = ProjectDefinition(**definition)

for override_path in spaths[1:]:
with override_path.open(
"r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB
) as override_yml:
overrides = load(override_yml.read(), Loader=yaml.loader.BaseLoader)
project.update_from_dict(overrides)
merge_two_dicts(definition, overrides)

rendered_project = render_project_template(project)
rendered_definition = render_definition_template(definition)
rendered_project = ProjectDefinition(**rendered_definition)
return rendered_project


Expand Down
240 changes: 106 additions & 134 deletions src/snowflake/cli/api/utils/definition_rendering.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,94 @@
from __future__ import annotations

import copy
import os
from collections import deque
from dataclasses import dataclass
from typing import List
from typing import Optional

import yaml
from jinja2 import Environment, UndefinedError, nodes
from snowflake.cli.api.project.schemas.project_definition import ProjectDefinition
from snowflake.cli.api.utils.graph import Graph
from snowflake.cli.api.utils.rendering import get_snowflake_cli_jinja_env


@dataclass
class GraphNode:
path: str
value: str | None = None
class Variable:
def __init__(self, vars_chain):
self._vars_chain = list(vars_chain)
self.value: str | None = None

def get_vars_hierarchy(self) -> deque[str]:
return deque(self._vars_chain)

def _get_referenced_vars(ast_node, attr_chain: List[str] = []) -> List[List[str]]:
def get_key(self):
return ".".join(self._vars_chain)

def __hash__(self):
return hash(self.get_key())

def __eq__(self, other):
return self.get_key() == other.get_key()


@dataclass(eq=False)
class TemplateGraphNode(Graph.Node):
templated_value: str | None = None
rendered_value: str | None = None
variable: Variable | None = None

def store_in_context(self, context):
# if graph node is ctx.env.test with value x, context will be built as {'ctx': {'env': {'test': 'x'}}}
vars_chain = self.variable.get_vars_hierarchy()
self._store_in_context_recursive(context, vars_chain)

def _store_in_context_recursive(self, context, vars_chain: deque):
if len(vars_chain) == 0:
return self.rendered_value

current_level_key = vars_chain.popleft()
current_value_in_key = (
context[current_level_key] if current_level_key in context else {}
)
context[current_level_key] = self._store_in_context_recursive(
current_value_in_key, vars_chain
)
return context


def _get_referenced_vars(
ast_node, variable_attr_chain: deque = deque()
) -> set[Variable]:
# Var nodes end in Name node, and starts with Getattr node.
# Example: ctx.env.test will look like Getattr(test) -> Getattr(env) -> Name(ctx)
all_referenced_vars = []
all_referenced_vars = set()

variable_appended = False
if isinstance(ast_node, nodes.Getattr):
attr_chain = [getattr(ast_node, "attr")] + attr_chain
variable_attr_chain.appendleft(getattr(ast_node, "attr"))
variable_appended = True
elif isinstance(ast_node, nodes.Name):
all_referenced_vars = [[getattr(ast_node, "name")] + attr_chain]
variable_attr_chain.appendleft(getattr(ast_node, "name"))
variable_appended = True
all_referenced_vars.add(Variable(variable_attr_chain))

for child_node in ast_node.iter_child_nodes():
all_referenced_vars = all_referenced_vars + _get_referenced_vars(
child_node, attr_chain
all_referenced_vars.update(
_get_referenced_vars(child_node, variable_attr_chain)
)

if variable_appended:
variable_attr_chain.popleft()

return all_referenced_vars


def _get_referenced_vars_from_str(env: Environment, template_str: str) -> List[str]:
def _get_referenced_vars_from_str(
env: Environment, template_str: Optional[str]
) -> set[Variable]:
if template_str == None:
return set()

ast = env.parse(template_str)
referenced_vars = _get_referenced_vars(ast)
result = [".".join(vars_parts) for vars_parts in referenced_vars]
return result
return _get_referenced_vars(ast)


def _get_value_from_var_path(env: Environment, context: dict, var_path: str):
Expand All @@ -48,144 +97,67 @@ def _get_value_from_var_path(env: Environment, context: dict, var_path: str):
ref_str = "<% " + var_path + " %>"
template = env.from_string(ref_str)
try:
return template.render(context)
except UndefinedError as e:
# check in env variables
return f"{env.block_start_string} raw {env.block_end_string}{template.render({'ctx': {'env': os.environ}})}{env.block_start_string} endraw {env.block_end_string}"
except UndefinedError as e:
try:
return f"{env.block_start_string} raw {env.block_end_string}{template.render({'ctx': {'env': os.environ}})}{env.block_start_string} endraw {env.block_end_string}"
return template.render(context)
except:
raise UndefinedError("Could not find template variable " + var_path)


def _find_node_with_no_dependencies(graph):
for node, deps in graph.items():
if len(deps) == 0:
return node
return None


def _remove_node_from_graph(graph, node):
del graph[node]
for n, deps in graph.items():
if node in deps:
deps.remove(node)


def _check_for_cycles(original_graph):
graph = copy.deepcopy(original_graph)
while len(graph) > 0:
node = _find_node_with_no_dependencies(graph)
if node == None:
raise RecursionError("Cycle detected in project definition file template")
_remove_node_from_graph(graph, node)


def _build_dependency_graph(env, referenced_vars, context_without_env):
dependencies_graph = {}
for node in referenced_vars:
dependencies_graph[node] = []

for node in referenced_vars:
value = _get_value_from_var_path(env, context_without_env, node)
depends_on = _get_referenced_vars_from_str(env, value)
for dependency in depends_on:
if not dependency in dependencies_graph:
raise RuntimeError(
f"unexpected dependency {dependency} not in {dependencies_graph.keys()}"
)

dependencies_graph[node] = depends_on
return dependencies_graph


def _fill_context_recursive(context, attrs, value):
if len(attrs) == 0:
return

if len(attrs) == 1:
context[attrs[0]] = value
return

if attrs[0] not in context:
context[attrs[0]] = {}

_fill_context_recursive(context[attrs[0]], attrs[1:], value)


def _fill_context(context_to_be_filled, graph_node: GraphNode):
# if graph node is ctx.env.test with value x, context will be built as {'ctx': {'env': {'test': 'x'}}}
value = graph_node.value
attrs = graph_node.path.split(".")
_fill_context_recursive(context_to_be_filled, attrs, value)

def _build_dependency_graph(env, all_vars: set[Variable], context_without_env) -> Graph:
dependencies_graph = Graph()
for variable in all_vars:
dependencies_graph.add(
TemplateGraphNode(key=variable.get_key(), variable=variable)
)

def _resolve_node_values(
jinja_env, dependencies_graph, node: GraphNode, graph_nodes_map, context_with_env
):
if node.value:
return node.value
dependencies = dependencies_graph[node.path]
for variable in all_vars:
node: TemplateGraphNode = dependencies_graph.get(key=variable.get_key())
node.templated_value = _get_value_from_var_path(
env, context_without_env, variable.get_key()
)
dependencies_vars = _get_referenced_vars_from_str(env, node.templated_value)

my_context: dict = {}
for dep in dependencies:
dep_node: GraphNode = graph_nodes_map[dep]
if not dep_node.value:
dep_node.value = _resolve_node_values(
jinja_env,
dependencies_graph,
dep_node,
graph_nodes_map,
context_with_env,
for referenced_var in dependencies_vars:
dependencies_graph.add_dependency(
variable.get_key(), referenced_var.get_key()
)
_fill_context(my_context, dep_node)
if len(dependencies) == 0:
my_context = context_with_env

value = _get_value_from_var_path(jinja_env, context_with_env, node.path)
template = jinja_env.from_string(value)
node.value = template.render(my_context)
return node.value
return dependencies_graph


def render_project_template(project):
def render_definition_template(definition):
jinja_env = get_snowflake_cli_jinja_env()
context_without_env = {"ctx": project.model_dump(exclude_unset=True)}
context_with_env = {"ctx": project}
combined_pdf_as_str = yaml.dump(project.model_dump(exclude_unset=True))
pdf_context = {"ctx": definition}
pdf_yaml_str = yaml.dump(definition)

referenced_vars = _get_referenced_vars_from_str(jinja_env, combined_pdf_as_str)
referenced_vars = _get_referenced_vars_from_str(jinja_env, pdf_yaml_str)

# build dependency graph without env variables, because these cannot reference other vars.
# env vars are used as backup as leaf nodes (just for graph purposes)
dependencies_graph = _build_dependency_graph(
jinja_env, referenced_vars, context_without_env
jinja_env, referenced_vars, pdf_context
)
_check_for_cycles(dependencies_graph)

# store extra node information (var_path to GraphNode object map)
graph_nodes_map = {}
for path in dependencies_graph.keys():
graph_nodes_map[path] = GraphNode(path=path)

# recursively resolve node values of the graph
for referenced_var in graph_nodes_map.values():
_resolve_node_values(
jinja_env,
dependencies_graph,
referenced_var,
graph_nodes_map,
context_with_env,
)

def evaluate_node(node: TemplateGraphNode):
current_context: dict = {}
dep_node: TemplateGraphNode
for dep_node in node.dependencies:
dep_node.store_in_context(current_context)

template = jinja_env.from_string(node.templated_value)
node.rendered_value = template.render(current_context)

dependencies_graph.dfs(visit_action=evaluate_node)

# now that we determined value of all referenced vars, use these resolved values as a fresh context to resolve project file later
final_context = {}
for referenced_var in graph_nodes_map.values():
_fill_context(final_context, referenced_var)
node: TemplateGraphNode
for node in dependencies_graph.get_all_nodes():
node.store_in_context(final_context)

# resolve combined project file based on final context
template = jinja_env.from_string(combined_pdf_as_str)
template = jinja_env.from_string(pdf_yaml_str)
rendered_template = template.render(final_context)
rendered_definition = yaml.load(rendered_template, Loader=yaml.loader.BaseLoader)
rendered_project = ProjectDefinition(**rendered_definition)

return rendered_project
return rendered_definition
Loading

0 comments on commit d60f3c7

Please sign in to comment.