Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply "[FRONTEND] added support for tuples (#5220)" #3043

Merged
merged 12 commits into from
Dec 21, 2024
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ repos:
- id: check-symlinks
- id: destroyed-symlinks
- id: trailing-whitespace
exclude: .*.patch
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
Expand Down
1 change: 1 addition & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ void init_triton_ir(py::module &&m) {
"Function argument index out of range");
return self.getArgument(idx);
})
.def("get_num_args", &FuncOp::getNumArguments)
.def(
"add_entry_block",
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },
Expand Down
49 changes: 25 additions & 24 deletions python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def kernel():
a += 1 # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert "is not defined" in str(e.value), "error should mention the undefined variable"
Expand All @@ -32,7 +32,7 @@ def kernel():
0 + "a"

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the 0"
Expand All @@ -47,7 +47,7 @@ def kernel():
tl.static_assert(isinstance(0, tl.tensor))

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert isinstance(e.value, CompileTimeAssertionFailure)
Expand All @@ -66,7 +66,7 @@ def kernel():
not (0, 0)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert e.value.__cause__ is None
Expand All @@ -83,7 +83,7 @@ def kernel():
1.0 << 1

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the 1.0"
Expand All @@ -107,7 +107,7 @@ def kernel():
nested_call()

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
inner = e.value.__cause__
Expand All @@ -130,7 +130,7 @@ def kernel():
tl.expand_dims(None, -1)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
inner = e.value.__cause__
Expand All @@ -157,7 +157,7 @@ def kernel():
a = two_returns()
a + tl.arange(0, 4) # only works if we took the first return

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


def test_not_const_annotate_no_err():
Expand All @@ -166,7 +166,7 @@ def test_not_const_annotate_no_err():
def kernel(N: int = 1):
pass

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))


@triton.jit
Expand All @@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr):
a = returns_branched_on_constexpr(N)
a + tl.arange(0, 4)

triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0}))
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0}))

@triton.jit
def kernel2(N: tl.constexpr):
a = returns_branched_on_constexpr(N)
a + tl.arange(0, 8)

triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1}))
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1}))


@triton.jit
Expand All @@ -211,7 +211,7 @@ def kernel(N: int):
returns_branched_on_non_constexpr(N)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the function call"
Expand All @@ -227,7 +227,7 @@ def kernel():
tl.arange(2, 7)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert str(e.value.__cause__) == "arange's range must be a power of 2"


Expand All @@ -238,7 +238,7 @@ def kernel():
tl.full((33, ), 0, dtype=tl.int64)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert str(e.value.__cause__) == "Shape element 0 must be a power of 2"


Expand All @@ -251,7 +251,7 @@ def kernel():
a = CAPTURED # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert "CAPTURED is not defined" in str(e.value)


Expand All @@ -265,7 +265,7 @@ def kernel():
a = GLOBAL # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert "global variable" in str(e.value)


Expand All @@ -279,7 +279,7 @@ def kernel():
a = CONSTEXPR_ANNOTATED_GLOBAL # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


CONSTEXPR_GLOBAL = tl.constexpr(42)
Expand All @@ -292,7 +292,7 @@ def kernel():
a = CONSTEXPR_GLOBAL # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


TYPE_ALIAS = tl.pointer_type(tl.int32)
Expand All @@ -305,7 +305,7 @@ def kernel():
a = TYPE_ALIAS # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


def test_global_access_in_fn_default_arg():
Expand All @@ -315,7 +315,7 @@ def kernel(a=GLOBAL):
pass

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={}))


def test_defaults_assign_no_err():
Expand All @@ -324,7 +324,7 @@ def test_defaults_assign_no_err():
def kernel(a=1, B: tl.constexpr = ""):
pass

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""}))


def test_where_warning(fresh_triton_cache):
Expand All @@ -337,7 +337,7 @@ def kernel():
tl.where(a, b, c)

with pytest.warns(UserWarning):
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
Expand Down Expand Up @@ -371,7 +371,8 @@ def dtype_kernel(dtype: tl.constexpr):
ctx = pytest.raises(CompilationError, match="")

with ctx as e:
triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype}))
triton.compile(
triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype}))

if dtype not in supported_dtypes:
try:
Expand Down Expand Up @@ -426,7 +427,7 @@ def dot_kernel():
tl.dot(a, b, max_num_imprecise_acc=128)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={}))
try:
assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)")
except AssertionError as assertion_err:
Expand Down
23 changes: 19 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4494,15 +4494,17 @@ def kernel(x):
def test_value_specialization(value: int, value_type: str, device) -> None:

def repr(specialization):
spec_type = specialization.signature["VALUE"]
return f"kernel_{spec_type}"
ty = specialization.signature["value1"]
cst = '_'.join([k for k, v in specialization.constants.items() if v == 1])
return f"kernel_{ty}_{cst}"

@triton.jit(repr=repr)
def kernel(VALUE, X):
def kernel(value1, is_one, X):
pass

x = torch.tensor([3.14159], device=device)
h = kernel[(1, )](value, x)
h = kernel[(1, )](value, 1, x)
assert "is_one" in h.name
assert value_type in h.name


Expand Down Expand Up @@ -6346,6 +6348,19 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r
torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32))


def test_dtype(device):

@triton.jit
def kernel(X):
dtype_x: tl.constexpr = X.dtype.element_ty
tl.static_assert(dtype_x == tl.int32)
tl.static_assert(dtype_x == tl.constexpr(tl.int32))
tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32))

X = torch.zeros(1, dtype=torch.int32, device=device)
kernel[(1, )](X)


def test_side_effectful_scan(device):
if device != "cuda":
pytest.xfail()
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def kernel():
pass

try:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
except Exception as e:
pytest.fail(f"triton compile failed with error: {e}")

Expand Down
100 changes: 100 additions & 0 deletions python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
import triton
import triton.language as tl
import torch


@triton.jit
def _tuple_increment(values):
for i in tl.static_range(len(values)):
values[i] = values[i] + 1
return values


@triton.jit
def _tuple_index_func(Ptrs, values):
for i in tl.static_range(len(values)):
tl.store(Ptrs[i], values[i])


@triton.jit
def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4):
values = _tuple_increment(values)
_tuple_index_func(Ptrs, values)


@pytest.mark.parametrize("size", [0, 1, 2, 3, 4])
def test_index(size, device="xpu"):
vals = tuple([i + 1 for i in range(size)])
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals])
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0)
assert vals == tuple([x.item() - 1 for x in rets])


# ----


@triton.jit
def _tuple_assign(XPtrs, YPtrs, values):
# assign from tuple
X0, X1 = XPtrs
x0, x1 = values
tl.store(X0, x0)
tl.store(X1, x1)
# assign to tuple
Y0, Y1, Y2 = YPtrs
Y = Y0, Y1, Y2
y = x0, 10, x1
tl.store(Y[0], y[0])
tl.store(Y[1], y[1])
tl.store(Y[2], y[2])


def test_assign(device="xpu"):
vals = (2., 3.)
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)])
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)])
_tuple_assign[(1, )](x, y, vals)
assert x[0] == vals[0]
assert x[1] == vals[1]
assert y[0] == vals[0]
assert y[1] == 10
assert y[2] == vals[1]


# -------


@triton.jit
def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1):
tl.static_assert(tuple1[1] is None)
tl.store(Ptr + 5, cst2)
tl.store(Ptr + 6, tuple1[0])
tl.store(Ptr + 7, tl.load(tuple1[2][0]))
tl.store(Ptr + 8, tuple1[2][1][0])
tl.store(Ptr + 9, tl.load(tuple1[2][1][2]))


# test serialization/deserialization of tuple arguments in
# the frontend.
@triton.jit
def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2):
tl.static_assert(N1 is None)
tl.static_assert(tuple1[1][1] is None)
tl.store(Ptr + 0, tl.load(tuple1[0]))
tl.store(Ptr + 1, tuple1[1][0])
tl.store(Ptr + 2, tl.load(tuple1[1][2]))
tl.store(Ptr + 3, cst1 + val1)
tl.store(Ptr + 4, tl.load(tuple2[0]))
_tuple_fn0(Ptr, 15, (-1, None, tuple1))


def test_serialize(device="xpu"):
x0 = torch.tensor([8], dtype=torch.int32, device=device)
x1 = torch.tensor([12], dtype=torch.int32, device=device)
y0 = torch.tensor([10], dtype=torch.int32, device=device)
z = torch.empty((10, ), dtype=torch.int32, device=device)
# we want to check that JIT specialization propagates to tuples:
_tuple_serialize[(1, )](z, None, (x0, (1, None, x1)), 20, 1, (y0, ))
ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device)
assert torch.equal(z, ref)
Loading
Loading