diff --git a/src/tensora/__init__.py b/src/tensora/__init__.py index a858d3f..9b3a383 100644 --- a/src/tensora/__init__.py +++ b/src/tensora/__init__.py @@ -1,4 +1,4 @@ -from .compile import TensorCompiler from .format import Format, Mode from .function import evaluate, evaluate_taco, evaluate_tensora, tensor_method +from .generate import TensorCompiler from .tensor import Tensor diff --git a/src/tensora/cli.py b/src/tensora/cli.py index 2929139..4f1b75e 100644 --- a/src/tensora/cli.py +++ b/src/tensora/cli.py @@ -4,13 +4,15 @@ from typing import Annotated, Optional import typer -from parsita import Failure, Success +from parsita import ParseError +from returns.result import Failure, Success -from .desugar.exceptions import DiagonalAccessError, NoKernelFoundError +from .desugar import DiagonalAccessError, NoKernelFoundError from .expression import parse_assignment -from .format import Format, Mode, parse_format +from .format import parse_named_format +from .generate import TensorCompiler, generate_c_code from .kernel_type import KernelType -from .native import generate_c_code_from_parsed +from .problem import make_problem app = typer.Typer() @@ -43,6 +45,14 @@ def tensora( help="The type of kernel that will be generated. Can be mentioned multiple times.", ), ] = [KernelType.compute], + tensor_compiler: Annotated[ + TensorCompiler, + typer.Option( + "--compiler", + "-c", + help="The tensor algebra compiler to use to generate the kernel.", + ), + ] = TensorCompiler.tensora, output_path: Annotated[ Optional[Path], typer.Option( @@ -61,45 +71,45 @@ def tensora( case Failure(error): typer.echo(f"Failed to parse assignment:\n{error}", err=True) raise typer.Exit(1) - case Success(sugar): - sugar = sugar + case Success(parsed_assignment): + pass + case _: + raise NotImplementedError() # Parse formats parsed_formats = {} for target_format_string in target_format_strings: - split_format = target_format_string.split(":") - if len(split_format) != 2: - typer.echo( - f"Format must be of the form 'target:format_string': {target_format_string}", - err=True, - ) - raise typer.Exit(1) - - target, format_string = split_format + match parse_named_format(target_format_string): + case Failure(ParseError(_) as error): + typer.echo(f"Failed to parse format:\n{error}", err=True) + raise typer.Exit(1) + case Failure(error): + typer.echo(str(error), err=True) + raise typer.Exit(1) + case Success((target, format)): + pass + case _: + raise NotImplementedError() if target in parsed_formats: typer.echo(f"Format for {target} was mentioned multiple times", err=True) raise typer.Exit(1) - match parse_format(format_string): - case Failure(error): - typer.echo(f"Failed to parse format:\n{error}", err=True) - typer.Exit(1) - case Success(format): - parsed_formats[target] = format + parsed_formats[target] = format - # Fill in missing formats with dense formats - # Use the order of variable_orders to determine the parameter order - formats = {} - for variable_name, order in sugar.variable_orders().items(): - if variable_name in parsed_formats: - formats[variable_name] = parsed_formats[variable_name] - else: - formats[variable_name] = Format((Mode.dense,) * order, tuple(range(order))) + # Validate and standardize assignment and formats + match make_problem(parsed_assignment, parsed_formats): + case Failure(error): + typer.echo(str(error), err=True) + raise typer.Exit(1) + case Success(problem): + pass + case _: + raise NotImplementedError() # Generate code try: - code = generate_c_code_from_parsed(sugar, formats, kernel_types) + code = generate_c_code(problem, kernel_types, tensor_compiler) except (DiagonalAccessError, NoKernelFoundError) as error: typer.echo(error, err=True) raise typer.Exit(1) diff --git a/src/tensora/compile.py b/src/tensora/compile.py index 6c9f152..e357b31 100644 --- a/src/tensora/compile.py +++ b/src/tensora/compile.py @@ -1,5 +1,5 @@ __all__ = [ - "taco_kernel", + "generate_library", "allocate_taco_structure", "taco_structure_to_cffi", "take_ownership_of_arrays", @@ -8,20 +8,19 @@ ] import re -import subprocess import tempfile import threading -from enum import Enum, auto from pathlib import Path -from typing import Any, List, Tuple +from typing import Any from weakref import WeakKeyDictionary from cffi import FFI -from .expression import deparse_to_taco -from .expression.ast import Assignment -from .format import Format -from .native import generate_c_code_from_parsed +from tensora.generate import TensorCompiler + +from .generate import generate_c_code +from .kernel_type import KernelType +from .problem import Problem lock = threading.Lock() @@ -97,69 +96,27 @@ tensor_lib = tensor_cdefs.dlopen(None) -def format_to_taco_format(format: Format): - return ( - "".join(mode.character for mode in format.modes) - + ":" - + ",".join(map(str, format.ordering)) - ) - - -class TensorCompiler(Enum): - taco = auto() - tensora = auto() - - -def taco_kernel( - expression: Assignment, - formats: dict[str, Format], - compiler: TensorCompiler = TensorCompiler.tensora, -) -> Tuple[List[str], Any]: - """Call taco with expression and compile resulting function. +def generate_library( + problem: Problem, compiler: TensorCompiler = TensorCompiler.tensora +) -> tuple[list[str], Any]: + """Generate source, compile it, and load it. - Given an expression and a set of formats: - (1) call out to taco to get the source code for the evaluate function that runs that expression for those formats + Given a problem: + (1) invoke the tensor algebra compiler to generate C code for evaluate (2) parse the signature in the source to determine the order of arguments (3) compile the source with cffi (4) return the list of parameter names and the compiled library - Because compilation can take a non-trivial amount of time, the results of this function is cached by a - `functools.lru_cache`, which is configured to store the results of the 256 most recent calls to this function. - Args: - expression: An expression that can parsed by taco. - formats: A frozen set of pairs of strings. It must be a frozen set because `lru_cache` requires that the - arguments be hashable and therefore immutable. The first element of each pair is a variable name; the second - element is the format in taco format (e.g. 'dd:1,0', 'dss:0,1,2'). Scalar variables must not be listed because - taco does not understand them having a format. + problem: A valid tensor algebra expression and associated tensor formats + compiler: The tensor algebra compiler to use to generate the C code Returns: A tuple where the first element is the list of variable names in the order they appear in the function signature, and the second element is the compiled FFILibrary which has a single method `evaluate` which expects cffi pointers to taco_tensor_t instances in order specified by the list of variable names. """ - match compiler: - case TensorCompiler.taco: - expression_string = deparse_to_taco(expression) - format_strings = frozenset( - (parameter_name, format_to_taco_format(format)) - for parameter_name, format in formats.items() - if format.order != 0 # Taco does not like formats for scalars - ) - # Call taco to write the kernels to standard out - result = subprocess.run( - [taco_binary, expression_string, "-print-evaluate", "-print-nocolor"] - + [f"-f={name}:{format}" for name, format in format_strings], - capture_output=True, - text=True, - ) - - if result.returncode != 0: - raise RuntimeError(result.stderr) - - source = result.stdout - case TensorCompiler.tensora: - source = generate_c_code_from_parsed(expression, formats) + source = generate_c_code(problem, [KernelType.evaluate], compiler) # Determine signature # 1) Find function by name and capture its parameter list @@ -197,7 +154,7 @@ def taco_kernel( def allocate_taco_structure( - mode_types: Tuple[int, ...], dimensions: Tuple[int, ...], mode_ordering: Tuple[int, ...] + mode_types: tuple[int, ...], dimensions: tuple[int, ...], mode_ordering: tuple[int, ...] ): """Allocate all parts of a taco tensor except growable arrays. @@ -285,12 +242,12 @@ def allocate_taco_structure( def taco_structure_to_cffi( - indices: List[List[List[int]]], - vals: List[float], + indices: list[list[list[int]]], + vals: list[float], *, - mode_types: Tuple[int, ...], - dimensions: Tuple[int, ...], - mode_ordering: Tuple[int, ...], + mode_types: tuple[int, ...], + dimensions: tuple[int, ...], + mode_ordering: tuple[int, ...], ): """Build a cffi taco tensor from Python data. @@ -485,5 +442,5 @@ def take_ownership_of_tensor(cffi_tensor) -> None: take_ownership_of_tensor_members(cffi_tensor) -def weakly_increasing(list: List[int]): +def weakly_increasing(list: list[int]): return all(x <= y for x, y in zip(list, list[1:])) diff --git a/src/tensora/desugar/best_algorithm.py b/src/tensora/desugar/best_algorithm.py index 3e2e3bc..bcf311c 100644 --- a/src/tensora/desugar/best_algorithm.py +++ b/src/tensora/desugar/best_algorithm.py @@ -7,7 +7,9 @@ from .to_iteration_graphs import to_iteration_graphs -def best_algorithm(assignment: ast.Assignment, formats: dict[str, Format]) -> IterationGraph: +def best_algorithm( + assignment: ast.Assignment, formats: dict[str, Format | None] +) -> IterationGraph: match next(to_iteration_graphs(assignment, formats), None): case None: raise NoKernelFoundError() diff --git a/src/tensora/desugar/to_iteration_graphs.py b/src/tensora/desugar/to_iteration_graphs.py index 18c1926..a98b338 100644 --- a/src/tensora/desugar/to_iteration_graphs.py +++ b/src/tensora/desugar/to_iteration_graphs.py @@ -249,7 +249,7 @@ def merge_assignment( def to_iteration_graphs( - assignment: ast.Assignment, formats: dict[str, Format] + assignment: ast.Assignment, formats: dict[str, Format | None] ) -> Iterator[ig.IterationGraph]: output_format = formats[assignment.target.name] output_layers = { diff --git a/src/tensora/expression/__init__.py b/src/tensora/expression/__init__.py index 5fe3bc9..715e3fb 100644 --- a/src/tensora/expression/__init__.py +++ b/src/tensora/expression/__init__.py @@ -1,4 +1,4 @@ from . import ast from .deparse_to_taco import deparse_to_taco -from .exceptions import InconsistentVariableSizeError, MutatingAssignmentError +from .exceptions import InconsistentDimensionsError, MutatingAssignmentError from .parser import parse_assignment diff --git a/src/tensora/expression/ast.py b/src/tensora/expression/ast.py index d26c541..25152c6 100644 --- a/src/tensora/expression/ast.py +++ b/src/tensora/expression/ast.py @@ -203,7 +203,7 @@ class Assignment: expression: Expression def __post_init__(self): - from .exceptions import InconsistentVariableSizeError, MutatingAssignmentError + from .exceptions import InconsistentDimensionsError, MutatingAssignmentError target_name = self.target.name @@ -215,7 +215,7 @@ def __post_init__(self): for variable in rest: if first.order != variable.order: - raise InconsistentVariableSizeError(self, first, variable) + raise InconsistentDimensionsError(self, first, variable) variable_orders[name] = first.order diff --git a/src/tensora/expression/exceptions.py b/src/tensora/expression/exceptions.py index c9f155b..2b97d6d 100644 --- a/src/tensora/expression/exceptions.py +++ b/src/tensora/expression/exceptions.py @@ -1,4 +1,4 @@ -__all__ = ["MutatingAssignmentError", "InconsistentVariableSizeError"] +__all__ = ["MutatingAssignmentError", "InconsistentDimensionsError"] from dataclasses import dataclass @@ -17,7 +17,7 @@ def __str__(self): @dataclass(frozen=True, slots=True) -class InconsistentVariableSizeError(Exception): +class InconsistentDimensionsError(Exception): assignment: Assignment first: Variable second: Variable diff --git a/src/tensora/expression/parser.py b/src/tensora/expression/parser.py index 6087424..ea58e76 100644 --- a/src/tensora/expression/parser.py +++ b/src/tensora/expression/parser.py @@ -7,7 +7,7 @@ from returns import result from .ast import Add, Assignment, Float, Integer, Multiply, Scalar, Subtract, Tensor -from .exceptions import InconsistentVariableSizeError, MutatingAssignmentError +from .exceptions import InconsistentDimensionsError, MutatingAssignmentError def make_expression(first, rest): @@ -45,10 +45,8 @@ class TensorExpressionParsers(ParserContext, whitespace=r"[ ]*"): def parse_assignment( string: str -) -> result.Result[ - Assignment, ParseError | MutatingAssignmentError | InconsistentVariableSizeError -]: +) -> result.Result[Assignment, ParseError | MutatingAssignmentError | InconsistentDimensionsError]: try: return TensorExpressionParsers.assignment.parse(string) - except (MutatingAssignmentError, InconsistentVariableSizeError) as e: + except (MutatingAssignmentError, InconsistentDimensionsError) as e: return result.Failure(e) diff --git a/src/tensora/format/__init__.py b/src/tensora/format/__init__.py index a282305..8fdcb1a 100644 --- a/src/tensora/format/__init__.py +++ b/src/tensora/format/__init__.py @@ -1,3 +1,4 @@ from .exceptions import InvalidModeOrderingError from .format import Format, Mode -from .parser import parse_format +from .format_order import format_order +from .parser import parse_format, parse_named_format diff --git a/src/tensora/format/format_order.py b/src/tensora/format/format_order.py new file mode 100644 index 0000000..67ecb84 --- /dev/null +++ b/src/tensora/format/format_order.py @@ -0,0 +1,7 @@ +__all__ = ["format_order"] + +from .format import Format + + +def format_order(format: Format | None) -> int | None: + return len(format.modes) if format is not None else None diff --git a/src/tensora/format/parser.py b/src/tensora/format/parser.py index a5476e6..49a49bf 100644 --- a/src/tensora/format/parser.py +++ b/src/tensora/format/parser.py @@ -1,4 +1,4 @@ -__all__ = ["parse_format"] +__all__ = ["parse_format", "parse_named_format"] from parsita import ParseError, ParserContext, lit, reg, rep from parsita.util import constant @@ -30,9 +30,21 @@ class FormatParsers(ParserContext): format = format_without_orderings | format_with_orderings + variable = reg(r"[a-zA-Z_][a-zA-Z0-9_]*") + named_format = variable << ":" & format > tuple -def parse_format(format: str) -> result.Result[Format, ParseError | InvalidModeOrderingError]: + +def parse_format(string: str, /) -> result.Result[Format, ParseError | InvalidModeOrderingError]: + try: + return FormatParsers.format.parse(string) + except InvalidModeOrderingError as e: + return result.Failure(e) + + +def parse_named_format( + string: str, / +) -> result.Result[tuple[str, Format], ParseError | InvalidModeOrderingError]: try: - return FormatParsers.format.parse(format) + return FormatParsers.named_format.parse(string) except InvalidModeOrderingError as e: return result.Failure(e) diff --git a/src/tensora/function.py b/src/tensora/function.py index fef9c17..4ab39a9 100644 --- a/src/tensora/function.py +++ b/src/tensora/function.py @@ -9,11 +9,14 @@ from functools import lru_cache from inspect import Parameter, Signature -from typing import Dict, Tuple -from .compile import TensorCompiler, allocate_taco_structure, taco_kernel, take_ownership_of_arrays -from .expression.ast import Assignment, Scalar +from returns.result import Failure, Success + +from .compile import allocate_taco_structure, generate_library, take_ownership_of_arrays +from .expression.ast import Assignment from .format import Format, parse_format +from .generate import TensorCompiler +from .problem import make_problem from .tensor import Tensor @@ -23,41 +26,15 @@ class PureTensorMethod: def __init__( self, assignment: Assignment, - input_formats: Dict[str, Format], - output_format: Format, + input_formats: dict[str, Format | None], + output_format: Format | None, compiler: TensorCompiler = TensorCompiler.tensora, ): - target_name = assignment.target.name - variable_orders = assignment.variable_orders() - - # Ensure that all parameters are defined - for variable_name in variable_orders.keys(): - if variable_name != target_name and variable_name not in input_formats: - raise ValueError( - f"Variable {variable_name} in {assignment} not listed in parameters" - ) - - # Ensure that no extraneous parameters are defined - for parameter_name in input_formats.keys(): - if parameter_name not in variable_orders: - raise ValueError(f"Parameter {parameter_name} not in {assignment} variables") - - # Verify that parameters have the correct order - for parameter_name, format in input_formats.items(): - if format.order != variable_orders[parameter_name]: - raise ValueError( - f"Parameter {parameter_name} has order {format.order}, but this variable in the " - f"assignment has order {variable_orders[parameter_name]}" - ) - - if isinstance(assignment.target, Scalar): - raise NotImplementedError("Tensora does not support scalar outputs yet") - - if output_format.order != assignment.target.order: - raise ValueError( - f"Output parameter has order {output_format.order}, but the output variable in the " - f"assignment has order {assignment.target.order}" - ) + match make_problem(assignment, {assignment.target.name: output_format, **input_formats}): + case Failure(error): + raise error + case Success(problem): + pass # Store validated attributes self.assignment = assignment @@ -73,8 +50,7 @@ def __init__( ) # Compile taco function - all_formats = {self.assignment.target.name: output_format, **input_formats} - self.parameter_order, self.cffi_lib = taco_kernel(assignment, all_formats, compiler) + self.parameter_order, self.cffi_lib = generate_library(problem, compiler) def __call__(self, *args, **kwargs): # Handle arguments like normal Python function @@ -152,7 +128,7 @@ def __call__(self, *args, **kwargs): def tensor_method( assignment: str, - input_formats: Dict[str, str], + input_formats: dict[str, str], output_format: str, compiler: TensorCompiler = TensorCompiler.tensora, ) -> PureTensorMethod: @@ -164,7 +140,7 @@ def tensor_method( @lru_cache() def cachable_tensor_method( assignment: str, - input_formats: Tuple[Tuple[str, str], ...], + input_formats: tuple[tuple[str, str], ...], output_format: str, compiler: TensorCompiler, ) -> PureTensorMethod: diff --git a/src/tensora/generate.py b/src/tensora/generate.py new file mode 100644 index 0000000..4cc16c1 --- /dev/null +++ b/src/tensora/generate.py @@ -0,0 +1,91 @@ +__all__ = ["TensorCompiler", "generate_c_code", "generate_c_code_tensora", "generate_c_code_taco"] + +from enum import Enum +from pathlib import Path + +from .kernel_type import KernelType +from .problem import Problem + + +class TensorCompiler(str, Enum): + # Python 3.10 does not support StrEnum, so do it manually + taco = "taco" + tensora = "tensora" + + def __str__(self) -> str: + return self.name + + +def generate_c_code( + problem: Problem, kernel_types: list[KernelType], tensor_compiler: TensorCompiler +) -> str: + match tensor_compiler: + case TensorCompiler.tensora: + return generate_c_code_tensora(problem, kernel_types) + case TensorCompiler.taco: + return generate_c_code_taco(problem, kernel_types) + + +def generate_c_code_tensora(problem: Problem, kernel_types: list[KernelType]) -> str: + from .codegen import ir_to_c + from .desugar import best_algorithm, desugar_assignment, index_dimensions, to_identifiable + from .ir import SourceBuilder, peephole + from .iteration_graph import Definition, generate_ir + + formats = problem.formats + + desugar = desugar_assignment(problem.assignment) + + output_variable = to_identifiable(desugar.target, formats) + + definition = Definition(output_variable, formats, index_dimensions(desugar)) + + graph = best_algorithm(desugar, formats) + + ir = SourceBuilder() + for kernel_type in kernel_types: + ir.append(generate_ir(definition, graph, kernel_type).finalize()) + + return ir_to_c(peephole(ir.finalize())) + + +taco_binary = Path(__file__).parent.joinpath("taco/bin/taco") + + +def generate_c_code_taco(problem: Problem, kernel_types: list[KernelType]) -> str: + import subprocess + + from .expression import deparse_to_taco + + formats = problem.formats + + expression_string = deparse_to_taco(problem.assignment) + format_string_arguments = [] + for name, format in formats.items(): + if format is not None and format.order != 0: # Taco does not like formats for scalars + mode_string = "".join(mode.character for mode in format.modes) + ordering_string = ",".join(map(str, format.ordering)) + format_string_arguments.append(f"-f={name}:{mode_string}:{ordering_string}") + + kernel_type_arguments = [] + for kernel_type in kernel_types: + if kernel_type == KernelType.evaluate: + kernel_type_arguments.append("-print-evaluate") + elif kernel_type == KernelType.compute: + kernel_type_arguments.append("-print-compute") + elif kernel_type == KernelType.assemble: + kernel_type_arguments.append("-print-assembly") + + # Call taco to write the kernels to standard out + result = subprocess.run( + [taco_binary, expression_string, "-print-nocolor"] + + kernel_type_arguments + + format_string_arguments, + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(result.stderr) + + return result.stdout diff --git a/src/tensora/iteration_graph/definition.py b/src/tensora/iteration_graph/definition.py index d8ac093..0293e7e 100644 --- a/src/tensora/iteration_graph/definition.py +++ b/src/tensora/iteration_graph/definition.py @@ -15,5 +15,5 @@ class TensorDimension: @dataclass(frozen=True, slots=True) class Definition: output_variable: Variable - formats: dict[str, Format] + formats: dict[str, Format | None] indexes: dict[str, TensorDimension] diff --git a/src/tensora/native.py b/src/tensora/native.py deleted file mode 100644 index 38cbb15..0000000 --- a/src/tensora/native.py +++ /dev/null @@ -1,29 +0,0 @@ -__all__ = ["generate_c_code_from_parsed"] - -from .codegen import ir_to_c -from .desugar import best_algorithm, desugar_assignment, index_dimensions, to_identifiable -from .expression.ast import Assignment -from .format import Format -from .ir import SourceBuilder, peephole -from .iteration_graph import Definition, generate_ir -from .kernel_type import KernelType - - -def generate_c_code_from_parsed( - assignment: Assignment, - formats: dict[str, Format], - kernel_types: list[KernelType] = [KernelType.evaluate], -) -> str: - desugar = desugar_assignment(assignment) - - output_variable = to_identifiable(desugar.target, formats) - - problem = Definition(output_variable, formats, index_dimensions(desugar)) - - graph = best_algorithm(desugar, formats) - - ir = SourceBuilder() - for kernel_type in kernel_types: - ir.append(generate_ir(problem, graph, kernel_type).finalize()) - - return ir_to_c(peephole(ir.finalize())) diff --git a/src/tensora/problem.py b/src/tensora/problem.py new file mode 100644 index 0000000..0ff1cb3 --- /dev/null +++ b/src/tensora/problem.py @@ -0,0 +1,113 @@ +__all__ = [ + "Problem", + "make_problem", + "IncorrectDimensionsError", + "UndefinedReferenceError", + "UnusedFormatError", +] + +from dataclasses import dataclass + +from returns.result import Failure, Result, Success + +from .expression.ast import Assignment +from .format import Format, Mode, format_order + + +@dataclass(frozen=True, slots=True) +class IncorrectDimensionsError(Exception): + name: str + actual: int | None + expected: int | None + assignment: Assignment + + def __str__(self): + actual = self.actual if self.actual is not None else "scalar" + expected = self.expected if self.expected is not None else "scalar" + + return ( + f"Expected each reference in an assignment to have a number of indexes matching the " + f"order of the corresponding format, but variable {self.name} referenced in " + f"{self.assignment} indexes has order {actual} while its format has order {expected}" + ) + + +@dataclass(frozen=True, slots=True) +class UndefinedReferenceError(Exception): + name: str + assignment: Assignment + formats: list[str] + + def __str__(self): + return ( + f"Excepted each reference in an assignment to have a corresponding format, " + f"but variable {self.name} referenced in {self.assignment} was not found among the " + f"given formats {self.formats}" + ) + + +@dataclass(frozen=True, slots=True) +class UnusedFormatError(Exception): + name: str + assignment: Assignment + + def __str__(self): + return ( + f"Expected each format to be referenced in the assignment, " + f"but format {self.name} was not referenced in {self.assignment}" + ) + + +@dataclass(frozen=True, slots=True) +class Problem: + assignment: Assignment + formats: dict[str, Format | None] + + def __post_init__(self): + # This intentionally allows for names in formats that are not referenced in the assignment. + # The CLI and porcelain API will not allow this, but this is just as valid as defining a + # function with unused parameters. + + tensor_orders = self.assignment.variable_orders() + for name, order in tensor_orders.items(): + if name not in self.formats: + raise UndefinedReferenceError(name, self.assignment, list(self.formats.keys())) + elif order != format_order(self.formats[name]): + raise IncorrectDimensionsError( + name, format_order(self.formats[name]), order, self.assignment + ) + + +def make_problem( + assignment: Assignment, formats: dict[str, Format | None] +) -> Result[Problem, UnusedFormatError | UndefinedReferenceError | IncorrectDimensionsError]: + """Create a Problem while filling in default formats. + + This does three things that the `Problem` constructor does not do: + 1. It reorders the formats to match the order the variables appear in the assignment. + 2. It fills in any missing formats with all dense modes. + 3. It raises an exception if there are formats not referenced in the assignment. + """ + + tensor_orders = assignment.variable_orders() + + for name in formats.keys(): + if name not in tensor_orders: + return Failure(UnusedFormatError(name, assignment)) + + new_formats = {} + for name, order in tensor_orders.items(): + if name not in formats: + if order is None: + new_formats[name] = None + else: + new_formats[name] = Format(tuple([Mode.dense] * order), tuple(range(order))) + else: + new_formats[name] = formats[name] + + try: + problem = Problem(assignment, new_formats) + except (UndefinedReferenceError, IncorrectDimensionsError) as error: + return Failure(error) + + return Success(problem) diff --git a/tests/test_expression.py b/tests/test_expression.py index e77a4c6..1e4bd4c 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1,7 +1,7 @@ import pytest from tensora.expression import ( - InconsistentVariableSizeError, + InconsistentDimensionsError, MutatingAssignmentError, parse_assignment, ) @@ -71,7 +71,7 @@ def test_mutating_assignment(): ], ) def test_inconsistent_variable_size(assignment): - assert isinstance(parse_assignment(assignment).failure(), InconsistentVariableSizeError) + assert isinstance(parse_assignment(assignment).failure(), InconsistentDimensionsError) def parse(string): diff --git a/tests/test_format.py b/tests/test_format.py index ff24990..691dd1e 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -2,7 +2,7 @@ from returns.result import Failure from tensora import Format, Mode -from tensora.format import InvalidModeOrderingError, parse_format +from tensora.format import InvalidModeOrderingError, parse_format, parse_named_format format_strings = [ ("", Format((), ())), @@ -27,12 +27,28 @@ def test_deparse_format(string, format): assert actual == string -@pytest.mark.parametrize("string", ["df", "1d0s", "d0s", "d0s1s1", "d1s2s3"]) +@pytest.mark.parametrize("string", ["df", "1d0s", "d0s", "d0s1s1", "d1s2s3", "d3d1d2"]) def test_parse_bad_format(string): actual = parse_format(string) assert isinstance(actual, Failure) +def test_parse_named_format(): + actual = parse_named_format("A:d1s0s2").unwrap() + assert actual == ("A", Format((Mode.dense, Mode.compressed, Mode.compressed), (1, 0, 2))) + + +def test_parse_bad_named_format(): + actual = parse_named_format("d1s0s2s3") + assert isinstance(actual, Failure) + + +@pytest.mark.parametrize("string", ["A:d0s", "A:d3d1d2"]) +def test_parse_bad_ordering_in_named_format(string): + actual = parse_named_format(string) + assert isinstance(actual, Failure) + + def test_format_attributes(): format = Format((Mode.dense, Mode.compressed), (1, 0)) diff --git a/tests/test_function.py b/tests/test_function.py index 25b0a9d..d95bb2d 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -3,8 +3,9 @@ import pytest -from tensora import Tensor, TensorCompiler, evaluate_taco, evaluate_tensora, tensor_method +from tensora import Tensor, evaluate_taco, evaluate_tensora, tensor_method from tensora.desugar import DiagonalAccessError, NoKernelFoundError +from tensora.generate import TensorCompiler pytestmark = pytest.mark.parametrize( ("evaluate", "compiler"),