diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ea237fc848..6f10daec56 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,7 @@ repos: - id: check-symlinks - id: destroyed-symlinks - id: trailing-whitespace + exclude: .*.patch - id: end-of-file-fixer - id: check-yaml - id: check-toml diff --git a/python/src/ir.cc b/python/src/ir.cc index 23bb86e5eb..53ba39ae10 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -605,6 +605,7 @@ void init_triton_ir(py::module &&m) { "Function argument index out of range"); return self.getArgument(idx); }) + .def("get_num_args", &FuncOp::getNumArguments) .def( "add_entry_block", [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index f0ab578cbd..1d2f692841 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -17,7 +17,7 @@ def kernel(): a += 1 # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert "is not defined" in str(e.value), "error should mention the undefined variable" @@ -32,7 +32,7 @@ def kernel(): 0 + "a" with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert "at 2:4:" in str(e.value), "error should point to the 0" @@ -47,7 +47,7 @@ def kernel(): tl.static_assert(isinstance(0, tl.tensor)) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert isinstance(e.value, CompileTimeAssertionFailure) @@ -66,7 +66,7 @@ def kernel(): not (0, 0) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert e.value.__cause__ is None @@ -83,7 +83,7 @@ def kernel(): 1.0 << 1 with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert "at 2:4:" in str(e.value), "error should point to the 1.0" @@ -107,7 +107,7 @@ def kernel(): nested_call() with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: inner = e.value.__cause__ @@ -130,7 +130,7 @@ def kernel(): tl.expand_dims(None, -1) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: inner = e.value.__cause__ @@ -157,7 +157,7 @@ def kernel(): a = two_returns() a + tl.arange(0, 4) # only works if we took the first return - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) def test_not_const_annotate_no_err(): @@ -166,7 +166,7 @@ def test_not_const_annotate_no_err(): def kernel(N: int = 1): pass - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) @triton.jit @@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr): a = returns_branched_on_constexpr(N) a + tl.arange(0, 4) - triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0})) + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0})) @triton.jit def kernel2(N: tl.constexpr): a = returns_branched_on_constexpr(N) a + tl.arange(0, 8) - triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1})) + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1})) @triton.jit @@ -211,7 +211,7 @@ def kernel(N: int): returns_branched_on_non_constexpr(N) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) try: assert "at 2:4:" in str(e.value), "error should point to the function call" @@ -227,7 +227,7 @@ def kernel(): tl.arange(2, 7) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert str(e.value.__cause__) == "arange's range must be a power of 2" @@ -238,7 +238,7 @@ def kernel(): tl.full((33, ), 0, dtype=tl.int64) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" @@ -251,7 +251,7 @@ def kernel(): a = CAPTURED # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert "CAPTURED is not defined" in str(e.value) @@ -265,7 +265,7 @@ def kernel(): a = GLOBAL # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert "global variable" in str(e.value) @@ -279,7 +279,7 @@ def kernel(): a = CONSTEXPR_ANNOTATED_GLOBAL # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) CONSTEXPR_GLOBAL = tl.constexpr(42) @@ -292,7 +292,7 @@ def kernel(): a = CONSTEXPR_GLOBAL # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) TYPE_ALIAS = tl.pointer_type(tl.int32) @@ -305,7 +305,7 @@ def kernel(): a = TYPE_ALIAS # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) def test_global_access_in_fn_default_arg(): @@ -315,7 +315,7 @@ def kernel(a=GLOBAL): pass # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={})) def test_defaults_assign_no_err(): @@ -324,7 +324,7 @@ def test_defaults_assign_no_err(): def kernel(a=1, B: tl.constexpr = ""): pass - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""})) def test_where_warning(fresh_triton_cache): @@ -337,7 +337,7 @@ def kernel(): tl.where(a, b, c) with pytest.warns(UserWarning): - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) @pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]) @@ -371,7 +371,8 @@ def dtype_kernel(dtype: tl.constexpr): ctx = pytest.raises(CompilationError, match="") with ctx as e: - triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype})) + triton.compile( + triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) if dtype not in supported_dtypes: try: @@ -426,7 +427,7 @@ def dot_kernel(): tl.dot(a, b, max_num_imprecise_acc=128) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={})) try: assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)") except AssertionError as assertion_err: diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 345178ccf6..7354d663a4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4494,15 +4494,17 @@ def kernel(x): def test_value_specialization(value: int, value_type: str, device) -> None: def repr(specialization): - spec_type = specialization.signature["VALUE"] - return f"kernel_{spec_type}" + ty = specialization.signature["value1"] + cst = '_'.join([k for k, v in specialization.constants.items() if v == 1]) + return f"kernel_{ty}_{cst}" @triton.jit(repr=repr) - def kernel(VALUE, X): + def kernel(value1, is_one, X): pass x = torch.tensor([3.14159], device=device) - h = kernel[(1, )](value, x) + h = kernel[(1, )](value, 1, x) + assert "is_one" in h.name assert value_type in h.name @@ -6346,6 +6348,19 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) +def test_dtype(device): + + @triton.jit + def kernel(X): + dtype_x: tl.constexpr = X.dtype.element_ty + tl.static_assert(dtype_x == tl.int32) + tl.static_assert(dtype_x == tl.constexpr(tl.int32)) + tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32)) + + X = torch.zeros(1, dtype=torch.int32, device=device) + kernel[(1, )](X) + + def test_side_effectful_scan(device): if device != "cuda": pytest.xfail() diff --git a/python/test/unit/language/test_decorator.py b/python/test/unit/language/test_decorator.py index fbbfb71446..42207cc1fa 100644 --- a/python/test/unit/language/test_decorator.py +++ b/python/test/unit/language/test_decorator.py @@ -23,7 +23,7 @@ def kernel(): pass try: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) except Exception as e: pytest.fail(f"triton compile failed with error: {e}") diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py new file mode 100644 index 0000000000..7eef4b4cd3 --- /dev/null +++ b/python/test/unit/language/test_tuple.py @@ -0,0 +1,100 @@ +import pytest +import triton +import triton.language as tl +import torch + + +@triton.jit +def _tuple_increment(values): + for i in tl.static_range(len(values)): + values[i] = values[i] + 1 + return values + + +@triton.jit +def _tuple_index_func(Ptrs, values): + for i in tl.static_range(len(values)): + tl.store(Ptrs[i], values[i]) + + +@triton.jit +def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4): + values = _tuple_increment(values) + _tuple_index_func(Ptrs, values) + + +@pytest.mark.parametrize("size", [0, 1, 2, 3, 4]) +def test_index(size, device="xpu"): + vals = tuple([i + 1 for i in range(size)]) + rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals]) + _tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0) + assert vals == tuple([x.item() - 1 for x in rets]) + + +# ---- + + +@triton.jit +def _tuple_assign(XPtrs, YPtrs, values): + # assign from tuple + X0, X1 = XPtrs + x0, x1 = values + tl.store(X0, x0) + tl.store(X1, x1) + # assign to tuple + Y0, Y1, Y2 = YPtrs + Y = Y0, Y1, Y2 + y = x0, 10, x1 + tl.store(Y[0], y[0]) + tl.store(Y[1], y[1]) + tl.store(Y[2], y[2]) + + +def test_assign(device="xpu"): + vals = (2., 3.) + x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)]) + y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)]) + _tuple_assign[(1, )](x, y, vals) + assert x[0] == vals[0] + assert x[1] == vals[1] + assert y[0] == vals[0] + assert y[1] == 10 + assert y[2] == vals[1] + + +# ------- + + +@triton.jit +def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1): + tl.static_assert(tuple1[1] is None) + tl.store(Ptr + 5, cst2) + tl.store(Ptr + 6, tuple1[0]) + tl.store(Ptr + 7, tl.load(tuple1[2][0])) + tl.store(Ptr + 8, tuple1[2][1][0]) + tl.store(Ptr + 9, tl.load(tuple1[2][1][2])) + + +# test serialization/deserialization of tuple arguments in +# the frontend. +@triton.jit +def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2): + tl.static_assert(N1 is None) + tl.static_assert(tuple1[1][1] is None) + tl.store(Ptr + 0, tl.load(tuple1[0])) + tl.store(Ptr + 1, tuple1[1][0]) + tl.store(Ptr + 2, tl.load(tuple1[1][2])) + tl.store(Ptr + 3, cst1 + val1) + tl.store(Ptr + 4, tl.load(tuple2[0])) + _tuple_fn0(Ptr, 15, (-1, None, tuple1)) + + +def test_serialize(device="xpu"): + x0 = torch.tensor([8], dtype=torch.int32, device=device) + x1 = torch.tensor([12], dtype=torch.int32, device=device) + y0 = torch.tensor([10], dtype=torch.int32, device=device) + z = torch.empty((10, ), dtype=torch.int32, device=device) + # we want to check that JIT specialization propagates to tuples: + _tuple_serialize[(1, )](z, None, (x0, (1, None, x1)), 20, 1, (y0, )) + ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device) + assert torch.equal(z, ref) diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py index 206d132301..e621eefc01 100644 --- a/python/test/unit/runtime/test_bindings.py +++ b/python/test/unit/runtime/test_bindings.py @@ -63,15 +63,12 @@ def walk_fn(op): backend = triton.compiler.compiler.make_backend(target) src = triton.compiler.compiler.ASTSource( fn=kernel, - signature={ - kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) - for i, arg in enumerate(args) - if i not in kernel.constexprs - }, - constants={kernel.arg_names[i]: arg - for i, arg in enumerate(args) - if not isinstance(arg, torch.Tensor)}, - attrs=backend.get_attrs_descriptor(args, kernel.params), + signature={kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args)}, + constexprs={kernel.arg_names[i]: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + attrs=backend.get_attrs_descriptor(kernel.params, args), ) context = triton._C.libtriton.ir.context() diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index ab08dfe86b..7bcd1c1b0e 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -592,10 +592,10 @@ def cache_hook(*args, **kwargs): JITFunction.cache_hook = cache_hook # In warmup we assume that the pointer range is 32 bits kernel_add.warmup(torch.float32, grid=(1, )) - assert pointer_range_32 == [0] + assert pointer_range_32 == [(0, )] # Torch tensor > 2GB kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) assert len(pointer_range_32) == 0 # Torch tensor <= 2GB kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) - assert pointer_range_32 == [0] + assert pointer_range_32 == [(0, )] diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index e6868563e7..fa90f18129 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -19,8 +19,8 @@ def kernel_sub(a, b, o, N: tl.constexpr): src = ASTSource( fn=kernel_sub, - constants={'N': 32}, - signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32"}, + constexprs={'N': 32}, + signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32", 'N': 'constexpr'}, attrs=attrs, ) triton.compile(src=src, target=target) @@ -44,7 +44,7 @@ def kernel_dot(Z): z = tl.dot(z, z) tl.store(Z + offs, z) - src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constants={}) + src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constexprs={}) triton.compile(src=src, target=target) @@ -65,7 +65,7 @@ def empty_kernel(): import gc gc.collect() - src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants={}) + src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constexprs={}) triton.compile(src=src, target=target) diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 6646d94f50..461dcb46b4 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -92,7 +92,7 @@ def matmul_kernel( "stride_cm": "i32", "stride_cn": "i32", }, - constants={}, + constexprs={}, )) captured = capfd.readouterr() @@ -136,8 +136,9 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr) "in_ptr2": "*fp16", "in_ptr3": "*fp32", "out_ptr0": "*fp16", + "XBLOCK": "constexpr", }, - constants={"XBLOCK": XBLOCK}, + constexprs={"XBLOCK": XBLOCK}, ), options={"num_warps": 1}, ) diff --git a/python/triton/_utils.py b/python/triton/_utils.py index ca60c8c3cb..0ce1a53a70 100644 --- a/python/triton/_utils.py +++ b/python/triton/_utils.py @@ -20,3 +20,52 @@ def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]: idx += size assert idx == len(flat) return ret + + +def find_paths_if(iterable, pred): + from .language import core + is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type)) + ret = dict() + + def _impl(current, path): + path = (path[0], ) if len(path) == 1 else tuple(path) + if is_iterable(current): + for idx, item in enumerate(current): + _impl(item, path + (idx, )) + elif pred(path, current): + if len(path) == 1: + ret[(path[0], )] = current + else: + ret[tuple(path)] = current + + if is_iterable(iterable): + _impl(iterable, []) + elif pred(list(), iterable): + ret = {tuple(): iterable} + else: + ret = dict() + return ret + + +def parse_list_string(s): + s = s.strip() + if s.startswith('[') and s.endswith(']'): + s = s[1:-1] + result = [] + current = '' + depth = 0 + for c in s: + if c == '[': + depth += 1 + current += c + elif c == ']': + depth -= 1 + current += c + elif c == ',' and depth == 0: + result.append(current.strip()) + current = '' + else: + current += c + if current.strip(): + result.append(current.strip()) + return result diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index a24f768192..98e6f95c8c 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -3,11 +3,11 @@ import hashlib import subprocess import sysconfig - from abc import ABCMeta, abstractmethod from dataclasses import dataclass from typing import Dict, List, Tuple, Union from types import ModuleType +from .._utils import find_paths_if # Table that associates strings to AttrsDescriptor (sub)classes. # In this way we can dynamically select the correct class @@ -52,7 +52,8 @@ class AttrsDescriptor: `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant """ - __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties') + __slots__ = ('divisibility_16', 'equal_to_1', 'equal_to_none', 'arg_properties', 'property_values', + 'constant_properties') def __init__(self, params=None, values=None): """ @@ -67,6 +68,7 @@ def __init__(self, params=None, values=None): # Default initialization self.arg_properties = {} self.property_values = {} + self.equal_to_none = {} self.constant_properties = set() self._add_common_properties(params, values) @@ -86,17 +88,30 @@ def _add_common_properties(self, params, values): assert (len(params) == len(values)) # Divisibility property - self.arg_properties["tt.divisibility"] = [ - param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg) - and not param.do_not_specialize and not param.do_not_specialize_on_alignment - ] + divisibility_16 = [] + for param, arg in zip(params, values): + if param.do_not_specialize or \ + param.do_not_specialize_on_alignment: + continue + paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_divisible_by_16(val)) + divisibility_16 += [(param.num, ) + x for x in paths] + self.arg_properties["tt.divisibility"] = divisibility_16 # Equal to 1 property - self.arg_properties["tt.equal_to"] = [ - param.num - for param, arg in zip(params, values) - if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize - ] + equal_to_1 = [] + for param, arg in zip(params, values): + if param.do_not_specialize: + continue + paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_equal_to_1(val)) + equal_to_1 += [(param.num, ) + x for x in paths] + self.arg_properties["tt.equal_to"] = equal_to_1 + + # Equal to None property + equal_to_none = [] + for param, arg in zip(params, values): + paths = find_paths_if(arg, lambda path, val: val is None) + equal_to_none += [(param.num, ) + x for x in paths] + self.equal_to_none = equal_to_none def _add_backend_properties(self, params=None, values=None): """ This method is for different subclasses to implement their own compile-time properties """ @@ -130,6 +145,8 @@ def get_constants(self) -> Dict: for prop_name in self.constant_properties: for p in self.arg_properties.get(prop_name, []): constants[p] = self.property_values[prop_name] + for v in self.equal_to_none: + constants[v] = None return constants def filter_out_constants(self): @@ -166,7 +183,7 @@ def from_dict(data): """ attrs_descriptor = _descriptor_table[data["cls"]]() for prop_name, param_ids in data["arg_properties"].items(): - attrs_descriptor.arg_properties[prop_name] = param_ids + attrs_descriptor.arg_properties[prop_name] = list(map(tuple, param_ids)) attrs_descriptor._init_slots() return attrs_descriptor diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index e42f9f44e4..20249336e4 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -14,12 +14,15 @@ from ..runtime.jit import _normalize_ty, get_jit_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction -from .._utils import list_list_flatten, list_list_unflatten +from .._utils import find_paths_if, list_list_flatten, list_list_unflatten +from functools import reduce from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) def mangle_ty(ty): + if ty.is_tuple(): + return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T' if ty.is_ptr(): return 'P' + mangle_ty(ty.element_ty) if ty.is_int(): @@ -58,7 +61,7 @@ def _is_triton_tensor(o: Any) -> bool: def _is_constexpr(o: Any) -> bool: - return isinstance(o, constexpr) + return o is None or isinstance(o, (constexpr, language.core.dtype)) def _is_triton_scalar(o: Any) -> bool: @@ -191,11 +194,66 @@ def visit_Call(self, node: ast.Call) -> bool: return self.visit(node.func) +class ASTFunction: + + def get_path(self, x, path): + return reduce(lambda a, idx: a[idx], path, x) + + def set_path(self, x, path, val): + prev = x if len(path) == 1 else self.get_path(x, path[:-1]) + prev[path[-1]] = val + + def __init__(self, ret_types, arg_types, constexprs, constants, attrs): + self.ret_types = ret_types + self.arg_types = arg_types + self.constexprs = constexprs + self.constants = constants + self.attrs = attrs + + def serialize(self, builder: ir.builder): + # fill up IR values in template + # > build function + is_val = lambda path, _: path not in self.constexprs and _ is not None + val_paths = list(find_paths_if(self.arg_types, is_val).keys()) + arg_types = [self.get_path(self.arg_types, path).to_ir(builder) for path in val_paths] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(arg_types, ret_types) + + def deserialize(self, fn): + # create "template" + def make_template(val): + if isinstance(val, (list, tuple, language.tuple_type)): + return language.tuple([make_template(x) for x in val]) + return language.constexpr(None) + + vals = make_template(self.arg_types) + is_val = lambda path, _: path not in self.constexprs and _ is not None + val_paths = list(find_paths_if(self.arg_types, is_val).keys()) + # > set attributes + for attr_path, attr_specs in self.attrs.items(): + for attr_name, attr_val in attr_specs: + if attr_path in val_paths: + fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val) + for i, path in enumerate(val_paths): + ty = self.get_path(self.arg_types, path) + if isinstance(ty, nv_tma_desc_type): + fn.set_arg_attr(i, "tt.nv_tma_desc", 1) + # > add IR values to the template + for i, path in enumerate(val_paths): + ty = self.get_path(self.arg_types, path) + self.set_path(vals, path, language.tensor(fn.args(i), ty)) + # > add constexpr values to the template + constants = self.constants | self.constexprs + for path, val in constants.items(): + self.set_path(vals, path, language.constexpr(val)) + return vals + + class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, - codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, - noinline=False, file_name: Optional[str] = None, begin_line=0): + def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map, + module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False, + file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) self.file_name = file_name @@ -225,8 +283,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.gscope[k] = v self.lscope = {} - self.attributes = attributes - self.constants = constants self.jit_fn = jit_fn self.function_name = function_name self.is_kernel = is_kernel @@ -344,7 +400,6 @@ def visit_compound_statement(self, stmts): stmts = [stmts] for stmt in stmts: self.visit(stmt) - # Stop parsing as soon as we hit a `return` statement; everything # after this is dead code. if isinstance(stmt, ast.Return): @@ -356,7 +411,7 @@ def visit_Module(self, node): def visit_List(self, node): ctx = self.visit(node.ctx) assert ctx is None - elts = [self.visit(elt) for elt in node.elts] + elts = language.tuple([self.visit(elt) for elt in node.elts]) return elts # By design, only non-kernel functions can return @@ -365,16 +420,15 @@ def visit_Return(self, node): if ret_value is None: self.builder.ret([]) ret_ty = language.void - elif isinstance(ret_value, tuple): - ret_values = [semantic.to_tensor(v, self.builder) for v in ret_value] + elif isinstance(ret_value, language.tuple): + ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value.values] ret_types = [v.type for v in ret_values] self.builder.ret([v.handle for v in ret_values]) - ret_ty = tuple(ret_types) + ret_ty = language.tuple_type(ret_types) else: ret = semantic.to_tensor(ret_value, self.builder) self.builder.ret([ret.handle]) ret_ty = ret.type - if self.ret_type is None: self.ret_type = ret_ty elif self.ret_type != ret_ty: @@ -399,7 +453,6 @@ def visit_FunctionDef(self, node): init_node = ast.Assign(targets=[st_target], value=default_value) else: init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) - try: assert not self.visiting_arg_default_value self.visiting_arg_default_value = True @@ -409,34 +462,15 @@ def visit_FunctionDef(self, node): # initialize function visibility = "public" if self.is_kernel else "private" - self.fn = self.builder.get_or_insert_function(self.module, self.function_name, - self.prototype.to_ir(self.builder), visibility, self.noinline) + fn_ty = self.prototype.serialize(self.builder) + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline) self.module.push_back(self.fn) entry = self.fn.add_entry_block() - arg_values = [] - idx = 0 - for i in range(len(arg_names)): - if i in self.constants: - cst = self.constants[i] - if not _is_constexpr(cst): - cst = constexpr(self.constants[i]) - arg_values.append(cst) - continue - else: - if i in self.attributes: - for name, value in self.attributes[i]: - self.fn.set_arg_attr(idx, name, value) - - # Mark this argument as a pass-by-value TMA descriptor (nvidia) - if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): - self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) - - arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) - idx += 1 - - insert_pt = self.builder.get_insertion_block() + arg_values = self.prototype.deserialize(self.fn) + # bind arguments to symbols for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) + insert_pt = self.builder.get_insertion_block() self.builder.set_insertion_point_to_start(entry) # visit function body self.visit_compound_statement(node.body) @@ -447,8 +481,11 @@ def visit_FunctionDef(self, node): self.ret_type = language.void self.builder.ret([]) else: - self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type] - self.fn.reset_type(self.prototype.to_ir(self.builder)) + if isinstance(self.ret_type, language.tuple_type): + self.prototype.ret_types = self.ret_type.types + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.serialize(self.builder)) self.builder.ret([ self.builder.create_poison(ty.to_ir(self.builder)) for ty in self.prototype.ret_types @@ -480,37 +517,41 @@ def visit_AnnAssign(self, node): if target in self.lscope: raise ValueError(f'{target} is already defined.' f' constexpr cannot be reassigned.') - if not _is_constexpr(value): - value = constexpr(value) + value = constexpr(value) self.lscope[target] = value return self.lscope[target] # default: call visit_Assign return self.visit_Assign(node) + def assignTarget(self, target, value): + if isinstance(target, ast.Subscript): + assert target.ctx.__class__.__name__ == "Store" + return self.visit_Subscript_Store(target, value) + if isinstance(target, ast.Tuple): + assert target.ctx.__class__.__name__ == "Store" + for i, name in enumerate(target.elts): + self.set_value(self.visit(name), value.values[i]) + return + assert isinstance(target, ast.Name) + self.set_value(self.visit(target), value) + def visit_Assign(self, node): - _names = [] - if isinstance(node, ast.AnnAssign): - _names += [self.visit(node.target)] - else: - for target in node.targets: - _names += [self.visit(target)] - if len(_names) > 1: - raise self._unsupported(node, "simultaneous multiple assignment is not supported.") - names = _names[0] - values = self.visit(node.value) - if not _is_list_like(names): - names = [names] - if not _is_list_like(values): - values = [values] - native_nontensor_types = (language.dtype, ) - for name, value in zip(names, values): - # by default, constexpr are assigned into python variable + # construct values to assign + def _sanitize_value(value): + if isinstance(value, language.tuple): + return language.tuple([_sanitize_value(v) for v in value.values]) + native_nontensor_types = (language.dtype, language.tuple) value = _unwrap_if_constexpr(value) if value is not None and \ - not _is_triton_value(value) and \ - not isinstance(value, native_nontensor_types): - value = semantic.to_tensor(value, self.builder) - self.set_value(name, value) + not _is_triton_value(value) and \ + not isinstance(value, native_nontensor_types): + value = language.semantic.to_tensor(value, self.builder) + return value + + values = _sanitize_value(self.visit(node.value)) + targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets + assert len(targets) == 1 + self.assignTarget(targets[0], values) def visit_AugAssign(self, node): name = node.target.id @@ -533,7 +574,7 @@ def visit_Load(self, node): def visit_Tuple(self, node): args = [self.visit(x) for x in node.elts] - return tuple(args) + return language.tuple(args) def _apply_binary_method(self, method_name, lhs, rhs): # TODO: raise something meaningful if getattr fails below, esp for reverse method @@ -905,7 +946,7 @@ def visit_While(self, node): assert False, "Not implemented" ast.NodeVisitor.generic_visit(self, stmt) - def visit_Subscript(self, node): + def visit_Subscript_Load(self, node): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) slices = self.visit(node.slice) @@ -913,6 +954,16 @@ def visit_Subscript(self, node): return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] + def visit_Subscript_Store(self, node, value): + assert node.ctx.__class__.__name__ == "Store" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + assert isinstance(lhs, language.tuple) + lhs.__setitem__(slices, value) + + def visit_Subscript(self, node): + return self.visit_Subscript_Load(node) + def visit_ExtSlice(self, node): return [self.visit(dim) for dim in node.dims] @@ -1069,7 +1120,7 @@ def visit_Slice(self, node): lower = self.visit(node.lower) upper = self.visit(node.upper) step = self.visit(node.step) - return slice(lower, upper, step) + return language.slice(lower, upper, step) def visit_Index(self, node): return self.visit(node.value) @@ -1085,24 +1136,26 @@ def visit_Assert(self, node) -> Any: def call_JitFunction(self, fn: JITFunction, args, kwargs): args = inspect.getcallargs(fn.fn, *args, **kwargs) args = [args[name] for name in fn.arg_names] - args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] - # generate function def - attributes = {} - constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] - constants = {i: args[i] for i in constexprs} - # generate call - args = [None if i in constexprs else arg for i, arg in enumerate(args)] - arg_vals = [arg.handle for arg in args if arg is not None] - arg_types = [arg.type for arg in args if arg is not None] - fn_name = mangle_fn(fn.__name__, arg_types, constants) + for i, arg in enumerate(args): + if isinstance(arg, (language.dtype, float, int, bool)): + args[i] = language.core.constexpr(arg) + args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x)) + args_val = find_paths_if(args, lambda _, x: not _is_constexpr(x)).values() + # mangle + fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst) # generate function def if necessary if not self.module.has_function(fn_name): - prototype = language.function_type([], arg_types) gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller file_name, begin_line = get_jit_fn_file_line(fn) - generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, - jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + arg_types = [ + language.core.constexpr if arg is None or isinstance(arg, + (bool, int, language.core.dtype)) else arg.type + for arg in args + ] + prototype = ASTFunction([], arg_types, args_cst, dict(), dict()) + generator = CodeGenerator(self.context, prototype, gscope, module=self.module, jit_fn=fn, + function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, options=self.builder.options, codegen_fns=self.builder.codegen_fns, module_map=self.builder.module_map) @@ -1117,8 +1170,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): else: callee_ret_type = self.function_ret_types[fn_name] symbol = self.module.get_function(fn_name) - call_op = self.builder.call(symbol, arg_vals) - if call_op.get_num_results() == 0 or callee_ret_type is None: + args_val = [arg.handle for arg in args_val] + call_op = self.builder.call(symbol, args_val) + if callee_ret_type is None: return None elif call_op.get_num_results() == 1: return tensor(call_op.get_result(0), callee_ret_type) @@ -1126,8 +1180,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): # should return a tuple of tl.tensor results = [] for i in range(call_op.get_num_results()): - results.append(tensor(call_op.get_result(i), callee_ret_type[i])) - return tuple(results) + results.append(tensor(call_op.get_result(i), callee_ret_type.types[i])) + return language.tuple(results) def visit_Call(self, node): fn = _unwrap_if_constexpr(self.visit(node.func)) @@ -1146,7 +1200,11 @@ def visit_Call(self, node): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: - return fn(*args, **extra_kwargs, **kws) + ret = fn(*args, **extra_kwargs, **kws) + # builtin functions return plain tuples for readability + if isinstance(ret, tuple): + ret = language.tuple(ret) + return ret except Exception as e: # Normally when we raise a CompilationError, we raise it as # `from None`, because the original fileline from the exception @@ -1287,38 +1345,29 @@ def kernel_suffix(signature, specialization): suffix = '' for i, _ in enumerate(signature): suffix += str(i) - if i in specialization.equal_to_1: + if (i, ) in specialization.equal_to_1: suffix += 'c' - if i in specialization.divisibility_16: + if (i, ) in specialization.divisibility_16: suffix += 'd' return suffix def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): + constexprs = specialization.constexprs + arg_idx = lambda x: (fn.arg_names.index(x), ) if isinstance(x, str) else x + constants = specialization.attrs.get_constants() + constexprs = {arg_idx(k): v for k, v in constexprs.items()} + arg_types = [str_to_ty(ty) for ty in specialization.signature.values()] + # find index of constants in serialized order attrs = specialization.attrs - # create kernel prototype - cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in specialization.constants.items()} - # visit kernel AST - gscope = fn.__globals__.copy() - function_name = fn.repr(specialization) - tys = list(specialization.signature.values()) - new_constants = attrs.get_constants() - for k in new_constants: - if k in tys and tys[k] == "i1" and new_constants[k] == 1: - new_constants[k] = True - new_attrs = attrs.filter_out_constants() fn_attrs = new_attrs.get_fn_attrs() - all_constants = constants.copy() - all_constants.update(new_constants) - arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + fn_attrs = {k: v for k, v in fn_attrs.items() if k not in constants} file_name, begin_line = get_jit_fn_file_line(fn) - - prototype = language.function_type([], arg_types) - generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, - jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, - begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) + prototype = ASTFunction([], arg_types, constexprs, constants, fn_attrs) + generator = CodeGenerator(context, prototype, gscope=fn.__globals__.copy(), function_name=fn.repr(specialization), + jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options, + codegen_fns=codegen_fns, module_map=module_map) generator.visit(fn.parse()) ret = generator.module diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index c9427c78fd..6a3da233be 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -51,12 +51,12 @@ def convert_type_repr(x): class ASTSource: - def __init__(self, fn, signature, constants=None, attrs=None) -> None: + def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: self.fn = fn self.ext = "ttir" self.name = fn.__name__ self.signature = signature - self.constants = constants + self.constexprs = constexprs self.attrs = attrs if isinstance(self.signature, str): self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} @@ -64,20 +64,19 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: for k in self.signature.keys(): if not isinstance(k, str): raise TypeError("Signature keys must be string") - if self.constants is None: - self.constants = {} - else: - for k in self.constants.keys(): - if not isinstance(k, str): - raise TypeError("Constants keys must be string") + if self.constexprs is None: + self.constexprs = {} if self.attrs is None: self.attrs = AttrsDescriptor() + # this is the constexprs plus the specialized constants + spec_constants = {self.fn.arg_names[k[0]]: v for k, v in self.attrs.get_constants().items() if len(k) == 1} + self.constants = self.constexprs | spec_constants def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] # Note - we stringify the keys here to allow sorting to work for cases # where constants have mixed int/str keys. - sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + sorted_constants = sorted((str(k), v) for k, v in self.constexprs.items()) key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" return hashlib.sha256(key.encode("utf-8")).hexdigest() @@ -276,11 +275,11 @@ def compile(src, target=None, options=None): codegen_fns = backend.get_codegen_implementation() module_map = backend.get_module_map() - try: - module = src.make_ir(options, codegen_fns, module_map, context) - except Exception as e: - filter_traceback(e) - raise + # try: + module = src.make_ir(options, codegen_fns, module_map, context) + # except Exception as e: + # filter_traceback(e) + # raise use_ir_loc = os.environ.get("USE_IR_LOC", None) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) @@ -412,7 +411,7 @@ def launch_metadata(self, grid, stream, *args): arg_idx = 0 for i, arg_name in enumerate(self.src.fn.arg_names): if i in self.src.fn.constexprs: - arg_dict[arg_name] = self.src.constants[arg_name] + arg_dict[arg_name] = self.src.constexprs[arg_name] else: arg_dict[arg_name] = args[arg_idx] arg_idx += 1 diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 0c8965fc52..5f5d464d63 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,6 +1,7 @@ """isort:skip_file""" # Import order is significant here. +from .._utils import parse_list_string from . import math from . import extra from .standard import ( @@ -69,7 +70,6 @@ float8e5, float8e5b16, full, - function_type, gather, histogram, inline_asm_elementwise, @@ -95,6 +95,7 @@ range, reduce, reshape, + slice, split, static_assert, static_print, @@ -102,6 +103,8 @@ store, tensor, trans, + tuple, + tuple_type, uint16, uint32, uint64, @@ -188,7 +191,6 @@ "floor", "fma", "full", - "function_type", "gather", "histogram", "inline_asm_elementwise", @@ -232,6 +234,7 @@ "reduce", "reshape", "rsqrt", + "slice", "sigmoid", "sin", "softmax", @@ -248,6 +251,7 @@ "tensor", "trans", "triton", + "tuple", "uint16", "uint32", "uint64", @@ -264,6 +268,9 @@ def str_to_ty(name): + if name == "none": + return None + if name[0] == "*": name = name[1:] const = False @@ -273,9 +280,17 @@ def str_to_ty(name): ty = str_to_ty(name) return pointer_type(element_ty=ty, const=const) + if name[0] == "[": + names = parse_list_string(name) + tys = [str_to_ty(x) for x in names] + return tuple_type(types=tys) + if name == "nvTmaDesc": return nv_tma_desc_type() + if name == "constexpr": + return constexpr + tys = { "fp8e4nv": float8e4nv, "fp8e4b8": float8e4b8, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 13d097dc34..dcce42908c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -140,6 +140,7 @@ def __init__(self, value): self.value = value.value else: self.value = value + self.type = constexpr def __repr__(self) -> str: return f"constexpr[{self.value}]" @@ -473,6 +474,10 @@ def is_ptr(): def is_const(): return False + @staticmethod + def is_tuple(): + return False + def __eq__(self, other: dtype): if not isinstance(other, dtype): return False @@ -608,11 +613,10 @@ def __init__(self, element_ty: dtype, shape: List): # Note that block_type's shape is a list of int # while tensor's shape is a list of constexpr. - - assert (isinstance(shape, list)) + assert (isinstance(shape, (list, tuple))) # shape can be empty ([]) when an input is a 0D tensor. - self.shape = _unwrap_shape(shape) + self.shape = tuple(_unwrap_shape(shape)) if not self.shape: raise TypeError('0d block_type is forbidden') @@ -647,19 +651,32 @@ def scalar(self): return self.element_ty -class function_type(dtype): +class tuple_type(dtype): - def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: - self.ret_types = ret_types - self.param_types = param_types + def __init__(self, types): + self.types = types + self.name = f"[{','.join(map(str, self.types))}]" def __str__(self): - return f'fn ({self.param_types}) -> {self.ret_types}' + return self.name + + def __iter__(self): + return iter(self.types) def to_ir(self, builder: ir.builder): - ir_param_types = [ty.to_ir(builder) for ty in self.param_types] - ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] - return builder.get_function_ty(ir_param_types, ret_types) + return [ty.to_ir(builder) for ty in self.types] + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def is_tuple(self): + return True + + +class slice_type(dtype): + + def __init__(self): + self.name = 'slice_type' # scalar types @@ -761,7 +778,7 @@ def __init__(self, handle, type: dtype): self.type = type # Tensor type (can be block_type) # Following the practice in pytorch, dtype is scalar type self.dtype = type.scalar - self.shape = [constexpr(s) for s in self.shape] + self.shape = tuple([constexpr(s) for s in self.shape]) def _flatten_ir(self): return [self.handle] @@ -982,13 +999,16 @@ def __not__(self, _builder=None): @builtin def __getitem__(self, slices, _builder=None): - if isinstance(slices, (slice, constexpr)) or slices is None: + import builtins + if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None: slices = [slices] + if isinstance(slices, tuple): + slices = slices.values ret = self for dim, sl in enumerate(slices): if sl is None or isinstance(sl, constexpr) and sl.value is None: ret = semantic.expand_dims(ret, dim, _builder) - elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: + elif isinstance(sl, (builtins.slice, slice)) and sl.start is None and sl.stop is None and sl.step is None: pass else: raise ValueError(f"unsupported tensor index: {sl}") @@ -1141,6 +1161,77 @@ def flip(self, dim=None) -> tensor: ... +class tuple: + + def __init__(self, args: list): + self.values = [i for i in args] + + @property + def type(self): + + def get_type(x): + if isinstance(x, dtype): + return dtype + return x.type + + return tuple_type([get_type(x) for x in self.values]) + + def __getitem__(self, idx: constexpr): + if isinstance(idx, int): + idx = constexpr(idx) + if isinstance(idx, constexpr): + return self.values[idx] + else: + import builtins + assert isinstance(idx, (slice, builtins.slice)) + return tuple(self.values[idx.start:idx.stop:idx.step]) + + # TODO: remove + def __setitem__(self, idx: constexpr, value): + if isinstance(idx, int): + idx = constexpr(idx) + assert isinstance(idx, constexpr) + self.values[idx] = value + + def __add__(self, other): + if isinstance(other, list): + other = tuple(other) + return tuple(self.values + other.values) + # return tuple(a + b for a, b in zip(self.values, other.values)) + + def __mul__(self, other): + assert isinstance(other, constexpr) + return tuple(self.values * other.value) + + def __eq__(self, other): + import builtins + if isinstance(other, (list, builtins.tuple)): + other = tuple(other) + return constexpr(self.values == other.values) + + def __hash__(self): + import builtins + return hash(builtins.tuple(self.values)) + + def __str__(self): + return str([str(x) for x in self.values]) + + def __iter__(self): + return iter(self.values) + + def __len__(self): + return len(self.values) + + +class slice: + + def __init__(self, start, stop, step): + self.start = start + self.stop = stop + self.step = step + self.type = slice_type() + + class _experimental_tensor_descriptor_base(_value): """" A tensor descriptor with unknown shape and strides @@ -1556,7 +1647,7 @@ def expand_dims(input, axis, _builder=None): """ input = semantic.to_tensor(input, _builder) axis = _constexpr_to_value(axis) - axes = list(axis) if isinstance(axis, Sequence) else [axis] + axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis] new_ndim = len(input.shape) + len(axes) axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] @@ -2210,14 +2301,12 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] def make_combine_region(reduce_op): - in_scalar_tys = [t.type.scalar for t in input] - prototype = function_type(in_scalar_tys, in_scalar_tys * 2) - + param_types = [t.type.scalar for t in input] * 2 region = reduce_op.get_region(0) with _insertion_guard(_builder): - param_types = [ty.to_ir(_builder) for ty in prototype.param_types] - block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + to_ir = lambda T: T.to_ir(_builder) + block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] @@ -2311,14 +2400,12 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] def make_combine_region(scan_op): - in_scalar_tys = [t.type.scalar for t in input] - prototype = function_type(in_scalar_tys, in_scalar_tys * 2) - + param_types = [t.type.scalar for t in input] * 2 region = scan_op.get_region(0) with _insertion_guard(_builder): - param_types = [ty.to_ir(_builder) for ty in prototype.param_types] - block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + to_ir = lambda T: T.to_ir(_builder) + block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 3f97eda1d0..e5dbf16ccb 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -767,14 +767,14 @@ def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> # Add new axes to lhs for _ in range(len(lhs_shape), len(rhs_shape)): lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), - tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values)) lhs_ty = lhs.type lhs_shape = lhs_ty.get_block_shapes() elif len(rhs_shape) < len(lhs_shape): # Add new axes to rhs for _ in range(len(rhs_shape), len(lhs_shape)): rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), - tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values)) rhs_ty = rhs.type rhs_shape = rhs_ty.get_block_shapes() assert len(rhs_shape) == len(lhs_shape) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 07b82df414..36f368426a 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -308,6 +308,8 @@ def mangle_type(arg, is_const=False): return "fp32" elif hasattr(arg, "tma_desc_cpu_ptr"): return "nvTmaDesc" + elif isinstance(arg, tuple): + return "[" + ",".join(map(mangle_type, arg)) + "]" else: # dtypes are hashable so we can memoize this mapping: dsk = (arg.dtype, is_const) @@ -335,8 +337,8 @@ def serialize_specialization_data(name, signature, constants, attrs, options, ke constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} import json obj = { - 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': - options.__dict__, 'key': key + 'name': name, 'signature': signature, 'constant_keys': list(constants.keys()), 'constant_vals': + list(constants.values()), 'attrs': attrs.to_dict(), 'options': options.__dict__, 'key': key } serialized_obj = json.dumps(obj) return serialized_obj @@ -368,6 +370,7 @@ def create_function_from_signature(sig, kparams, backend): func_args.append(f"{name}=default_{name}") dict_entries.append(f"'{name}': {name}") if kp.is_constexpr: + signature_types.append('"constexpr"') constexpr_vals.append(name) else: non_constexpr_vals.append(name) @@ -599,32 +602,23 @@ def run(self, *args, grid, warmup, **kwargs): # done here rather than when we build the signature as otherwise # the kernel cache key could not distinguish between byte pointers # and None arguments, resulting in a downstream mismatch: - sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigkeys = [param.name for param in self.params] sigvals = sig_and_spec[:len(sigkeys)] - signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} - - configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) - constant_params = configs[0].get_constants() - constants = { - p.name: v - for (v, p) in zip(bound_vals, self.params) - if p.is_constexpr or (p.num in constant_params) or v is None - } - for i, arg in constants.items(): + signature = {k: v for (k, v) in zip(sigkeys, sigvals)} + + attrs = backend.get_attrs_descriptor(self.params, bound_vals) + constexprs = {p.name: v for (v, p) in zip(bound_vals, self.params) if p.is_constexpr} + for i, arg in constexprs.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): + if self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=True): return None # compile the kernel - src = self.ASTSource(self, signature, constants, configs[0]) - kernel = self.compile( - src, - target=target, - options=options.__dict__, - ) + src = self.ASTSource(self, signature, constexprs, attrs) + kernel = self.compile(src, target=target, options=options.__dict__) kernel_cache[key] = kernel - self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) + self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False) # Check that used global values have not changed. not_present = object() @@ -637,15 +631,11 @@ def run(self, *args, grid, warmup, **kwargs): # canonicalize grid assert grid is not None if callable(grid): - # Arguments are passed as a dict to `grid`, by contract. - # TODO(jlebar): In the new launch API, pass the compiler flags as a - # second parameter to `grid`. grid = grid(bound_args) grid_size = len(grid) grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 - # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, @@ -734,9 +724,11 @@ def preload(self, specialization_data): if deserialized_obj['name'] != self.fn.__name__: raise RuntimeError( f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constant_keys = deserialized_obj['constant_keys'] + constant_vals = deserialized_obj['constant_vals'] constants = { key: tl.dtype(value) if tl.dtype.is_dtype(value) else value - for key, value in deserialized_obj['constants'].items() + for key, value in zip(constant_keys, constant_vals) } signature = dict(deserialized_obj['signature'].items()) src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 6adf7794cc..50483b2362 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -91,15 +91,13 @@ def constexpr(s): pass return None - hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} hints = {k: v for k, v in hints.items() if v is not None} constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} constants = {k: v for k, v in constants.items() if v is not None} - signature = { - kernel.arg_names[i]: s.split(":")[0] - for i, s in enumerate(signature) - if kernel.arg_names[i] not in constants - } + signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)} + for key in constants: + signature[key] = 'constexpr' const_sig = 'x'.join([str(v) for v in constants.values()]) doc_string = [f"{k}={v}" for k, v in constants.items()] doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] @@ -109,8 +107,8 @@ def constexpr(s): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" attrs = triton.backends.compiler.AttrsDescriptor.from_hints(hints) for p, v in attrs.get_constants().items(): - constants.update({kernel.arg_names[p]: v}) - src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + constants.update({kernel.arg_names[p[0]]: v}) + src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs) opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} ccinfo = triton.compile(src, options=opts) if ccinfo.metadata.global_scratch_size > 0: @@ -126,7 +124,7 @@ def constexpr(s): arg_types.append(signature[arg_name]) arg_names_not_1.append(arg_name) arg_types_not_1.append(signature[arg_name]) - elif i in attrs.equal_to_1: + elif (i, ) in attrs.equal_to_1: arg_names.append(arg_name) arg_types.append(signature[arg_name]) diff --git a/scripts/patch-pytorch.sh b/scripts/patch-pytorch.sh index 5e35d25441..63be29146f 100755 --- a/scripts/patch-pytorch.sh +++ b/scripts/patch-pytorch.sh @@ -6,6 +6,7 @@ set -euo pipefail REPO_ROOT="$(git rev-parse --show-toplevel)" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" if [[ ! $REPO_ROOT ]]; then echo "Failed to identify root of the repository." @@ -16,3 +17,5 @@ echo "Applying PyTorch patches in $REPO_ROOT" cd "$REPO_ROOT" curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply - +# REVERT ME: it's just a trigger for pytorch rebuild +git apply "${SCRIPT_DIR}/pytorch.patch" diff --git a/scripts/pytorch.patch b/scripts/pytorch.patch new file mode 100644 index 0000000000..25b236dba1 --- /dev/null +++ b/scripts/pytorch.patch @@ -0,0 +1,258 @@ +diff --git a/test/inductor/test_codegen_triton.py b/test/inductor/test_codegen_triton.py +index 84264bf1b0..eb3fac0f39 100644 +--- a/test/inductor/test_codegen_triton.py ++++ b/test/inductor/test_codegen_triton.py +@@ -48,7 +48,7 @@ class TestCodegenTriton(InductorTestCase): + return config.divisible_by_16 + + self.assertEqual( +- (2,), ++ [(2,)], + _check_divisibility( + triton_utils.config_of( + [ +@@ -63,7 +63,7 @@ class TestCodegenTriton(InductorTestCase): + ) + + self.assertEqual( +- (0, 2, 4, 5, 6), ++ [(0, 2, 4, 5, 6)], + _check_divisibility( + triton_utils.config_of( + [ +diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py +index ace56135fe..568cbde0a1 100644 +--- a/torch/_higher_order_ops/triton_kernel_wrap.py ++++ b/torch/_higher_order_ops/triton_kernel_wrap.py +@@ -238,7 +238,7 @@ def generate_ttir( + + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) +- return backend.get_attrs_descriptor(args, kernel.params) ++ return backend.get_attrs_descriptor(kernel.params, args) + except ImportError: + return kernel._get_config(*args) + +@@ -251,7 +251,6 @@ def generate_ttir( + signature = { + name: kernel._type_of(kernel._key_of(arg)) + for i, (name, arg) in enumerate(ordered_args.items()) +- if i not in kernel.constexprs + } + + triton._C.libtriton.ir.load_dialects(context) +diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py +index 00031a56b8..b941e2aaa6 100644 +--- a/torch/_inductor/codegen/triton.py ++++ b/torch/_inductor/codegen/triton.py +@@ -2980,6 +2980,7 @@ class TritonKernel(SIMDKernel): + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() ++ # breakpoint() + # maps actual expression to SizeArg if it is in sizevars replacements + for i, arg in enumerate(signature): + if isinstance(arg, SizeArg): +@@ -3030,7 +3031,7 @@ class TritonKernel(SIMDKernel): + triton_meta = { + "signature": triton_meta_signature, + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), +- "constants": {}, ++ "constexprs": {}, + } + + # Skip memory optimization for forward of the training loop where we expect +@@ -3065,20 +3066,12 @@ class TritonKernel(SIMDKernel): + argdefs.append(f"{tree.prefix}numel") + # constexpr version causes issues, see + # https://github.com/pytorch/torchdynamo/pull/1362 +- # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( ++ # triton_meta["constexprs"][len(argdefs)] = V.graph.sizevars.size_hint( + # tree.numel + # ) + # argdefs.append(f"{tree.prefix}numel: tl.constexpr") + triton_meta["configs"] = [config_of(signature)] + +- # Triton compiler includes equal_to_1 args into constants even +- # when they are not constexpr. otherwise there may be a segfault +- # during launching the Inductor-compiled Triton kernel. +- # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 +- # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 +- for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] +- triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] +- + self.triton_meta = triton_meta + + for tree in self.range_trees: +@@ -3087,9 +3080,14 @@ class TritonKernel(SIMDKernel): + continue + if tree.tensor_dim is None: + continue +- argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr") ++ const_name = f"{tree.prefix.upper()}BLOCK" ++ triton_meta['signature'][const_name] = 'constexpr' ++ triton_meta['constexprs'][const_name] = tree.numel ++ argdefs.append(f"{const_name} : tl.constexpr") + + if self.cooperative_reduction: ++ triton_meta['signature']['RSPLIT'] = 'constexpr' ++ triton_meta['constexprs']['RSPLIT'] = tree.numel + argdefs.append("RSPLIT : tl.constexpr") + + self.codegen_body() +diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py +index 8b8c29bbb1..3e5abaa824 100644 +--- a/torch/_inductor/codegen/triton_utils.py ++++ b/torch/_inductor/codegen/triton_utils.py +@@ -157,13 +157,13 @@ def config_of( + raise NotImplementedError(f"unhandled {type(x)}: {x}") + + if config.triton.divisible_by_16: +- divisible_by_16 = tuple( ++ divisible_by_16 = [tuple( + i + for i, arg in zip(indices, args) + if is_aligned(arg, alignment=16, include_tensor=True) +- ) ++ )] + else: +- divisible_by_16 = () ++ divisible_by_16 = [] + + equal_to_1 = tuple( + i +@@ -172,5 +172,7 @@ def config_of( + and isinstance(arg.expr, (int, sympy.Integer)) + and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type] + ) ++ if equal_to_1 != tuple(): ++ equal_to_1 = [equal_to_1] + + return AttrsDescriptorWrapper(divisible_by_16, equal_to_1) +diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py +index 2ab2b32635..5f08c3c0b7 100644 +--- a/torch/_inductor/codegen/wrapper.py ++++ b/torch/_inductor/codegen/wrapper.py +@@ -1535,16 +1535,21 @@ class PythonWrapperCodegen(CodeGen): + + signature: List[KernelArgType] = [] + constants: Dict[str, Any] = {} ++ constexprs = {} + non_constant_indices = [] + equal_to_1_args: List[str] = [] + for idx, key in enumerate(kernel.arg_names): + if key not in kwargs: ++ if idx in kernel.constexprs: ++ constexprs[key] = 'constexpr' + continue + arg = kwargs[key] + if idx in kernel.constexprs: + constants[key] = arg ++ constexprs[key] = 'constexpr' + elif kwargs[key] is None: + constants[key] = None ++ constexprs[key] = 'constexpr' + else: + non_constant_indices.append(idx) + if isinstance(arg, ir.TMADescriptor): +@@ -1596,9 +1601,8 @@ class PythonWrapperCodegen(CodeGen): + # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 +- "constants": { ++ "constexprs": { + **constants, +- **dict.fromkeys(equal_to_1_args, 1), + }, + "configs": [ + config_of( +@@ -1607,6 +1611,8 @@ class PythonWrapperCodegen(CodeGen): + ) + ], + } ++ for constexpr_name in constexprs.keys(): ++ triton_meta['signature'][constexpr_name] = 'constexpr' + + if restore_value_args: + triton_meta["restore_value"] = tuple(restore_value_args) +diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py +index 276c01f3f4..4e6e1ab9ce 100644 +--- a/torch/_inductor/runtime/hints.py ++++ b/torch/_inductor/runtime/hints.py +@@ -53,6 +53,7 @@ if _is_triton_available(): + } + + # Instantiate AttrsDescriptor with the prepared arguments ++ # breakpoint() + res = AttrsDescriptor.from_dict( + {"arg_properties": kwargs, "cls": AttrsDescriptor.__name__} + ) +diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py +index af8530e94d..a1935831e2 100644 +--- a/torch/_inductor/runtime/triton_heuristics.py ++++ b/torch/_inductor/runtime/triton_heuristics.py +@@ -407,6 +407,7 @@ class CachingAutotuner(KernelInterface): + + def _precompile_config(self, cfg: Config, warm_cache_only: bool): + """Ahead of time compile a given autotuner config.""" ++ # print(f"self.triton_meta: {self.triton_meta}") + compile_meta = copy.deepcopy(self.triton_meta) + for k, v in cfg.kwargs.items(): + if self.device_props.type == "hip": +@@ -419,7 +420,7 @@ class CachingAutotuner(KernelInterface): + if k == "kpack": + compile_meta["kpack"] = v + continue +- compile_meta["constants"][k] = v ++ compile_meta["constexprs"][k] = v + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + compile_meta["debug"] = self.inductor_meta.get( +@@ -435,12 +436,13 @@ class CachingAutotuner(KernelInterface): + else: + triton_helpers.set_driver_to_gpu() + ++ # print(compile_meta) + if ASTSource: + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], +- compile_meta["constants"], ++ compile_meta["constexprs"], + compile_meta["configs"][0], + ), + ) +@@ -527,7 +529,7 @@ class CachingAutotuner(KernelInterface): + We also don't want to modify self.fn. + + We know that we removed something from the signature if: +- 1. It's in compile_meta["constants"] ++ 1. It's in compile_meta["constexprs"] + 2. It isn't a constant we already know about + Note: The value of interest has already been added to compile_meta['constants'], + so we use self.fn.constexprs instead. +@@ -538,7 +540,7 @@ class CachingAutotuner(KernelInterface): + } + none_args = { + k +- for k, v in compile_meta["constants"].items() ++ for k, v in compile_meta["constexprs"].items() + if v is None and k not in known_constants + } + none_args = none_args.difference(set(compile_meta["signature"].keys())) +@@ -548,12 +550,14 @@ class CachingAutotuner(KernelInterface): + for i, arg in enumerate(self.fn.arg_names) + if i not in self.fn.constexprs and arg not in none_args + ] ++ # print(f"call_args: {call_args}") + + def_args = [ + name + for name in self.fn.arg_names + if name not in cfg.kwargs and name not in none_args + ] ++ # print(f"def_args: {def_args}\n") + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 77a1233dbb..ad3e654faf 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,5 +1,6 @@ from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor from triton._C.libtriton import ir, passes, llvm, amd +from triton._utils import find_paths_if from dataclasses import dataclass from typing import Any, Dict, Tuple from types import ModuleType @@ -104,10 +105,14 @@ def _add_backend_properties(self, params=None, values=None): if params is None or values is None: return - self.arg_properties["tt.pointer_range"] = [ - param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg) - and not param.do_not_specialize and not param.do_not_specialize_on_alignment - ] + pointer_range = [] + for param, arg in zip(params, values): + if param.do_not_specialize or \ + param.do_not_specialize_on_alignment: + continue + paths = find_paths_if(arg, lambda path, val: HIPAttrsDescriptor.is_within2gb(val)) + pointer_range += [(param.num, ) + x for x in paths] + self.arg_properties["tt.pointer_range"] = pointer_range @staticmethod def is_within2gb(arg): diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 537604d8d4..dc424caddb 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -8,6 +8,7 @@ from triton.runtime.cache import get_cache_manager from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver +from triton._utils import parse_list_string dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] @@ -164,7 +165,7 @@ def __init__(self): # -------------------- Launcher ---------------------------- def ty_to_cpp(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "hipDeviceptr_t" return { "i1": "int32_t", @@ -186,32 +187,27 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids, warp_size): - start_desc = len(signature) - #signature = generate_cu_signature(constants, signature, ids) - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "PyObject*" - return { - 'i1': 'int32_t', - 'i8': 'int8_t', - 'i16': 'int16_t', - 'i32': 'int32_t', - 'i64': 'int64_t', - 'u1': 'uint32_t', - 'u8': 'uint8_t', - 'u16': 'uint16_t', - 'u32': 'uint32_t', - 'u64': 'uint64_t', - 'fp16': 'float', - 'bf16': 'float', - 'fp32': 'float', - 'f32': 'float', - 'fp64': 'double', - }[ty] + if ty[0] == '[': + if ty == "[]": + return "[]" + tys = parse_list_string(ty) + val = ','.join(map(_extracted_type, tys)) + return f"[{val}]" + return ty_to_cpp(ty) def format_of(ty): + if ty == "hipDeviceptr_t": + return "O" + if ty[0] == "[": + if ty == "[]": + return "()" + tys = parse_list_string(ty) + val = ''.join(map(format_of, tys)) + return f"({val})" return { "PyObject*": "O", "float": "f", @@ -227,14 +223,22 @@ def format_of(ty): "uint64_t": "K", }[ty] + signature = {k: v for k, v in signature.items() if v != 'constexpr'} args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOO" + args_format + signature = ','.join(signature.values()).replace('[', '').replace(']', '') + signature = list(filter(bool, signature.split(','))) + signature = {i: s for i, s in enumerate(signature)} args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) libhip_path = _get_path_to_hip_runtime_dylib() # generate glue code - params = [f"&arg{i}" for i in signature.keys() if i not in constants] + params = list(range(len(signature))) + params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"] params.append("&global_scratch") src = f""" #define __HIP_PLATFORM_AMD__ @@ -416,8 +420,8 @@ def format_of(ty): // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); if(launch_exit_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); @@ -468,9 +472,8 @@ class HIPLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} + constants = {idx: value for idx, value in constants.items()} + signature = {idx: value for idx, value in src.signature.items()} src = make_launcher(constants, signature, ids, metadata.warp_size) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index f2a237c35b..7be311ce4f 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -12,6 +12,7 @@ from triton.runtime.cache import get_cache_manager from triton.backends.compiler import GPUTarget from triton.backends.driver import DriverBase +from triton._utils import parse_list_string def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]: @@ -178,7 +179,7 @@ def wait(self): def ty_to_cpp(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "void*" return { "i1": "int32_t", @@ -200,16 +201,27 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids): - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "PyObject*" + if ty[0] == '[': + if ty == "[]": + return "[]" + tys = parse_list_string(ty) + val = ','.join(map(_extracted_type, tys)) + return f"[{val}]" return ty_to_cpp(ty) def format_of(ty): + if ty == "void*": + return "O" + if ty[0] == "[": + if ty == "[]": + return "()" + tys = parse_list_string(ty) + val = ''.join(map(format_of, tys)) + return f"({val})" return { "PyObject*": "O", "float": "f", @@ -225,10 +237,18 @@ def format_of(ty): "uint64_t": "K", }[ty] + signature = {k: v for k, v in signature.items() if v != 'constexpr'} args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiOOOOOO" + args_format + signature = ','.join(signature.values()).replace('[', '').replace(']', '') + signature = list(filter(bool, signature.split(','))) + signature = {i: s for i, s in enumerate(signature)} args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + # generate glue code src = f""" #include @@ -329,7 +349,7 @@ def format_of(ty): static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ std::string kernel_name = kernel_ptr.get_info(); - void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; + void *params[] = {{ {', '.join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }}; uint32_t num_params = sizeof(params)/sizeof(params[0]); uint32_t expected_num_params = kernel_ptr.get_info(); size_t global_range_x = gridX*threads_per_warp*num_warps; @@ -347,7 +367,7 @@ def format_of(ty): assert(num_params == expected_num_params && "number of kernel param not matched"); // Submit the imported kernel. auto cgf = [&](sycl::handler &cgh) {{ - {" ".join(f'set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants]))} + {" ".join(f'set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))} if (shared_memory) {{ using share_mem_t = sycl::local_accessor; share_mem_t local_buffer = share_mem_t(shared_memory, cgh); @@ -410,8 +430,8 @@ def format_of(ty): if(kernel_ptr == nullptr) return NULL; sycl::kernel kernel = *kernel_ptr; - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; + sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); if(launch_exit_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); @@ -515,9 +535,8 @@ class XPULauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - self.constants = {cst_key(key): value for key, value in constants.items()} - self.signature = {cst_key(key): value for key, value in src.signature.items()} + self.constants = {idx: value for idx, value in constants.items()} + self.signature = {idx: value for idx, value in src.signature.items()} src = make_launcher(self.constants, self.signature, ids) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 8bfd010773..d54a2d2f03 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -10,6 +10,7 @@ from triton.runtime import _allocation from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver +from triton._utils import parse_list_string dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] @@ -95,7 +96,7 @@ def __init__(self): def ty_to_cpp(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "CUdeviceptr" return { "i1": "int32_t", @@ -118,19 +119,29 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids): - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "PyObject*" if ty == "nvTmaDesc": return "PyObject*" - + if ty[0] == '[': + if ty == "[]": + return "[]" + tys = parse_list_string(ty) + val = ','.join(map(_extracted_type, tys)) + return f"[{val}]" return ty_to_cpp(ty) def format_of(ty): + if ty == "CUdeviceptr": + return "O" + if ty[0] == "[": + if ty == "[]": + return "()" + tys = parse_list_string(ty) + val = ''.join(map(format_of, tys)) + return f"({val})" return { "PyObject*": "O", "float": "f", @@ -146,22 +157,29 @@ def format_of(ty): "uint64_t": "K", }[ty] + signature = {k: v for k, v in signature.items() if v != 'constexpr'} args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKpOOOOO" + args_format + signature = ','.join(signature.values()).replace('[', '').replace(']', '') + signature = list(filter(bool, signature.split(','))) + signature = {i: s for i, s in enumerate(signature)} args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) internal_args_list = [] for i, ty in signature.items(): - if ty[0] == "*": + if ty[0] == "*" or ty == "none": internal_args_list.append(f"ptr_info{i}.dev_ptr") elif ty == "nvTmaDesc": # Note: we have to dereference the pointer internal_args_list.append(f"*tma_ptr{i}") else: internal_args_list.append(f"_arg{i}") + params = range(len(signature)) # generate glue code - params = [f"&arg{i}" for i in signature.keys() if i not in constants] + params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"] params.append("&global_scratch") src = f""" #include \"cuda.h\" @@ -452,7 +470,7 @@ def format_of(ty): }} // raise exception asap - {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; Py_BEGIN_ALLOW_THREADS; _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); @@ -503,9 +521,8 @@ class CudaLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} + constants = {idx: value for idx, value in constants.items()} + signature = {idx: value for idx, value in src.signature.items()} src = make_launcher(constants, signature, ids) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch