Skip to content

Commit

Permalink
Name all iteration graph nodes as Node
Browse files Browse the repository at this point in the history
  • Loading branch information
drhagen committed Nov 30, 2023
1 parent 715ee8c commit ebf7195
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 46 deletions.
64 changes: 32 additions & 32 deletions src/tensora/desugar/to_iteration_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def to_iteration_graphs_integer(
formats: dict[str, Format],
counter: Iterator[int],
) -> Iterator[ig.IterationGraph]:
yield ig.TerminalExpression(id.Integer(self.value))
yield ig.TerminalNode(id.Integer(self.value))


@to_iteration_graphs_expression.register(ast.Float)
Expand All @@ -62,7 +62,7 @@ def to_iteration_graphs_float(
formats: dict[str, Format],
counter: Iterator[int],
) -> Iterator[ig.IterationGraph]:
yield ig.TerminalExpression(id.Float(self.value))
yield ig.TerminalNode(id.Float(self.value))


@to_iteration_graphs_expression.register(ast.Scalar)
Expand All @@ -71,7 +71,7 @@ def to_iteration_graphs_scalar(
formats: dict[str, Format],
counter: Iterator[int],
) -> Iterator[ig.IterationGraph]:
yield ig.TerminalExpression(id.Scalar(id.TensorLeaf(self.name, self.id)))
yield ig.TerminalNode(id.Scalar(id.TensorLeaf(self.name, self.id)))


@to_iteration_graphs_expression.register(ast.Tensor)
Expand All @@ -90,21 +90,21 @@ def to_iteration_graphs_tensor(
raise DiagonalAccessError(self)

for index_order in legal_iteration_orders(format):
graph = ig.TerminalExpression(
graph = ig.TerminalNode(
id.Tensor(id.TensorLeaf(self.name, self.id), index_variables, modes)
)
# Build iteration order bottom up
for i_index in reversed(index_order):
index_variable = index_variables[i_index]
graph = ig.IterationVariable(
graph = ig.IterationNode(
index_variable,
None,
next=graph,
)
yield graph


def simplify_add(graph: ig.Sum) -> Iterator[ig.IterationGraph]:
def simplify_add(graph: ig.SumNode) -> Iterator[ig.IterationGraph]:
"""Simplify an Add by combining terms with the same index variable.
An Add node can be simplified if all its terms are `IterationVariable`s with
Expand All @@ -116,26 +116,26 @@ def simplify_add(graph: ig.Sum) -> Iterator[ig.IterationGraph]:
# This could yield all the intermediate graphs, but the last one might
# always be the most efficient.

if all(isinstance(term, ig.IterationVariable) for term in graph.terms):
if all(isinstance(term, ig.IterationNode) for term in graph.terms):
unique_terms = {term.index_variable for term in graph.terms}
if len(unique_terms) == 1:
head = graph.terms[0]

# Flatten nested Adds
terms = []
for term in graph.terms:
if isinstance(term.next, ig.Sum):
if isinstance(term.next, ig.SumNode):
terms.extend(term.next.terms)
else:
terms.append(term.next)

for next_graph in simplify_add(ig.Sum(graph.name, terms)):
for next_graph in simplify_add(ig.SumNode(graph.name, terms)):
yield replace(head, next=next_graph)
else:
yield graph
elif all(isinstance(term, ig.TerminalExpression) for term in graph.terms):
elif all(isinstance(term, ig.TerminalNode) for term in graph.terms):
expression = reduce(id.Add, [term.expression for term in graph.terms])
yield ig.TerminalExpression(expression)
yield ig.TerminalNode(expression)
else:
if not any(term.has_output() for term in graph.terms):
# Output iterations are not permitted downstream of an Add
Expand All @@ -153,14 +153,14 @@ def to_iteration_graphs_add(
for right in to_iteration_graphs_expression(self.right, formats, counter):
# Always simplify Add within Add
match (left, right):
case (ig.Sum(), ig.Sum()):
graph = ig.Sum(name, [*left.terms, *right.terms])
case (ig.Sum(), _):
graph = ig.Sum(name, [*left.terms, right])
case (_, ig.Sum()):
graph = ig.Sum(name, [left, *right.terms])
case (ig.SumNode(), ig.SumNode()):
graph = ig.SumNode(name, [*left.terms, *right.terms])
case (ig.SumNode(), _):
graph = ig.SumNode(name, [*left.terms, right])
case (_, ig.SumNode()):
graph = ig.SumNode(name, [left, *right.terms])
case (_, _):
graph = ig.Sum(name, [left, right])
graph = ig.SumNode(name, [left, right])

yield from simplify_add(graph)

Expand All @@ -169,15 +169,15 @@ def merge_multiply(
left: ig.IterationGraph, right: ig.IterationGraph
) -> Iterator[ig.IterationGraph]:
match (left, right):
case (ig.TerminalExpression(), ig.TerminalExpression()):
yield ig.TerminalExpression(id.Multiply(left.expression, right.expression))
case (ig.IterationVariable(), ig.TerminalExpression()):
case (ig.TerminalNode(), ig.TerminalNode()):
yield ig.TerminalNode(id.Multiply(left.expression, right.expression))
case (ig.IterationNode(), ig.TerminalNode()):
for tail in merge_multiply(left.next, right):
yield replace(left, next=tail)
case (ig.TerminalExpression(), ig.IterationVariable()):
case (ig.TerminalNode(), ig.IterationNode()):
for tail in merge_multiply(left, right.next):
yield replace(right, next=tail)
case (ig.IterationVariable(), ig.IterationVariable()):
case (ig.IterationNode(), ig.IterationNode()):
if left.index_variable == right.index_variable:
for tail in merge_multiply(left.next, right.next):
yield replace(left, next=tail)
Expand All @@ -189,12 +189,12 @@ def merge_multiply(
if right.index_variable not in left.next.later_indexes():
for tail in merge_multiply(left, right.next):
yield replace(right, next=tail)
case (ig.Sum(), _):
case (ig.SumNode(), _):
for terms in product(*[merge_multiply(term, right) for term in left.terms]):
yield ig.Sum(left.name, list(terms))
case (_, ig.Sum()):
yield ig.SumNode(left.name, list(terms))
case (_, ig.SumNode()):
for terms in product(*[merge_multiply(left, term) for term in right.terms]):
yield ig.Sum(right.name, list(terms))
yield ig.SumNode(right.name, list(terms))


@to_iteration_graphs_expression.register(ast.Multiply)
Expand All @@ -221,13 +221,13 @@ def merge_assignment(
target: ig.IterationGraph, expression: ig.IterationGraph, output_layers: dict[str, TensorLayer]
) -> Iterator[ig.IterationGraph]:
match (target, expression):
case (ig.TerminalExpression(), _):
case (ig.TerminalNode(), _):
yield expression
case (ig.IterationVariable(), ig.TerminalExpression()):
case (ig.IterationNode(), ig.TerminalNode()):
output_leaf = output_layers[target.index_variable]
for tail in merge_assignment(target.next, expression, output_layers):
yield replace(target, output=output_leaf, next=tail)
case (ig.IterationVariable(), ig.IterationVariable()):
case (ig.IterationNode(), ig.IterationNode()):
if target.index_variable == expression.index_variable:
output_leaf = output_layers[target.index_variable]
for tail in merge_assignment(target.next, expression.next, output_layers):
Expand All @@ -241,11 +241,11 @@ def merge_assignment(
if expression.index_variable not in target.next.later_indexes():
for tail in merge_assignment(target, expression.next, output_layers):
yield replace(expression, next=tail)
case (ig.IterationVariable(), ig.Sum(name=name, terms=terms)):
case (ig.IterationNode(), ig.SumNode(name=name, terms=terms)):
for merged_terms in product(
*(merge_assignment(target, term, output_layers) for term in terms)
):
yield from simplify_add(ig.Sum(name, list(merged_terms)))
yield from simplify_add(ig.SumNode(name, list(merged_terms)))


def to_iteration_graphs(
Expand Down
16 changes: 8 additions & 8 deletions src/tensora/iteration_graph/generate_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .definition import Definition
from .identifiable_expression import to_ir
from .identifiable_expression.tensor_layer import TensorLayer
from .iteration_graph import IterationGraph, IterationVariable, Sum, TerminalExpression
from .iteration_graph import IterationGraph, IterationNode, SumNode, TerminalNode
from .names import crd_name, dimension_name, pos_name, vals_name
from .outputs import AppendOutput, Output
from .write_sparse_ir import (
Expand All @@ -44,8 +44,8 @@ def to_ir_iteration_graph(
)


@to_ir_iteration_graph.register(TerminalExpression)
def to_ir_terminal_expression(self: TerminalExpression, output: Output, kernel_type: KernelType):
@to_ir_iteration_graph.register(TerminalNode)
def to_ir_terminal_expression(self: TerminalNode, output: Output, kernel_type: KernelType):
source = SourceBuilder("*** Computation of expression ***")

if kernel_type.is_compute():
Expand All @@ -54,7 +54,7 @@ def to_ir_terminal_expression(self: TerminalExpression, output: Output, kernel_t
return source


def generate_subgraphs(graph: IterationVariable) -> list[IterationVariable]:
def generate_subgraphs(graph: IterationNode) -> list[IterationNode]:
# The 0th element is just the full graph
# Each element is derived from a previous element by zeroing a tensor
# Zeroing a tensor always results in a strictly simpler graph
Expand Down Expand Up @@ -82,8 +82,8 @@ def generate_subgraphs(graph: IterationVariable) -> list[IterationVariable]:
return list(all_subgraphs.values())


@to_ir_iteration_graph.register(IterationVariable)
def to_ir_iteration_variable(self: IterationVariable, output: Output, kernel_type: KernelType):
@to_ir_iteration_graph.register(IterationNode)
def to_ir_iteration_variable(self: IterationNode, output: Output, kernel_type: KernelType):
source = SourceBuilder(f"*** Iteration over {self.index_variable} ***")

if not kernel_type.is_compute() and not self.has_assemble():
Expand Down Expand Up @@ -325,8 +325,8 @@ def to_ir_iteration_variable(self: IterationVariable, output: Output, kernel_typ
return source


@to_ir_iteration_graph.register(Sum)
def to_ir_sum(self: Sum, output: Output, kernel_type: KernelType):
@to_ir_iteration_graph.register(SumNode)
def to_ir_sum(self: SumNode, output: Output, kernel_type: KernelType):
source = SourceBuilder("*** Sum ***")

if kernel_type.is_compute():
Expand Down
12 changes: 6 additions & 6 deletions src/tensora/iteration_graph/iteration_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

__all__ = ["IterationGraph", "Sum", "IterationVariable", "TerminalExpression"]
__all__ = ["IterationGraph", "TerminalNode", "IterationNode", "SumNode"]

from abc import abstractmethod
from dataclasses import dataclass, replace
Expand Down Expand Up @@ -46,14 +46,14 @@ def has_output(self) -> bool:


@dataclass(frozen=True)
class TerminalExpression(IterationGraph):
class TerminalNode(IterationGraph):
expression: Expression

def extract_context(self, index: str) -> Context:
return extract_context(self.expression, index)

def exhaust_tensor(self, tensor: TensorLeaf) -> IterationGraph:
return TerminalExpression(exhaust_tensor(self.expression, tensor))
return TerminalNode(exhaust_tensor(self.expression, tensor))

def is_sparse_output(self) -> bool:
return False
Expand All @@ -70,7 +70,7 @@ def has_output(self) -> bool:


@dataclass(frozen=True)
class IterationVariable(IterationGraph):
class IterationNode(IterationGraph):
index_variable: str
output: TensorLayer | None
next: IterationGraph
Expand Down Expand Up @@ -122,7 +122,7 @@ def has_assemble(self) -> bool:


@dataclass(frozen=True)
class Sum(IterationGraph):
class SumNode(IterationGraph):
name: str
terms: list[IterationGraph]

Expand All @@ -140,7 +140,7 @@ def exhaust_tensor(self, tensor: TensorLeaf) -> IterationGraph:
new_terms.append(new_term)

if len(new_terms) == 0:
return TerminalExpression(Integer(0))
return TerminalNode(Integer(0))
elif len(new_terms) == 1:
return new_terms[0]
else:
Expand Down

0 comments on commit ebf7195

Please sign in to comment.