From 66b87e8022343267b76147bfce1b0b3291bd483d Mon Sep 17 00:00:00 2001 From: David Hagen Date: Thu, 13 Jun 2024 16:09:03 -0400 Subject: [PATCH] Add Module to IR (#67) --- src/tensora/codegen/__init__.py | 2 +- src/tensora/codegen/_ir_to_c.py | 15 ++-- src/tensora/generate/_tensora.py | 10 +-- src/tensora/ir/__init__.py | 2 +- src/tensora/ir/_peephole.py | 73 +++++++++------- src/tensora/ir/ast.py | 13 ++- src/tensora/iteration_graph/_generate_ir.py | 95 ++++++++++++--------- tests/codegen/test_ast_to_c.py | 54 ++++++++---- tests/ir/test_peephole.py | 66 ++++++++------ 9 files changed, 202 insertions(+), 128 deletions(-) diff --git a/src/tensora/codegen/__init__.py b/src/tensora/codegen/__init__.py index 4560949..8b546ef 100644 --- a/src/tensora/codegen/__init__.py +++ b/src/tensora/codegen/__init__.py @@ -1 +1 @@ -from ._ir_to_c import ir_to_c +from ._ir_to_c import ir_to_c, ir_to_c_function_definition, ir_to_c_statement diff --git a/src/tensora/codegen/_ir_to_c.py b/src/tensora/codegen/_ir_to_c.py index c96de59..c0e6356 100644 --- a/src/tensora/codegen/_ir_to_c.py +++ b/src/tensora/codegen/_ir_to_c.py @@ -1,4 +1,4 @@ -__all__ = ["ir_to_c"] +__all__ = ["ir_to_c_statement", "ir_to_c_function_definition", "ir_to_c"] from functools import singledispatch @@ -36,6 +36,7 @@ Max, Min, ModeLiteral, + Module, Multiply, NotEqual, Or, @@ -323,17 +324,19 @@ def ir_to_c_return(self: Return) -> list[str]: return [f"return {ir_to_c_expression(self.value)};"] -@ir_to_c_statement.register(FunctionDefinition) -def ir_to_c_function_definition(self: FunctionDefinition) -> list[str]: +def ir_to_c_function_definition(self: FunctionDefinition) -> str: return_type_string = type_to_c(self.return_type) name_string = ir_to_c_expression(self.name) parameters_string = ", ".join(map(ir_to_c_declaration, self.parameters)) - return [ + + lines = [ f"{return_type_string} {name_string}({parameters_string}) {{", *indent_lines(ir_to_c_statement(self.body)), "}", ] + return "\n".join(lines) + -def ir_to_c(self: Statement) -> str: - return "\n".join(ir_to_c_statement(self)) +def ir_to_c(self: Module) -> str: + return "\n\n".join((ir_to_c_function_definition(function) for function in self.definitions)) diff --git a/src/tensora/generate/_tensora.py b/src/tensora/generate/_tensora.py index 6054384..c789234 100644 --- a/src/tensora/generate/_tensora.py +++ b/src/tensora/generate/_tensora.py @@ -11,7 +11,8 @@ index_dimensions, to_identifiable, ) -from ..ir import SourceBuilder, peephole +from ..ir import peephole +from ..ir.ast import Module from ..iteration_graph import Definition, generate_ir from ..kernel_type import KernelType from ..problem import Problem @@ -36,8 +37,7 @@ def generate_c_code_tensora( case _: raise NotImplementedError() - ir = SourceBuilder() - for kernel_type in kernel_types: - ir.append(generate_ir(definition, graph, kernel_type).finalize()) + functions = [generate_ir(definition, graph, kernel_type) for kernel_type in kernel_types] + module = Module(functions) - return Success(ir_to_c(peephole(ir.finalize()))) + return Success(ir_to_c(peephole(module))) diff --git a/src/tensora/ir/__init__.py b/src/tensora/ir/__init__.py index d1af023..e4488b6 100644 --- a/src/tensora/ir/__init__.py +++ b/src/tensora/ir/__init__.py @@ -1,2 +1,2 @@ from ._builder import SourceBuilder -from ._peephole import peephole +from ._peephole import peephole, peephole_function_definition, peephole_statement diff --git a/src/tensora/ir/_peephole.py b/src/tensora/ir/_peephole.py index 0f92872..9e4083d 100644 --- a/src/tensora/ir/_peephole.py +++ b/src/tensora/ir/_peephole.py @@ -19,7 +19,7 @@ * redundant_assignment: a = a => {} """ -__all__ = ["peephole"] +__all__ = ["peephole_function_definition", "peephole_statement", "peephole"] from dataclasses import replace from functools import singledispatch @@ -58,6 +58,7 @@ Max, Min, ModeLiteral, + Module, Multiply, NotEqual, Or, @@ -68,18 +69,6 @@ ) -@singledispatch -def peephole(self: Statement) -> Statement: - raise NotImplementedError(f"peephole not implemented for {type(self)}: {self}") - - -@peephole.register(Expression) -@singledispatch -def peephole_expression(self: Expression) -> Expression: - raise NotImplementedError(f"peephole_expression not implemented for {type(self)}: {self}") - - -@peephole_expression.register(Assignable) @singledispatch def peephole_assignable(self: Assignable) -> Assignable: raise NotImplementedError(f"peephole_assignable not implemented for {type(self)}: {self}") @@ -100,6 +89,17 @@ def peephole_array_index(self: ArrayIndex) -> Assignable: return ArrayIndex(peephole_assignable(self.target), peephole_expression(self.index)) +@singledispatch +def peephole_expression(self: Expression) -> Expression: + raise NotImplementedError(f"peephole_expression not implemented for {type(self)}: {self}") + + +@peephole_expression.register(Assignable) +def peephole_expression_assignable(self: Assignable) -> Assignable: + # Assignables are expressions in their own right + return peephole_assignable(self) + + @peephole_expression.register(IntegerLiteral) @peephole_expression.register(FloatLiteral) @peephole_expression.register(BooleanLiteral) @@ -255,17 +255,28 @@ def peephole_array_reallocate(self: ArrayReallocate) -> Expression: return replace(self, old=old, n_elements=n_elements) -@peephole.register(Declaration) +@singledispatch +def peephole_statement(self: Statement) -> Statement: + raise NotImplementedError(f"peephole not implemented for {type(self)}: {self}") + + +@peephole_statement.register(Expression) +def peephole_expression_statement(self: Expression) -> Expression: + # Expressions are statements in their own right + return peephole_expression(self) + + +@peephole_statement.register(Declaration) def peephole_declaration(self: Declaration) -> Declaration: return self -@peephole.register(Free) +@peephole_statement.register(Free) def peephole_free(self: Free) -> Statement: return Free(peephole_assignable(self.target)) -@peephole.register(Assignment) +@peephole_statement.register(Assignment) def peephole_assignment(self: Assignment) -> Statement: target = peephole_assignable(self.target) value = peephole_expression(self.value) @@ -276,17 +287,17 @@ def peephole_assignment(self: Assignment) -> Statement: return Assignment(target, value) -@peephole.register(DeclarationAssignment) +@peephole_statement.register(DeclarationAssignment) def peephole_declaration_assignment(self: DeclarationAssignment) -> Statement: value = peephole_expression(self.value) return replace(self, value=value) -@peephole.register(Block) +@peephole_statement.register(Block) def peephole_block(self: Block) -> Statement: statements = [] for old_statement in self.statements: - statement = peephole(old_statement) + statement = peephole_statement(old_statement) if isinstance(statement, Block) and statement.is_empty(): pass else: @@ -295,11 +306,11 @@ def peephole_block(self: Block) -> Statement: return replace(self, statements=statements) -@peephole.register(Branch) +@peephole_statement.register(Branch) def peephole_branch(self: Branch) -> Statement: condition = peephole_expression(self.condition) - if_true = peephole(self.if_true) - if_false = peephole(self.if_false) + if_true = peephole_statement(self.if_true) + if_false = peephole_statement(self.if_false) if condition == BooleanLiteral(True): return if_true @@ -316,10 +327,10 @@ def peephole_branch(self: Branch) -> Statement: return Branch(condition, if_true, if_false) -@peephole.register(Loop) +@peephole_statement.register(Loop) def peephole_loop(self: Loop) -> Statement: condition = peephole_expression(self.condition) - body = peephole(self.body) + body = peephole_statement(self.body) if condition == BooleanLiteral(False): return Block([]) @@ -329,18 +340,22 @@ def peephole_loop(self: Loop) -> Statement: return Loop(condition, body) -@peephole.register(Break) +@peephole_statement.register(Break) def peephole_break(self: Break) -> Statement: return self -@peephole.register(Return) +@peephole_statement.register(Return) def peephole_return(self: Return) -> Statement: value = peephole_expression(self.value) return Return(value) -@peephole.register(FunctionDefinition) -def peephole_function_definition(self: FunctionDefinition) -> Statement: - body = peephole(self.body) +def peephole_function_definition(self: FunctionDefinition) -> FunctionDefinition: + body = peephole_statement(self.body) return replace(self, body=body) + + +def peephole(self: Module) -> Module: + functions = [peephole_function_definition(function) for function in self.definitions] + return Module(functions) diff --git a/src/tensora/ir/ast.py b/src/tensora/ir/ast.py index 7d06026..f60574d 100644 --- a/src/tensora/ir/ast.py +++ b/src/tensora/ir/ast.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Sequence - __all__ = [ "Statement", "Expression", @@ -43,10 +41,12 @@ "Break", "Return", "FunctionDefinition", + "Module", ] from dataclasses import dataclass from functools import reduce +from typing import Sequence from ..format import Mode from .types import Type @@ -262,7 +262,7 @@ def join(operands: Sequence[Expression | int | str]) -> Expression: @dataclass(frozen=True, slots=True) class Address(Expression): - target: Variable + target: Assignable @dataclass(frozen=True, slots=True) @@ -357,8 +357,13 @@ class Return(Statement): @dataclass(frozen=True, slots=True) -class FunctionDefinition(Statement): +class FunctionDefinition: name: Variable parameters: list[Declaration] return_type: Type body: Statement + + +@dataclass(frozen=True, slots=True) +class Module: + definitions: list[FunctionDefinition] diff --git a/src/tensora/iteration_graph/_generate_ir.py b/src/tensora/iteration_graph/_generate_ir.py index 3d91d01..8985d12 100644 --- a/src/tensora/iteration_graph/_generate_ir.py +++ b/src/tensora/iteration_graph/_generate_ir.py @@ -11,8 +11,10 @@ Block, BooleanToInteger, Branch, + Declaration, Equal, Expression, + FunctionDefinition, GreaterThan, IntegerLiteral, LessThan, @@ -342,49 +344,58 @@ def to_ir_sum(self: SumNode, output: Output, kernel_type: KernelType): return source -def generate_ir(definition: Definition, graph: IterationGraph, kernel_type: KernelType): +def generate_ir( + definition: Definition, graph: IterationGraph, kernel_type: KernelType +) -> FunctionDefinition: + # Function body source = SourceBuilder() + # Dimensions of all index variables + with source.block("Extract dimensions"): + for index_name, tensor_layer in definition.indexes.items(): + declaration = dimension_name(index_name).declare(types.integer) + value = Variable(tensor_layer.name).attr("dimensions").idx(tensor_layer.dimension) + source.append(declaration.assign(value)) + + # Unpack tensors + with source.block("Unpack tensors"): + for tensor_name, format in definition.formats.items(): + for i, mode in enumerate(format.modes): + match mode: + case Mode.dense: + pass + case Mode.compressed: + pos_declaration = pos_name(tensor_name, i).declare( + types.Pointer(types.integer) + ) + pos_value = Variable(tensor_name).attr("indices").idx(i).idx(0) + source.append(pos_declaration.assign(pos_value)) + crd_declaration = crd_name(tensor_name, i).declare( + types.Pointer(types.integer) + ) + crd_value = Variable(tensor_name).attr("indices").idx(i).idx(1) + source.append(crd_declaration.assign(crd_value)) + + vals_declaration = vals_name(tensor_name).declare(types.Pointer(types.float)) + vals_value = Variable(tensor_name).attr("vals") + source.append(vals_declaration.assign(vals_value)) + + output = AppendOutput(definition.output_variable, 0) + source.append(output.write_declarations(kernel_type)) + + source.append(to_ir_iteration_graph(graph, output, kernel_type)) + + source.append(output.write_cleanup(kernel_type)) + + source.append(Return(IntegerLiteral(0))) + # Function declaration - parameters = {name: types.Pointer(types.tensor) for name in definition.formats.keys()} - with source.function_definition(kernel_type.name, parameters, types.integer): - # Dimensions of all index variables - with source.block("Extract dimensions"): - for index_name, tensor_layer in definition.indexes.items(): - declaration = dimension_name(index_name).declare(types.integer) - value = Variable(tensor_layer.name).attr("dimensions").idx(tensor_layer.dimension) - source.append(declaration.assign(value)) - - # Unpack tensors - with source.block("Unpack tensors"): - for tensor_name, format in definition.formats.items(): - for i, mode in enumerate(format.modes): - match mode: - case Mode.dense: - pass - case Mode.compressed: - pos_declaration = pos_name(tensor_name, i).declare( - types.Pointer(types.integer) - ) - pos_value = Variable(tensor_name).attr("indices").idx(i).idx(0) - source.append(pos_declaration.assign(pos_value)) - crd_declaration = crd_name(tensor_name, i).declare( - types.Pointer(types.integer) - ) - crd_value = Variable(tensor_name).attr("indices").idx(i).idx(1) - source.append(crd_declaration.assign(crd_value)) - - vals_declaration = vals_name(tensor_name).declare(types.Pointer(types.float)) - vals_value = Variable(tensor_name).attr("vals") - source.append(vals_declaration.assign(vals_value)) - - output = AppendOutput(definition.output_variable, 0) - source.append(output.write_declarations(kernel_type)) - - source.append(to_ir_iteration_graph(graph, output, kernel_type)) - - source.append(output.write_cleanup(kernel_type)) - - source.append(Return(IntegerLiteral(0))) + parameters = [ + Declaration(Variable(name), types.Pointer(types.tensor)) + for name in definition.formats.keys() + ] + function_definition = FunctionDefinition( + Variable(kernel_type.name), parameters, types.integer, source.finalize() + ) - return source + return function_definition diff --git a/tests/codegen/test_ast_to_c.py b/tests/codegen/test_ast_to_c.py index 185a78b..5bd493f 100644 --- a/tests/codegen/test_ast_to_c.py +++ b/tests/codegen/test_ast_to_c.py @@ -2,7 +2,7 @@ import pytest -from tensora.codegen import ir_to_c +from tensora.codegen import ir_to_c, ir_to_c_function_definition, ir_to_c_statement from tensora.ir.ast import * from tensora.ir.types import * @@ -114,7 +114,7 @@ def clean(string: str) -> str: @pytest.mark.parametrize(("ast", "code"), single_lines) def test_single_lines(ast: Expression, code: str): - assert ir_to_c(ast) == code + ";" + assert ir_to_c_statement(ast) == [code + ";"] multiple_lines = [ @@ -306,22 +306,44 @@ def test_single_lines(ast: Expression, code: str): } """, ), - ( - FunctionDefinition( - Variable("double_x"), - [Declaration(Variable("x"), integer)], - integer, - Return(Multiply(Variable("x"), IntegerLiteral(2))), - ), - """ +] + + +@pytest.mark.parametrize(("ast", "code"), multiple_lines) +def test_multiple_lines(ast: Statement, code: str): + assert "\n".join(ir_to_c_statement(ast)) == clean(code) + + +def test_function_definition(): + function = FunctionDefinition( + Variable("double_x"), + [Declaration(Variable("x"), integer)], + integer, + Return(Multiply(Variable("x"), IntegerLiteral(2))), + ) + expected = """ int32_t double_x(int32_t x) { return x * 2; } - """, - ), -] + """ + assert ir_to_c_function_definition(function) == clean(expected) -@pytest.mark.parametrize(("ast", "code"), multiple_lines) -def test_multiple_lines(ast: Expression, code: str): - assert ir_to_c(ast) == clean(code) +def test_module(): + function = FunctionDefinition( + Variable("double_x"), + [Declaration(Variable("x"), integer)], + integer, + Return(Multiply(Variable("x"), IntegerLiteral(2))), + ) + module = Module([function, function]) + expected = """ + int32_t double_x(int32_t x) { + return x * 2; + } + + int32_t double_x(int32_t x) { + return x * 2; + } + """ + assert ir_to_c(module) == clean(expected) diff --git a/tests/ir/test_peephole.py b/tests/ir/test_peephole.py index 631148b..1545d1b 100644 --- a/tests/ir/test_peephole.py +++ b/tests/ir/test_peephole.py @@ -1,6 +1,6 @@ import pytest -from tensora.ir import peephole +from tensora.ir import peephole, peephole_function_definition, peephole_statement from tensora.ir.ast import * from tensora.ir.types import * @@ -163,26 +163,12 @@ Return(Add(IntegerLiteral(0), Variable("x"))), Return(Variable("x")), ), - ( - FunctionDefinition( - Variable("f"), - [Declaration(Variable("x"), tensor)], - integer, - Return(Multiply(IntegerLiteral(0), IntegerLiteral(1))), - ), - FunctionDefinition( - Variable("f"), - [Declaration(Variable("x"), tensor)], - integer, - Return(IntegerLiteral(0)), - ), - ), ] @pytest.mark.parametrize(("before", "after"), changed) -def test_peephole(before: Statement, after: Statement): - assert peephole(before) == after +def test_peephole_statement(before: Statement, after: Statement): + assert peephole_statement(before) == after left_right_classes = [ @@ -207,8 +193,8 @@ def test_pass_through_left_right(cls): left = Add(IntegerLiteral(0), Variable("x")) right = Add(IntegerLiteral(0), Variable("y")) expected = cls(Variable("x"), Variable("y")) - assert peephole(cls(left, Variable("y"))) == expected - assert peephole(cls(Variable("x"), right)) == expected + assert peephole_statement(cls(left, Variable("y"))) == expected + assert peephole_statement(cls(Variable("x"), right)) == expected unchanged = [ @@ -223,14 +209,46 @@ def test_pass_through_left_right(cls): FunctionCall(Variable("f"), [Variable("x")]), Loop(BooleanLiteral(True), Variable("x")), Assignment(Variable("x"), ArrayIndex(ArrayIndex(Variable("y"), Variable("i")), Variable("j"))), - FunctionDefinition( - Variable("f"), [Declaration(Variable("x"), tensor)], integer, Return(IntegerLiteral(0)) - ), Break(), Declaration(Variable("x"), float), ] @pytest.mark.parametrize("input", unchanged) -def test_peephole_noop(input: Statement): - assert peephole(input) == input +def test_peephole_statement_noop(input: Statement): + assert peephole_statement(input) == input + + +def test_peephole_function_definition(): + function = FunctionDefinition( + Variable("f"), + [Declaration(Variable("x"), tensor)], + integer, + Return(Multiply(IntegerLiteral(0), IntegerLiteral(1))), + ) + expected = FunctionDefinition( + Variable("f"), + [Declaration(Variable("x"), tensor)], + integer, + Return(IntegerLiteral(0)), + ) + assert peephole_function_definition(function) == expected + + +def test_peephole_module(): + input_function = FunctionDefinition( + Variable("f"), + [Declaration(Variable("x"), tensor)], + integer, + Return(Multiply(IntegerLiteral(0), IntegerLiteral(1))), + ) + expected_function = FunctionDefinition( + Variable("f"), + [Declaration(Variable("x"), tensor)], + integer, + Return(IntegerLiteral(0)), + ) + + module = Module([input_function, input_function]) + expected = Module([expected_function, expected_function]) + assert peephole(module) == expected