Skip to content

Commit

Permalink
Add concept of Problem
Browse files Browse the repository at this point in the history
* Use Problem as validation of assignment and formats
* Use this validation before invoking code generation
* Centralize code generation
* Add option to CLI to choose TACO as the tensor compiler
  • Loading branch information
drhagen committed Dec 2, 2023
1 parent dd2c73f commit c242b23
Show file tree
Hide file tree
Showing 20 changed files with 343 additions and 188 deletions.
2 changes: 1 addition & 1 deletion src/tensora/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .compile import TensorCompiler
from .format import Format, Mode
from .function import evaluate, evaluate_taco, evaluate_tensora, tensor_method
from .generate import TensorCompiler
from .tensor import Tensor
70 changes: 40 additions & 30 deletions src/tensora/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from typing import Annotated, Optional

import typer
from parsita import Failure, Success
from parsita import ParseError
from returns.result import Failure, Success

from .desugar.exceptions import DiagonalAccessError, NoKernelFoundError
from .desugar import DiagonalAccessError, NoKernelFoundError
from .expression import parse_assignment
from .format import Format, Mode, parse_format
from .format import parse_named_format
from .generate import TensorCompiler, generate_c_code
from .kernel_type import KernelType
from .native import generate_c_code_from_parsed
from .problem import make_problem

app = typer.Typer()

Expand Down Expand Up @@ -43,6 +45,14 @@ def tensora(
help="The type of kernel that will be generated. Can be mentioned multiple times.",
),
] = [KernelType.compute],
tensor_compiler: Annotated[
TensorCompiler,
typer.Option(
"--compiler",
"-c",
help="The tensor algebra compiler to use to generate the kernel.",
),
] = TensorCompiler.tensora,
output_path: Annotated[
Optional[Path],
typer.Option(
Expand All @@ -61,45 +71,45 @@ def tensora(
case Failure(error):
typer.echo(f"Failed to parse assignment:\n{error}", err=True)
raise typer.Exit(1)
case Success(sugar):
sugar = sugar
case Success(parsed_assignment):
pass
case _:
raise NotImplementedError()

# Parse formats
parsed_formats = {}
for target_format_string in target_format_strings:
split_format = target_format_string.split(":")
if len(split_format) != 2:
typer.echo(
f"Format must be of the form 'target:format_string': {target_format_string}",
err=True,
)
raise typer.Exit(1)

target, format_string = split_format
match parse_named_format(target_format_string):
case Failure(ParseError(_) as error):
typer.echo(f"Failed to parse format:\n{error}", err=True)
raise typer.Exit(1)
case Failure(error):
typer.echo(str(error), err=True)
raise typer.Exit(1)
case Success((target, format)):
pass
case _:
raise NotImplementedError()

if target in parsed_formats:
typer.echo(f"Format for {target} was mentioned multiple times", err=True)
raise typer.Exit(1)

match parse_format(format_string):
case Failure(error):
typer.echo(f"Failed to parse format:\n{error}", err=True)
typer.Exit(1)
case Success(format):
parsed_formats[target] = format
parsed_formats[target] = format

# Fill in missing formats with dense formats
# Use the order of variable_orders to determine the parameter order
formats = {}
for variable_name, order in sugar.variable_orders().items():
if variable_name in parsed_formats:
formats[variable_name] = parsed_formats[variable_name]
else:
formats[variable_name] = Format((Mode.dense,) * order, tuple(range(order)))
# Validate and standardize assignment and formats
match make_problem(parsed_assignment, parsed_formats):
case Failure(error):
typer.echo(str(error), err=True)
raise typer.Exit(1)
case Success(problem):
pass
case _:
raise NotImplementedError()

# Generate code
try:
code = generate_c_code_from_parsed(sugar, formats, kernel_types)
code = generate_c_code(problem, kernel_types, tensor_compiler)
except (DiagonalAccessError, NoKernelFoundError) as error:
typer.echo(error, err=True)
raise typer.Exit(1)
Expand Down
89 changes: 23 additions & 66 deletions src/tensora/compile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__all__ = [
"taco_kernel",
"generate_library",
"allocate_taco_structure",
"taco_structure_to_cffi",
"take_ownership_of_arrays",
Expand All @@ -8,20 +8,19 @@
]

import re
import subprocess
import tempfile
import threading
from enum import Enum, auto
from pathlib import Path
from typing import Any, List, Tuple
from typing import Any
from weakref import WeakKeyDictionary

from cffi import FFI

from .expression import deparse_to_taco
from .expression.ast import Assignment
from .format import Format
from .native import generate_c_code_from_parsed
from tensora.generate import TensorCompiler

from .generate import generate_c_code
from .kernel_type import KernelType
from .problem import Problem

lock = threading.Lock()

Expand Down Expand Up @@ -97,69 +96,27 @@
tensor_lib = tensor_cdefs.dlopen(None)


def format_to_taco_format(format: Format):
return (
"".join(mode.character for mode in format.modes)
+ ":"
+ ",".join(map(str, format.ordering))
)


class TensorCompiler(Enum):
taco = auto()
tensora = auto()


def taco_kernel(
expression: Assignment,
formats: dict[str, Format],
compiler: TensorCompiler = TensorCompiler.tensora,
) -> Tuple[List[str], Any]:
"""Call taco with expression and compile resulting function.
def generate_library(
problem: Problem, compiler: TensorCompiler = TensorCompiler.tensora
) -> tuple[list[str], Any]:
"""Generate source, compile it, and load it.
Given an expression and a set of formats:
(1) call out to taco to get the source code for the evaluate function that runs that expression for those formats
Given a problem:
(1) invoke the tensor algebra compiler to generate C code for evaluate
(2) parse the signature in the source to determine the order of arguments
(3) compile the source with cffi
(4) return the list of parameter names and the compiled library
Because compilation can take a non-trivial amount of time, the results of this function is cached by a
`functools.lru_cache`, which is configured to store the results of the 256 most recent calls to this function.
Args:
expression: An expression that can parsed by taco.
formats: A frozen set of pairs of strings. It must be a frozen set because `lru_cache` requires that the
arguments be hashable and therefore immutable. The first element of each pair is a variable name; the second
element is the format in taco format (e.g. 'dd:1,0', 'dss:0,1,2'). Scalar variables must not be listed because
taco does not understand them having a format.
problem: A valid tensor algebra expression and associated tensor formats
compiler: The tensor algebra compiler to use to generate the C code
Returns:
A tuple where the first element is the list of variable names in the order they appear in the function
signature, and the second element is the compiled FFILibrary which has a single method `evaluate` which expects
cffi pointers to taco_tensor_t instances in order specified by the list of variable names.
"""
match compiler:
case TensorCompiler.taco:
expression_string = deparse_to_taco(expression)
format_strings = frozenset(
(parameter_name, format_to_taco_format(format))
for parameter_name, format in formats.items()
if format.order != 0 # Taco does not like formats for scalars
)
# Call taco to write the kernels to standard out
result = subprocess.run(
[taco_binary, expression_string, "-print-evaluate", "-print-nocolor"]
+ [f"-f={name}:{format}" for name, format in format_strings],
capture_output=True,
text=True,
)

if result.returncode != 0:
raise RuntimeError(result.stderr)

source = result.stdout
case TensorCompiler.tensora:
source = generate_c_code_from_parsed(expression, formats)
source = generate_c_code(problem, [KernelType.evaluate], compiler)

# Determine signature
# 1) Find function by name and capture its parameter list
Expand Down Expand Up @@ -197,7 +154,7 @@ def taco_kernel(


def allocate_taco_structure(
mode_types: Tuple[int, ...], dimensions: Tuple[int, ...], mode_ordering: Tuple[int, ...]
mode_types: tuple[int, ...], dimensions: tuple[int, ...], mode_ordering: tuple[int, ...]
):
"""Allocate all parts of a taco tensor except growable arrays.
Expand Down Expand Up @@ -285,12 +242,12 @@ def allocate_taco_structure(


def taco_structure_to_cffi(
indices: List[List[List[int]]],
vals: List[float],
indices: list[list[list[int]]],
vals: list[float],
*,
mode_types: Tuple[int, ...],
dimensions: Tuple[int, ...],
mode_ordering: Tuple[int, ...],
mode_types: tuple[int, ...],
dimensions: tuple[int, ...],
mode_ordering: tuple[int, ...],
):
"""Build a cffi taco tensor from Python data.
Expand Down Expand Up @@ -485,5 +442,5 @@ def take_ownership_of_tensor(cffi_tensor) -> None:
take_ownership_of_tensor_members(cffi_tensor)


def weakly_increasing(list: List[int]):
def weakly_increasing(list: list[int]):
return all(x <= y for x, y in zip(list, list[1:]))
4 changes: 3 additions & 1 deletion src/tensora/desugar/best_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from .to_iteration_graphs import to_iteration_graphs


def best_algorithm(assignment: ast.Assignment, formats: dict[str, Format]) -> IterationGraph:
def best_algorithm(
assignment: ast.Assignment, formats: dict[str, Format | None]
) -> IterationGraph:
match next(to_iteration_graphs(assignment, formats), None):
case None:
raise NoKernelFoundError()
Expand Down
2 changes: 1 addition & 1 deletion src/tensora/desugar/to_iteration_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def merge_assignment(


def to_iteration_graphs(
assignment: ast.Assignment, formats: dict[str, Format]
assignment: ast.Assignment, formats: dict[str, Format | None]
) -> Iterator[ig.IterationGraph]:
output_format = formats[assignment.target.name]
output_layers = {
Expand Down
2 changes: 1 addition & 1 deletion src/tensora/expression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import ast
from .deparse_to_taco import deparse_to_taco
from .exceptions import InconsistentVariableSizeError, MutatingAssignmentError
from .exceptions import InconsistentDimensionsError, MutatingAssignmentError
from .parser import parse_assignment
4 changes: 2 additions & 2 deletions src/tensora/expression/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class Assignment:
expression: Expression

def __post_init__(self):
from .exceptions import InconsistentVariableSizeError, MutatingAssignmentError
from .exceptions import InconsistentDimensionsError, MutatingAssignmentError

target_name = self.target.name

Expand All @@ -215,7 +215,7 @@ def __post_init__(self):

for variable in rest:
if first.order != variable.order:
raise InconsistentVariableSizeError(self, first, variable)
raise InconsistentDimensionsError(self, first, variable)

variable_orders[name] = first.order

Expand Down
4 changes: 2 additions & 2 deletions src/tensora/expression/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["MutatingAssignmentError", "InconsistentVariableSizeError"]
__all__ = ["MutatingAssignmentError", "InconsistentDimensionsError"]

from dataclasses import dataclass

Expand All @@ -17,7 +17,7 @@ def __str__(self):


@dataclass(frozen=True, slots=True)
class InconsistentVariableSizeError(Exception):
class InconsistentDimensionsError(Exception):
assignment: Assignment
first: Variable
second: Variable
Expand Down
8 changes: 3 additions & 5 deletions src/tensora/expression/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from returns import result

from .ast import Add, Assignment, Float, Integer, Multiply, Scalar, Subtract, Tensor
from .exceptions import InconsistentVariableSizeError, MutatingAssignmentError
from .exceptions import InconsistentDimensionsError, MutatingAssignmentError


def make_expression(first, rest):
Expand Down Expand Up @@ -45,10 +45,8 @@ class TensorExpressionParsers(ParserContext, whitespace=r"[ ]*"):

def parse_assignment(
string: str
) -> result.Result[
Assignment, ParseError | MutatingAssignmentError | InconsistentVariableSizeError
]:
) -> result.Result[Assignment, ParseError | MutatingAssignmentError | InconsistentDimensionsError]:
try:
return TensorExpressionParsers.assignment.parse(string)
except (MutatingAssignmentError, InconsistentVariableSizeError) as e:
except (MutatingAssignmentError, InconsistentDimensionsError) as e:
return result.Failure(e)
3 changes: 2 additions & 1 deletion src/tensora/format/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .exceptions import InvalidModeOrderingError
from .format import Format, Mode
from .parser import parse_format
from .format_order import format_order
from .parser import parse_format, parse_named_format
7 changes: 7 additions & 0 deletions src/tensora/format/format_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__all__ = ["format_order"]

from .format import Format


def format_order(format: Format | None) -> int | None:
return len(format.modes) if format is not None else None
18 changes: 15 additions & 3 deletions src/tensora/format/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["parse_format"]
__all__ = ["parse_format", "parse_named_format"]

from parsita import ParseError, ParserContext, lit, reg, rep
from parsita.util import constant
Expand Down Expand Up @@ -30,9 +30,21 @@ class FormatParsers(ParserContext):

format = format_without_orderings | format_with_orderings

variable = reg(r"[a-zA-Z_][a-zA-Z0-9_]*")
named_format = variable << ":" & format > tuple

def parse_format(format: str) -> result.Result[Format, ParseError | InvalidModeOrderingError]:

def parse_format(string: str, /) -> result.Result[Format, ParseError | InvalidModeOrderingError]:
try:
return FormatParsers.format.parse(string)
except InvalidModeOrderingError as e:
return result.Failure(e)


def parse_named_format(
string: str, /
) -> result.Result[tuple[str, Format], ParseError | InvalidModeOrderingError]:
try:
return FormatParsers.format.parse(format)
return FormatParsers.named_format.parse(string)
except InvalidModeOrderingError as e:
return result.Failure(e)
Loading

0 comments on commit c242b23

Please sign in to comment.