diff --git a/src/tensora/cli.py b/src/tensora/cli.py index 9e24f72..2929139 100644 --- a/src/tensora/cli.py +++ b/src/tensora/cli.py @@ -6,6 +6,7 @@ import typer from parsita import Failure, Success +from .desugar.exceptions import DiagonalAccessError, NoKernelFoundError from .expression import parse_assignment from .format import Format, Mode, parse_format from .kernel_type import KernelType @@ -97,7 +98,11 @@ def tensora( formats[variable_name] = Format((Mode.dense,) * order, tuple(range(order))) # Generate code - code = generate_c_code_from_parsed(sugar, formats, kernel_types) + try: + code = generate_c_code_from_parsed(sugar, formats, kernel_types) + except (DiagonalAccessError, NoKernelFoundError) as error: + typer.echo(error, err=True) + raise typer.Exit(1) if output_path is None: typer.echo(code) diff --git a/src/tensora/desugar/exceptions.py b/src/tensora/desugar/exceptions.py index e0d2f73..2a17518 100644 --- a/src/tensora/desugar/exceptions.py +++ b/src/tensora/desugar/exceptions.py @@ -12,7 +12,7 @@ class DiagonalAccessError(Exception): def __str__(self) -> str: return ( f"Diagonal access to a tensor (i.e. repeating the same index within a tensor) is not " - f"currently supported: {self.tensor.name}({', '.join(self.tensor.indexes)}" + f"currently supported: {self.tensor.name}({', '.join(self.tensor.indexes)})" ) @@ -20,6 +20,6 @@ def __str__(self) -> str: class NoKernelFoundError(Exception): def __str__(self) -> str: return ( - "Was unable to find a kernel to solve the given tensor expression. This is likely " - "due to sparse tensors needing to be iterated in opposite orders." + "Tensora's tensor algebra compiler was unable to find a kernel for the given problem. " + "This is likely due to sparse tensors needing to be iterated in opposite orders." )