Skip to content

Commit

Permalink
Clean up native code generation
Browse files Browse the repository at this point in the history
  • Loading branch information
drhagen committed Nov 11, 2023
1 parent da8d223 commit 2655b99
Show file tree
Hide file tree
Showing 21 changed files with 82 additions and 559 deletions.
2 changes: 2 additions & 0 deletions src/tensora/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .ast_to_c import ast_to_c
from .type_to_c import type_to_c
10 changes: 2 additions & 8 deletions src/tensora/codegen/type_to_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,12 @@ def type_to_c_pointer(type: Pointer, variable: Optional[str] = None):

@type_to_c.register(Array)
def type_to_c_array(type: Array, variable: Optional[str] = None):
if variable is None:
return f"{type_to_c(type.element)}[]"
else:
return f"{type_to_c(type.element)} {variable}[]"
return f"{type_to_c(type.element, variable)}[]"


@type_to_c.register(FixedArray)
def type_to_c_fixed_array(type: FixedArray, variable: Optional[str] = None):
if variable is None:
return f"{type_to_c(type.element)}[{type.n}]"
else:
return f"{type_to_c(type.element)} {variable}[{type.n}]"
return f"{type_to_c(type.element, variable)}[{type.n}]"


def space_variable(variable: Optional[str] = None):
Expand Down
6 changes: 6 additions & 0 deletions src/tensora/desugar/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .ast import Assignment, Integer, Float, Scalar, Tensor, Add, Multiply, Contract
from .collect_lattices import collect_lattices
from .desugar_expression import desugar_assignment
from .id import Id
from .to_identifiable import to_identifiable
from .to_iteration_graph import to_iteration_graph
2 changes: 1 addition & 1 deletion src/tensora/desugar/collect_lattices.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from functools import singledispatch
from tensora.format.format import Format

from tensora.iteration_graph.merge_lattice.merge_lattice import LatticeConjuction, LatticeDisjunction, LatticeLeaf, Lattice
from . import ast
from ..iteration_graph import LatticeConjuction, LatticeDisjunction, LatticeLeaf, Lattice
from ..iteration_graph.identifiable_expression import ast as id


Expand Down
6 changes: 3 additions & 3 deletions src/tensora/desugar/id.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

__all__ = ["Id"]

from dataclasses import dataclass


@dataclass(frozen=True)
Expand All @@ -10,4 +10,4 @@ class Id:

def to_tensor_leaf(self):
from tensora.iteration_graph.identifiable_expression.tensor_leaf import TensorLeaf
return TensorLeaf(self.name, self.instance)
return TensorLeaf(self.name, self.instance)
4 changes: 2 additions & 2 deletions src/tensora/desugar/to_identifiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from functools import singledispatch

from tensora.format.format import Format
from ..iteration_graph.identifiable_expression import ast as id
from . import ast as desugar
from ..format import Format
from ..iteration_graph.identifiable_expression import ast as id


def to_identifiable(assignment: desugar.Assignment, input_formats: dict[str, Format], output_format: Format) -> id.Assignment:
Expand Down
11 changes: 4 additions & 7 deletions src/tensora/desugar/to_iteration_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
from dataclasses import replace
from functools import singledispatch
from itertools import count
from typing import Dict, Iterator, Tuple
from .collect_lattices import collect_lattices
from tensora.format.format import Format, Mode
from tensora.iteration_graph.identifiable_expression.tensor_leaf import TensorLeaf

from tensora.iteration_graph.merge_lattice.merge_lattice import Lattice, LatticeLeaf
from typing import Dict, Iterator

from . import ast
from ..iteration_graph import iteration_graph as graph
from .collect_lattices import collect_lattices
from ..iteration_graph import Lattice, LatticeLeaf, iteration_graph as graph
from ..iteration_graph.identifiable_expression import ast as id
from ..format import Format


def to_iteration_graph(assignment: ast.Assignment, formats: dict[str, Format], output_format: Format) -> graph.IterationGraph:
Expand Down
3 changes: 3 additions & 0 deletions src/tensora/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ast import Statement, Expression, Assignable, Variable, AttributeAccess, ArrayIndex, IntegerLiteral, FloatLiteral, BooleanLiteral, ModeLiteral, ArrayLiteral, Add, Subtract, Multiply, Equal, NotEqual, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, And, Or, FunctionCall, Max, Min, Address, BooleanToInteger, Allocate, ArrayAllocate, ArrayReallocate, Free, Declaration, Assignment, DeclarationAssignment, Block, Branch, Loop, Break, Return, FunctionDefinition
from .builder import SourceBuilder
from .peephole import peephole
4 changes: 2 additions & 2 deletions src/tensora/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from contextlib import contextmanager
from typing import List, Dict, Optional, Union

from tensora.ir.ast import Statement, FunctionDefinition, Variable, Declaration, Block, Expression, Branch, Loop
from tensora.ir.types import Type
from .ast import Statement, FunctionDefinition, Variable, Declaration, Block, Expression, Branch, Loop
from .types import Type


class Builder:
Expand Down
3 changes: 3 additions & 0 deletions src/tensora/iteration_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .iteration_graph_to_c_code import generate_c_code, KernelType
from .problem import Problem
from .merge_lattice import Lattice, LatticeLeaf, LatticeConjuction, LatticeDisjunction, IterationMode
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from .ast import Assignment, Expression, Variable
from .exhaust_tensors import exhaust_tensors
from .index_dimension import index_dimension
from .index_dimensions import index_dimensions
from .tensor_leaf import TensorLeaf
from .to_ir import to_ir
from .variables import to_c_code
4 changes: 2 additions & 2 deletions src/tensora/iteration_graph/identifiable_expression/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from dataclasses import dataclass
from typing import Tuple

from tensora import Mode
from tensora.iteration_graph.identifiable_expression.tensor_leaf import TensorLeaf
from ...format import Mode
from .tensor_leaf import TensorLeaf


class Node:
Expand Down
2 changes: 1 addition & 1 deletion src/tensora/iteration_graph/merge_lattice/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .merge_lattice import *
from .merge_lattice import Lattice, LatticeLeaf, LatticeConjuction, LatticeDisjunction, IterationMode
8 changes: 4 additions & 4 deletions src/tensora/iteration_graph/problem.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
__all__ = ["Problem"]

from dataclasses import dataclass
from typing import Dict

from tensora import Format
from tensora.iteration_graph.identifiable_expression import Assignment
from tensora.iteration_graph.identifiable_expression.index_dimension import index_dimension
from tensora.iteration_graph.identifiable_expression.index_dimensions import index_dimensions
from ..format import Format
from .identifiable_expression import Assignment, index_dimension, index_dimensions


@dataclass(frozen=True)
Expand Down
52 changes: 0 additions & 52 deletions src/tensora/iteration_graph/problem_to_iteration_graph.py

This file was deleted.

26 changes: 26 additions & 0 deletions src/tensora/native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
__all__ = ["generate_code", "KernelType"]

from typing import Dict
from tensora.expression import parse_assignment
from tensora.format import parse_format
from tensora.desugar import desugar_assignment, to_identifiable, to_iteration_graph
from tensora.iteration_graph import KernelType, generate_c_code, Problem
from tensora.codegen import ast_to_c
from tensora.ir import peephole


def generate_code(assignment: str, output_format: str, input_formats: Dict[str, str], kernel_type: KernelType) -> str:
assignment_parsed = parse_assignment(assignment).unwrap()
input_formats_parsed = {name: parse_format(format).unwrap() for name, format in input_formats.items()}
output_format_parsed = parse_format(output_format).unwrap()

desugar = desugar_assignment(assignment_parsed)

identifiable_assignment = to_identifiable(desugar, input_formats_parsed, output_format_parsed)

graph = to_iteration_graph(desugar, input_formats_parsed, output_format_parsed)
problem = Problem(identifiable_assignment, input_formats_parsed, output_format_parsed)

ir = generate_c_code(problem, graph, kernel_type).finalize()

return ast_to_c(peephole(ir))
30 changes: 0 additions & 30 deletions src/tensora/native_generator.py

This file was deleted.

12 changes: 6 additions & 6 deletions tests/codegen/test_ast_to_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def clean(string: str) -> str:
(Declaration(Variable("x"), float), "double x"),
(Declaration(Variable("x"), tensor), "taco_tensor_t x"),
(Declaration(Variable("x"), hash_table), "hash_table_t x"),
(Declaration(Variable("x"), Pointer(float)), "double * restrict x"),
(Declaration(Variable("x"), Pointer(Pointer(integer))), "int32_t * restrict * restrict x"),
(Declaration(Variable("x"), Array(float)), "double[] x"),
(Declaration(Variable("x"), Array(Array(integer))), "int32_t[][] x"),
(Declaration(Variable("x"), FixedArray(mode, 3)), "taco_mode_t[3] x"),
(Declaration(Variable("x"), FixedArray(FixedArray(mode, 3), 2)), "taco_mode_t[3][2] x"),
(Declaration(Variable("x"), Pointer(float)), "double* restrict x"),
(Declaration(Variable("x"), Pointer(Pointer(integer))), "int32_t* restrict* restrict x"),
(Declaration(Variable("x"), Array(float)), "double x[]"),
(Declaration(Variable("x"), Array(Array(integer))), "int32_t x[][]"),
(Declaration(Variable("x"), FixedArray(mode, 3)), "taco_mode_t x[3]"),
(Declaration(Variable("x"), FixedArray(FixedArray(mode, 3), 2)), "taco_mode_t x[3][2]"),
# Assignment
(Assignment(Variable("x"), Variable("y")), "x = y"),
(Assignment(Variable("x"), Add(Variable("x"), IntegerLiteral(1))), "x++"),
Expand Down
8 changes: 0 additions & 8 deletions tests/test_combinatorically.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,3 @@ def test_matrix_multiply_add(dense1, dense2, dense3, format1, format2, format3,
in1=(dense1, format1),
in2=(dense2, format2),
in3=(dense3, format3))

@pytest.mark.parametrize('dense1', [[[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]], [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]])
@pytest.mark.parametrize('format1', ['sss'])
@pytest.mark.parametrize('format_out', ['ss'])
def test_inner_contract(dense1, format1, format_out):
assert_same_as_dense('out(i,k) = in1(i,j,k) * in2(i,j,k)', format_out,
in1=(dense1, format1),
in2=(dense1, format1))
10 changes: 10 additions & 0 deletions tests/test_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from tensora.native import generate_code, KernelType


def test_native_codegen():
assignment = "f(i) = A0(i) + A1(i,j) * x(j) + A2(i,k,l) * x(k) * x(l)"
output_format = "d"
input_formats = {"A0": "d", "A1": "ds", "A2": "dss", "x": "d"}
kernel_type = KernelType.compute
code = generate_code(assignment, output_format, input_formats, kernel_type)
assert isinstance(code, str)
Loading

0 comments on commit 2655b99

Please sign in to comment.