diff --git a/src/tensora/desugar/to_iteration_graphs.py b/src/tensora/desugar/to_iteration_graphs.py index ba06dcb..18b6a40 100644 --- a/src/tensora/desugar/to_iteration_graphs.py +++ b/src/tensora/desugar/to_iteration_graphs.py @@ -53,7 +53,7 @@ def to_iteration_graphs_integer( formats: dict[str, Format], counter: Iterator[int], ) -> Iterator[ig.IterationGraph]: - yield ig.TerminalExpression(id.Integer(self.value)) + yield ig.TerminalNode(id.Integer(self.value)) @to_iteration_graphs_expression.register(ast.Float) @@ -62,7 +62,7 @@ def to_iteration_graphs_float( formats: dict[str, Format], counter: Iterator[int], ) -> Iterator[ig.IterationGraph]: - yield ig.TerminalExpression(id.Float(self.value)) + yield ig.TerminalNode(id.Float(self.value)) @to_iteration_graphs_expression.register(ast.Scalar) @@ -71,7 +71,7 @@ def to_iteration_graphs_scalar( formats: dict[str, Format], counter: Iterator[int], ) -> Iterator[ig.IterationGraph]: - yield ig.TerminalExpression(id.Scalar(id.TensorLeaf(self.name, self.id))) + yield ig.TerminalNode(id.Scalar(id.TensorLeaf(self.name, self.id))) @to_iteration_graphs_expression.register(ast.Tensor) @@ -90,13 +90,13 @@ def to_iteration_graphs_tensor( raise DiagonalAccessError(self) for index_order in legal_iteration_orders(format): - graph = ig.TerminalExpression( + graph = ig.TerminalNode( id.Tensor(id.TensorLeaf(self.name, self.id), index_variables, modes) ) # Build iteration order bottom up for i_index in reversed(index_order): index_variable = index_variables[i_index] - graph = ig.IterationVariable( + graph = ig.IterationNode( index_variable, None, next=graph, @@ -104,7 +104,7 @@ def to_iteration_graphs_tensor( yield graph -def simplify_add(graph: ig.Sum) -> Iterator[ig.IterationGraph]: +def simplify_add(graph: ig.SumNode) -> Iterator[ig.IterationGraph]: """Simplify an Add by combining terms with the same index variable. An Add node can be simplified if all its terms are `IterationVariable`s with @@ -116,7 +116,7 @@ def simplify_add(graph: ig.Sum) -> Iterator[ig.IterationGraph]: # This could yield all the intermediate graphs, but the last one might # always be the most efficient. - if all(isinstance(term, ig.IterationVariable) for term in graph.terms): + if all(isinstance(term, ig.IterationNode) for term in graph.terms): unique_terms = {term.index_variable for term in graph.terms} if len(unique_terms) == 1: head = graph.terms[0] @@ -124,18 +124,18 @@ def simplify_add(graph: ig.Sum) -> Iterator[ig.IterationGraph]: # Flatten nested Adds terms = [] for term in graph.terms: - if isinstance(term.next, ig.Sum): + if isinstance(term.next, ig.SumNode): terms.extend(term.next.terms) else: terms.append(term.next) - for next_graph in simplify_add(ig.Sum(graph.name, terms)): + for next_graph in simplify_add(ig.SumNode(graph.name, terms)): yield replace(head, next=next_graph) else: yield graph - elif all(isinstance(term, ig.TerminalExpression) for term in graph.terms): + elif all(isinstance(term, ig.TerminalNode) for term in graph.terms): expression = reduce(id.Add, [term.expression for term in graph.terms]) - yield ig.TerminalExpression(expression) + yield ig.TerminalNode(expression) else: if not any(term.has_output() for term in graph.terms): # Output iterations are not permitted downstream of an Add @@ -153,14 +153,14 @@ def to_iteration_graphs_add( for right in to_iteration_graphs_expression(self.right, formats, counter): # Always simplify Add within Add match (left, right): - case (ig.Sum(), ig.Sum()): - graph = ig.Sum(name, [*left.terms, *right.terms]) - case (ig.Sum(), _): - graph = ig.Sum(name, [*left.terms, right]) - case (_, ig.Sum()): - graph = ig.Sum(name, [left, *right.terms]) + case (ig.SumNode(), ig.SumNode()): + graph = ig.SumNode(name, [*left.terms, *right.terms]) + case (ig.SumNode(), _): + graph = ig.SumNode(name, [*left.terms, right]) + case (_, ig.SumNode()): + graph = ig.SumNode(name, [left, *right.terms]) case (_, _): - graph = ig.Sum(name, [left, right]) + graph = ig.SumNode(name, [left, right]) yield from simplify_add(graph) @@ -169,15 +169,15 @@ def merge_multiply( left: ig.IterationGraph, right: ig.IterationGraph ) -> Iterator[ig.IterationGraph]: match (left, right): - case (ig.TerminalExpression(), ig.TerminalExpression()): - yield ig.TerminalExpression(id.Multiply(left.expression, right.expression)) - case (ig.IterationVariable(), ig.TerminalExpression()): + case (ig.TerminalNode(), ig.TerminalNode()): + yield ig.TerminalNode(id.Multiply(left.expression, right.expression)) + case (ig.IterationNode(), ig.TerminalNode()): for tail in merge_multiply(left.next, right): yield replace(left, next=tail) - case (ig.TerminalExpression(), ig.IterationVariable()): + case (ig.TerminalNode(), ig.IterationNode()): for tail in merge_multiply(left, right.next): yield replace(right, next=tail) - case (ig.IterationVariable(), ig.IterationVariable()): + case (ig.IterationNode(), ig.IterationNode()): if left.index_variable == right.index_variable: for tail in merge_multiply(left.next, right.next): yield replace(left, next=tail) @@ -189,12 +189,12 @@ def merge_multiply( if right.index_variable not in left.next.later_indexes(): for tail in merge_multiply(left, right.next): yield replace(right, next=tail) - case (ig.Sum(), _): + case (ig.SumNode(), _): for terms in product(*[merge_multiply(term, right) for term in left.terms]): - yield ig.Sum(left.name, list(terms)) - case (_, ig.Sum()): + yield ig.SumNode(left.name, list(terms)) + case (_, ig.SumNode()): for terms in product(*[merge_multiply(left, term) for term in right.terms]): - yield ig.Sum(right.name, list(terms)) + yield ig.SumNode(right.name, list(terms)) @to_iteration_graphs_expression.register(ast.Multiply) @@ -221,13 +221,13 @@ def merge_assignment( target: ig.IterationGraph, expression: ig.IterationGraph, output_layers: dict[str, TensorLayer] ) -> Iterator[ig.IterationGraph]: match (target, expression): - case (ig.TerminalExpression(), _): + case (ig.TerminalNode(), _): yield expression - case (ig.IterationVariable(), ig.TerminalExpression()): + case (ig.IterationNode(), ig.TerminalNode()): output_leaf = output_layers[target.index_variable] for tail in merge_assignment(target.next, expression, output_layers): yield replace(target, output=output_leaf, next=tail) - case (ig.IterationVariable(), ig.IterationVariable()): + case (ig.IterationNode(), ig.IterationNode()): if target.index_variable == expression.index_variable: output_leaf = output_layers[target.index_variable] for tail in merge_assignment(target.next, expression.next, output_layers): @@ -241,11 +241,11 @@ def merge_assignment( if expression.index_variable not in target.next.later_indexes(): for tail in merge_assignment(target, expression.next, output_layers): yield replace(expression, next=tail) - case (ig.IterationVariable(), ig.Sum(name=name, terms=terms)): + case (ig.IterationNode(), ig.SumNode(name=name, terms=terms)): for merged_terms in product( *(merge_assignment(target, term, output_layers) for term in terms) ): - yield from simplify_add(ig.Sum(name, list(merged_terms))) + yield from simplify_add(ig.SumNode(name, list(merged_terms))) def to_iteration_graphs( diff --git a/src/tensora/iteration_graph/generate_ir.py b/src/tensora/iteration_graph/generate_ir.py index 609fff9..89ed115 100644 --- a/src/tensora/iteration_graph/generate_ir.py +++ b/src/tensora/iteration_graph/generate_ir.py @@ -24,7 +24,7 @@ from .definition import Definition from .identifiable_expression import to_ir from .identifiable_expression.tensor_layer import TensorLayer -from .iteration_graph import IterationGraph, IterationVariable, Sum, TerminalExpression +from .iteration_graph import IterationGraph, IterationNode, SumNode, TerminalNode from .names import crd_name, dimension_name, pos_name, vals_name from .outputs import AppendOutput, Output from .write_sparse_ir import ( @@ -44,8 +44,8 @@ def to_ir_iteration_graph( ) -@to_ir_iteration_graph.register(TerminalExpression) -def to_ir_terminal_expression(self: TerminalExpression, output: Output, kernel_type: KernelType): +@to_ir_iteration_graph.register(TerminalNode) +def to_ir_terminal_expression(self: TerminalNode, output: Output, kernel_type: KernelType): source = SourceBuilder("*** Computation of expression ***") if kernel_type.is_compute(): @@ -54,7 +54,7 @@ def to_ir_terminal_expression(self: TerminalExpression, output: Output, kernel_t return source -def generate_subgraphs(graph: IterationVariable) -> list[IterationVariable]: +def generate_subgraphs(graph: IterationNode) -> list[IterationNode]: # The 0th element is just the full graph # Each element is derived from a previous element by zeroing a tensor # Zeroing a tensor always results in a strictly simpler graph @@ -82,8 +82,8 @@ def generate_subgraphs(graph: IterationVariable) -> list[IterationVariable]: return list(all_subgraphs.values()) -@to_ir_iteration_graph.register(IterationVariable) -def to_ir_iteration_variable(self: IterationVariable, output: Output, kernel_type: KernelType): +@to_ir_iteration_graph.register(IterationNode) +def to_ir_iteration_variable(self: IterationNode, output: Output, kernel_type: KernelType): source = SourceBuilder(f"*** Iteration over {self.index_variable} ***") if not kernel_type.is_compute() and not self.has_assemble(): @@ -325,8 +325,8 @@ def to_ir_iteration_variable(self: IterationVariable, output: Output, kernel_typ return source -@to_ir_iteration_graph.register(Sum) -def to_ir_sum(self: Sum, output: Output, kernel_type: KernelType): +@to_ir_iteration_graph.register(SumNode) +def to_ir_sum(self: SumNode, output: Output, kernel_type: KernelType): source = SourceBuilder("*** Sum ***") if kernel_type.is_compute(): diff --git a/src/tensora/iteration_graph/iteration_graph.py b/src/tensora/iteration_graph/iteration_graph.py index 70fd59a..b3961b0 100644 --- a/src/tensora/iteration_graph/iteration_graph.py +++ b/src/tensora/iteration_graph/iteration_graph.py @@ -1,6 +1,6 @@ from __future__ import annotations -__all__ = ["IterationGraph", "Sum", "IterationVariable", "TerminalExpression"] +__all__ = ["IterationGraph", "TerminalNode", "IterationNode", "SumNode"] from abc import abstractmethod from dataclasses import dataclass, replace @@ -46,14 +46,14 @@ def has_output(self) -> bool: @dataclass(frozen=True) -class TerminalExpression(IterationGraph): +class TerminalNode(IterationGraph): expression: Expression def extract_context(self, index: str) -> Context: return extract_context(self.expression, index) def exhaust_tensor(self, tensor: TensorLeaf) -> IterationGraph: - return TerminalExpression(exhaust_tensor(self.expression, tensor)) + return TerminalNode(exhaust_tensor(self.expression, tensor)) def is_sparse_output(self) -> bool: return False @@ -70,7 +70,7 @@ def has_output(self) -> bool: @dataclass(frozen=True) -class IterationVariable(IterationGraph): +class IterationNode(IterationGraph): index_variable: str output: TensorLayer | None next: IterationGraph @@ -122,7 +122,7 @@ def has_assemble(self) -> bool: @dataclass(frozen=True) -class Sum(IterationGraph): +class SumNode(IterationGraph): name: str terms: list[IterationGraph] @@ -140,7 +140,7 @@ def exhaust_tensor(self, tensor: TensorLeaf) -> IterationGraph: new_terms.append(new_term) if len(new_terms) == 0: - return TerminalExpression(Integer(0)) + return TerminalNode(Integer(0)) elif len(new_terms) == 1: return new_terms[0] else: