Skip to content

Commit

Permalink
Add CLI tests
Browse files Browse the repository at this point in the history
  • Loading branch information
drhagen committed Dec 2, 2023
1 parent c242b23 commit 9a9fb3a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
30 changes: 28 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from tempfile import NamedTemporaryFile

import pytest
from typer.testing import CliRunner

from tensora.cli import app
Expand All @@ -23,9 +24,11 @@ def test_cli():
assert result.stdout.startswith("int32_t compute(taco_tensor_t* restrict y,")


def test_multiple_kernels():
@pytest.mark.parametrize("compiler", [[], ["-c", "tensora"], ["-c", "taco"]])
def test_multiple_kernels(compiler):
result = runner.invoke(
app, ["y(i) = A(i,j) * x(j)", "-t", "compute", "-t", "evaluate", "-t", "assemble"]
app,
["y(i) = A(i,j) * x(j)", "-t", "compute", "-t", "evaluate", "-t", "assemble"] + compiler,
)

assert result.exit_code == 0
Expand All @@ -42,3 +45,26 @@ def test_write_to_file():
assert result.stdout == ""

assert Path(f.name).read_text().startswith("int32_t compute(taco_tensor_t* restrict y,")


@pytest.mark.parametrize("compiler", ["tensora", "taco"])
@pytest.mark.parametrize(
"command",
[
["a(i) = b(i) +"],
["y(i) = A(i,j) * x(j)", "-f=ds"],
["y(i) = A(i,j) * x(j)", "-f=A:d1s2"],
["y(i) = A(i,j) * x(j)", "-f=A:ds", "-f=A:dd"],
["y(i) = A(i,j) * x(j)", "-f=A:d"],
["y(i) = A(i,j) * x(j)", "-f=B:ds"],
["a(i) = A(i,i)"],
["A(i,j) = B(i,j) + C(j,i)", "-f=A:ds", "-f=B:ds", "-f=C:ds"],
],
)
def test_bad_input(compiler, command):
if command[0] == "a(i) = A(i,i)" and compiler == "taco":
pytest.xfail("Taco CLI succeeds but generates invalid code")

result = runner.invoke(app, command + ["-c", compiler], catch_exceptions=False)

assert result.exit_code == 1
2 changes: 1 addition & 1 deletion tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_diagonal_error(evaluate, compiler):

def test_no_solution(evaluate, compiler):
if compiler == TensorCompiler.taco:
pytest.skip("Taco currently segfaults on this")
pytest.xfail("Taco currently segfaults on this")

with pytest.raises(NoKernelFoundError):
tensor_method("A(i,j) = B(i,j) + C(j,i)", dict(B="ds", C="ds"), "ds", compiler)

0 comments on commit 9a9fb3a

Please sign in to comment.