Skip to content

Commit

Permalink
Add Module to IR (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
drhagen authored Jun 13, 2024
1 parent c3ec481 commit 66b87e8
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 128 deletions.
2 changes: 1 addition & 1 deletion src/tensora/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._ir_to_c import ir_to_c
from ._ir_to_c import ir_to_c, ir_to_c_function_definition, ir_to_c_statement
15 changes: 9 additions & 6 deletions src/tensora/codegen/_ir_to_c.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["ir_to_c"]
__all__ = ["ir_to_c_statement", "ir_to_c_function_definition", "ir_to_c"]

from functools import singledispatch

Expand Down Expand Up @@ -36,6 +36,7 @@
Max,
Min,
ModeLiteral,
Module,
Multiply,
NotEqual,
Or,
Expand Down Expand Up @@ -323,17 +324,19 @@ def ir_to_c_return(self: Return) -> list[str]:
return [f"return {ir_to_c_expression(self.value)};"]


@ir_to_c_statement.register(FunctionDefinition)
def ir_to_c_function_definition(self: FunctionDefinition) -> list[str]:
def ir_to_c_function_definition(self: FunctionDefinition) -> str:
return_type_string = type_to_c(self.return_type)
name_string = ir_to_c_expression(self.name)
parameters_string = ", ".join(map(ir_to_c_declaration, self.parameters))
return [

lines = [
f"{return_type_string} {name_string}({parameters_string}) {{",
*indent_lines(ir_to_c_statement(self.body)),
"}",
]

return "\n".join(lines)


def ir_to_c(self: Statement) -> str:
return "\n".join(ir_to_c_statement(self))
def ir_to_c(self: Module) -> str:
return "\n\n".join((ir_to_c_function_definition(function) for function in self.definitions))
10 changes: 5 additions & 5 deletions src/tensora/generate/_tensora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
index_dimensions,
to_identifiable,
)
from ..ir import SourceBuilder, peephole
from ..ir import peephole
from ..ir.ast import Module
from ..iteration_graph import Definition, generate_ir
from ..kernel_type import KernelType
from ..problem import Problem
Expand All @@ -36,8 +37,7 @@ def generate_c_code_tensora(
case _:
raise NotImplementedError()

ir = SourceBuilder()
for kernel_type in kernel_types:
ir.append(generate_ir(definition, graph, kernel_type).finalize())
functions = [generate_ir(definition, graph, kernel_type) for kernel_type in kernel_types]
module = Module(functions)

return Success(ir_to_c(peephole(ir.finalize())))
return Success(ir_to_c(peephole(module)))
2 changes: 1 addition & 1 deletion src/tensora/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ._builder import SourceBuilder
from ._peephole import peephole
from ._peephole import peephole, peephole_function_definition, peephole_statement
73 changes: 44 additions & 29 deletions src/tensora/ir/_peephole.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* redundant_assignment: a = a => {}
"""

__all__ = ["peephole"]
__all__ = ["peephole_function_definition", "peephole_statement", "peephole"]

from dataclasses import replace
from functools import singledispatch
Expand Down Expand Up @@ -58,6 +58,7 @@
Max,
Min,
ModeLiteral,
Module,
Multiply,
NotEqual,
Or,
Expand All @@ -68,18 +69,6 @@
)


@singledispatch
def peephole(self: Statement) -> Statement:
raise NotImplementedError(f"peephole not implemented for {type(self)}: {self}")


@peephole.register(Expression)
@singledispatch
def peephole_expression(self: Expression) -> Expression:
raise NotImplementedError(f"peephole_expression not implemented for {type(self)}: {self}")


@peephole_expression.register(Assignable)
@singledispatch
def peephole_assignable(self: Assignable) -> Assignable:
raise NotImplementedError(f"peephole_assignable not implemented for {type(self)}: {self}")
Expand All @@ -100,6 +89,17 @@ def peephole_array_index(self: ArrayIndex) -> Assignable:
return ArrayIndex(peephole_assignable(self.target), peephole_expression(self.index))


@singledispatch
def peephole_expression(self: Expression) -> Expression:
raise NotImplementedError(f"peephole_expression not implemented for {type(self)}: {self}")


@peephole_expression.register(Assignable)
def peephole_expression_assignable(self: Assignable) -> Assignable:
# Assignables are expressions in their own right
return peephole_assignable(self)


@peephole_expression.register(IntegerLiteral)
@peephole_expression.register(FloatLiteral)
@peephole_expression.register(BooleanLiteral)
Expand Down Expand Up @@ -255,17 +255,28 @@ def peephole_array_reallocate(self: ArrayReallocate) -> Expression:
return replace(self, old=old, n_elements=n_elements)


@peephole.register(Declaration)
@singledispatch
def peephole_statement(self: Statement) -> Statement:
raise NotImplementedError(f"peephole not implemented for {type(self)}: {self}")


@peephole_statement.register(Expression)
def peephole_expression_statement(self: Expression) -> Expression:
# Expressions are statements in their own right
return peephole_expression(self)


@peephole_statement.register(Declaration)
def peephole_declaration(self: Declaration) -> Declaration:
return self


@peephole.register(Free)
@peephole_statement.register(Free)
def peephole_free(self: Free) -> Statement:
return Free(peephole_assignable(self.target))


@peephole.register(Assignment)
@peephole_statement.register(Assignment)
def peephole_assignment(self: Assignment) -> Statement:
target = peephole_assignable(self.target)
value = peephole_expression(self.value)
Expand All @@ -276,17 +287,17 @@ def peephole_assignment(self: Assignment) -> Statement:
return Assignment(target, value)


@peephole.register(DeclarationAssignment)
@peephole_statement.register(DeclarationAssignment)
def peephole_declaration_assignment(self: DeclarationAssignment) -> Statement:
value = peephole_expression(self.value)
return replace(self, value=value)


@peephole.register(Block)
@peephole_statement.register(Block)
def peephole_block(self: Block) -> Statement:
statements = []
for old_statement in self.statements:
statement = peephole(old_statement)
statement = peephole_statement(old_statement)
if isinstance(statement, Block) and statement.is_empty():
pass
else:
Expand All @@ -295,11 +306,11 @@ def peephole_block(self: Block) -> Statement:
return replace(self, statements=statements)


@peephole.register(Branch)
@peephole_statement.register(Branch)
def peephole_branch(self: Branch) -> Statement:
condition = peephole_expression(self.condition)
if_true = peephole(self.if_true)
if_false = peephole(self.if_false)
if_true = peephole_statement(self.if_true)
if_false = peephole_statement(self.if_false)

if condition == BooleanLiteral(True):
return if_true
Expand All @@ -316,10 +327,10 @@ def peephole_branch(self: Branch) -> Statement:
return Branch(condition, if_true, if_false)


@peephole.register(Loop)
@peephole_statement.register(Loop)
def peephole_loop(self: Loop) -> Statement:
condition = peephole_expression(self.condition)
body = peephole(self.body)
body = peephole_statement(self.body)

if condition == BooleanLiteral(False):
return Block([])
Expand All @@ -329,18 +340,22 @@ def peephole_loop(self: Loop) -> Statement:
return Loop(condition, body)


@peephole.register(Break)
@peephole_statement.register(Break)
def peephole_break(self: Break) -> Statement:
return self


@peephole.register(Return)
@peephole_statement.register(Return)
def peephole_return(self: Return) -> Statement:
value = peephole_expression(self.value)
return Return(value)


@peephole.register(FunctionDefinition)
def peephole_function_definition(self: FunctionDefinition) -> Statement:
body = peephole(self.body)
def peephole_function_definition(self: FunctionDefinition) -> FunctionDefinition:
body = peephole_statement(self.body)
return replace(self, body=body)


def peephole(self: Module) -> Module:
functions = [peephole_function_definition(function) for function in self.definitions]
return Module(functions)
13 changes: 9 additions & 4 deletions src/tensora/ir/ast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import Sequence

__all__ = [
"Statement",
"Expression",
Expand Down Expand Up @@ -43,10 +41,12 @@
"Break",
"Return",
"FunctionDefinition",
"Module",
]

from dataclasses import dataclass
from functools import reduce
from typing import Sequence

from ..format import Mode
from .types import Type
Expand Down Expand Up @@ -262,7 +262,7 @@ def join(operands: Sequence[Expression | int | str]) -> Expression:

@dataclass(frozen=True, slots=True)
class Address(Expression):
target: Variable
target: Assignable


@dataclass(frozen=True, slots=True)
Expand Down Expand Up @@ -357,8 +357,13 @@ class Return(Statement):


@dataclass(frozen=True, slots=True)
class FunctionDefinition(Statement):
class FunctionDefinition:
name: Variable
parameters: list[Declaration]
return_type: Type
body: Statement


@dataclass(frozen=True, slots=True)
class Module:
definitions: list[FunctionDefinition]
95 changes: 53 additions & 42 deletions src/tensora/iteration_graph/_generate_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
Block,
BooleanToInteger,
Branch,
Declaration,
Equal,
Expression,
FunctionDefinition,
GreaterThan,
IntegerLiteral,
LessThan,
Expand Down Expand Up @@ -342,49 +344,58 @@ def to_ir_sum(self: SumNode, output: Output, kernel_type: KernelType):
return source


def generate_ir(definition: Definition, graph: IterationGraph, kernel_type: KernelType):
def generate_ir(
definition: Definition, graph: IterationGraph, kernel_type: KernelType
) -> FunctionDefinition:
# Function body
source = SourceBuilder()

# Dimensions of all index variables
with source.block("Extract dimensions"):
for index_name, tensor_layer in definition.indexes.items():
declaration = dimension_name(index_name).declare(types.integer)
value = Variable(tensor_layer.name).attr("dimensions").idx(tensor_layer.dimension)
source.append(declaration.assign(value))

# Unpack tensors
with source.block("Unpack tensors"):
for tensor_name, format in definition.formats.items():
for i, mode in enumerate(format.modes):
match mode:
case Mode.dense:
pass
case Mode.compressed:
pos_declaration = pos_name(tensor_name, i).declare(
types.Pointer(types.integer)
)
pos_value = Variable(tensor_name).attr("indices").idx(i).idx(0)
source.append(pos_declaration.assign(pos_value))
crd_declaration = crd_name(tensor_name, i).declare(
types.Pointer(types.integer)
)
crd_value = Variable(tensor_name).attr("indices").idx(i).idx(1)
source.append(crd_declaration.assign(crd_value))

vals_declaration = vals_name(tensor_name).declare(types.Pointer(types.float))
vals_value = Variable(tensor_name).attr("vals")
source.append(vals_declaration.assign(vals_value))

output = AppendOutput(definition.output_variable, 0)
source.append(output.write_declarations(kernel_type))

source.append(to_ir_iteration_graph(graph, output, kernel_type))

source.append(output.write_cleanup(kernel_type))

source.append(Return(IntegerLiteral(0)))

# Function declaration
parameters = {name: types.Pointer(types.tensor) for name in definition.formats.keys()}
with source.function_definition(kernel_type.name, parameters, types.integer):
# Dimensions of all index variables
with source.block("Extract dimensions"):
for index_name, tensor_layer in definition.indexes.items():
declaration = dimension_name(index_name).declare(types.integer)
value = Variable(tensor_layer.name).attr("dimensions").idx(tensor_layer.dimension)
source.append(declaration.assign(value))

# Unpack tensors
with source.block("Unpack tensors"):
for tensor_name, format in definition.formats.items():
for i, mode in enumerate(format.modes):
match mode:
case Mode.dense:
pass
case Mode.compressed:
pos_declaration = pos_name(tensor_name, i).declare(
types.Pointer(types.integer)
)
pos_value = Variable(tensor_name).attr("indices").idx(i).idx(0)
source.append(pos_declaration.assign(pos_value))
crd_declaration = crd_name(tensor_name, i).declare(
types.Pointer(types.integer)
)
crd_value = Variable(tensor_name).attr("indices").idx(i).idx(1)
source.append(crd_declaration.assign(crd_value))

vals_declaration = vals_name(tensor_name).declare(types.Pointer(types.float))
vals_value = Variable(tensor_name).attr("vals")
source.append(vals_declaration.assign(vals_value))

output = AppendOutput(definition.output_variable, 0)
source.append(output.write_declarations(kernel_type))

source.append(to_ir_iteration_graph(graph, output, kernel_type))

source.append(output.write_cleanup(kernel_type))

source.append(Return(IntegerLiteral(0)))
parameters = [
Declaration(Variable(name), types.Pointer(types.tensor))
for name in definition.formats.keys()
]
function_definition = FunctionDefinition(
Variable(kernel_type.name), parameters, types.integer, source.finalize()
)

return source
return function_definition
Loading

0 comments on commit 66b87e8

Please sign in to comment.