Skip to content

Commit

Permalink
Use Returns Result in more places
Browse files Browse the repository at this point in the history
  • Loading branch information
drhagen committed Dec 2, 2023
1 parent 9a9fb3a commit 7632df9
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 26 deletions.
14 changes: 8 additions & 6 deletions src/tensora/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from parsita import ParseError
from returns.result import Failure, Success

from .desugar import DiagonalAccessError, NoKernelFoundError
from .expression import parse_assignment
from .format import parse_named_format
from .generate import TensorCompiler, generate_c_code
Expand Down Expand Up @@ -108,11 +107,14 @@ def tensora(
raise NotImplementedError()

# Generate code
try:
code = generate_c_code(problem, kernel_types, tensor_compiler)
except (DiagonalAccessError, NoKernelFoundError) as error:
typer.echo(error, err=True)
raise typer.Exit(1)
match generate_c_code(problem, kernel_types, tensor_compiler):
case Failure(error):
typer.echo(str(error), err=True)
raise typer.Exit(1)
case Success(code):
pass
case _:
raise NotImplementedError()

if output_path is None:
typer.echo(code)
Expand Down
11 changes: 7 additions & 4 deletions src/tensora/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from weakref import WeakKeyDictionary

from cffi import FFI
from returns.result import Failure, Success

from tensora.generate import TensorCompiler

from .generate import generate_c_code
from .generate import TensorCompiler, generate_c_code
from .kernel_type import KernelType
from .problem import Problem

Expand Down Expand Up @@ -116,7 +115,11 @@ def generate_library(
signature, and the second element is the compiled FFILibrary which has a single method `evaluate` which expects
cffi pointers to taco_tensor_t instances in order specified by the list of variable names.
"""
source = generate_c_code(problem, [KernelType.evaluate], compiler)
match generate_c_code(problem, [KernelType.evaluate], compiler):
case Failure(error):
raise error
case Success(source):
pass

# Determine signature
# 1) Find function by name and capture its parameter list
Expand Down
19 changes: 12 additions & 7 deletions src/tensora/desugar/best_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
__all__ = ["best_algorithm"]

from returns.result import Failure, Result, Success

from ..format import Format
from ..iteration_graph.iteration_graph import IterationGraph
from . import ast
from .exceptions import NoKernelFoundError
from .exceptions import DiagonalAccessError, NoKernelFoundError
from .to_iteration_graphs import to_iteration_graphs


def best_algorithm(
assignment: ast.Assignment, formats: dict[str, Format | None]
) -> IterationGraph:
match next(to_iteration_graphs(assignment, formats), None):
case None:
raise NoKernelFoundError()
case graph:
return graph
) -> Result[IterationGraph, DiagonalAccessError | NoKernelFoundError]:
try:
match next(to_iteration_graphs(assignment, formats), None):
case None:
return Failure(NoKernelFoundError())
case graph:
return Success(graph)
except DiagonalAccessError as e:
return Failure(e)
46 changes: 37 additions & 9 deletions src/tensora/generate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
__all__ = ["TensorCompiler", "generate_c_code", "generate_c_code_tensora", "generate_c_code_taco"]

__all__ = [
"TensorCompiler",
"generate_c_code",
"generate_c_code_tensora",
"generate_c_code_taco",
"TacoError",
]

from dataclasses import dataclass
from enum import Enum
from pathlib import Path

from returns.result import Failure, Result, Success

from .desugar import DiagonalAccessError, NoKernelFoundError
from .kernel_type import KernelType
from .problem import Problem

Expand All @@ -18,15 +28,17 @@ def __str__(self) -> str:

def generate_c_code(
problem: Problem, kernel_types: list[KernelType], tensor_compiler: TensorCompiler
) -> str:
) -> Result[str, Exception]:
match tensor_compiler:
case TensorCompiler.tensora:
return generate_c_code_tensora(problem, kernel_types)
case TensorCompiler.taco:
return generate_c_code_taco(problem, kernel_types)


def generate_c_code_tensora(problem: Problem, kernel_types: list[KernelType]) -> str:
def generate_c_code_tensora(
problem: Problem, kernel_types: list[KernelType]
) -> Result[str, DiagonalAccessError | NoKernelFoundError]:
from .codegen import ir_to_c
from .desugar import best_algorithm, desugar_assignment, index_dimensions, to_identifiable
from .ir import SourceBuilder, peephole
Expand All @@ -40,19 +52,35 @@ def generate_c_code_tensora(problem: Problem, kernel_types: list[KernelType]) ->

definition = Definition(output_variable, formats, index_dimensions(desugar))

graph = best_algorithm(desugar, formats)
match best_algorithm(desugar, formats):
case Failure() as result:
return result
case Success(graph):
pass
case _:
raise NotImplementedError()

ir = SourceBuilder()
for kernel_type in kernel_types:
ir.append(generate_ir(definition, graph, kernel_type).finalize())

return ir_to_c(peephole(ir.finalize()))
return Success(ir_to_c(peephole(ir.finalize())))


@dataclass(frozen=True, slots=True)
class TacoError(Exception):
message: str

def __str__(self) -> str:
return self.message


taco_binary = Path(__file__).parent.joinpath("taco/bin/taco")


def generate_c_code_taco(problem: Problem, kernel_types: list[KernelType]) -> str:
def generate_c_code_taco(
problem: Problem, kernel_types: list[KernelType]
) -> Result[str, Exception]:
import subprocess

from .expression import deparse_to_taco
Expand Down Expand Up @@ -86,6 +114,6 @@ def generate_c_code_taco(problem: Problem, kernel_types: list[KernelType]) -> st
)

if result.returncode != 0:
raise RuntimeError(result.stderr)
return Failure(TacoError(result.stderr))

return result.stdout
return Success(result.stdout)

0 comments on commit 7632df9

Please sign in to comment.