From e8bc45df911df2f0968268392a7395ffd4af03a1 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 11 Jun 2024 22:51:13 -0700 Subject: [PATCH 1/5] [BACKEND][AMD] Disable linear layout due to perf regression (#4126) We have identified a 20% perf regression in our downstream flash attention perf kernel after switching to linear layout. Initial analysis shows register pressure is increased to cause spills. Further analysis is still ongoing. So this commit introduces a minimal way to selectively disable linear layout only on AMD backend to avoid affecting NVIDIA backend while continuing bring it up on AMD side. --- include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h | 5 +++++ include/triton/Conversion/TritonGPUToLLVM/Utility.h | 2 +- third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h | 2 ++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 0ee1707a4c..f977d30c02 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -56,6 +56,11 @@ class TargetInfoBase { StringRef message, StringRef file, StringRef func, int line) const = 0; + // Whether to enable linear layout. This is a per-backend temporary escape + // hatch to disable linear layout while figuring out issues. Eventually we + // want to enable linear layout everywhere and delete this control. + virtual bool enableLinearLayout() const { return true; } + virtual ~TargetInfoBase() {} }; } // namespace mlir::triton diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 9f6d17c935..6ad63e3796 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1219,7 +1219,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, bool allowLL = true) { // Eventually the LinearLayout path will be the only one. For now we allow // both paths so we can test that they produce the same results. - if (allowLL) { + if (allowLL && target.enableLinearLayout()) { std::optional>> llOffsets = emitIndicesUsingLinearLayouts(loc, rewriter, target, layout, type, withCTAOffset); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 84733661cf..4e86beb3ca 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -58,6 +58,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase { StringRef message, StringRef file, StringRef func, int line) const override; + bool enableLinearLayout() const override { return false; } + private: void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args, ConversionPatternRewriter &rewriter, bool useStdErr) const; From 27353b7301d586363972ce80e200f392acf03159 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 11 Jun 2024 23:46:17 -0700 Subject: [PATCH 2/5] [FRONTEND] added missing constexpr checks (#4127) also moved things around to remove duplication of `_unwrap_if_constexpr`. --- python/triton/compiler/code_generator.py | 5 +- python/triton/language/core.py | 336 ++++++++++++----------- python/triton/language/standard.py | 8 +- 3 files changed, 178 insertions(+), 171 deletions(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 6903052ca2..9e8c53919b 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -9,6 +9,7 @@ from .. import language from .._C.libtriton import ir from ..language import constexpr, tensor, str_to_ty +from ..language.core import _unwrap_if_constexpr from ..runtime.jit import _normalize_ty # ideally we wouldn't need any runtime component from ..runtime import JITFunction @@ -62,10 +63,6 @@ def _is_list_like(o: Any) -> bool: return isinstance(o, (list, tuple)) -def _unwrap_if_constexpr(o: Any): - return o.value if isinstance(o, constexpr) else o - - def _check_fn_args(node, fn, args): if fn.noinline: for idx, arg in enumerate(args): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 51be31a346..683cd0766d 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -149,6 +149,176 @@ def _to_tensor(x, builder): assert False, f"cannot convert {x} of type {type(x)} to tensor" +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _constexpr_to_value + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _constexpr_to_value(other)) + + def __radd__(self, other): + return constexpr(_constexpr_to_value(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _constexpr_to_value(other)) + + def __rsub__(self, other): + return constexpr(_constexpr_to_value(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _constexpr_to_value(other)) + + def __mod__(self, other): + return constexpr(self.value % _constexpr_to_value(other)) + + def __rmul__(self, other): + return constexpr(_constexpr_to_value(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _constexpr_to_value(other)) + + def __rtruediv__(self, other): + return constexpr(_constexpr_to_value(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _constexpr_to_value(other)) + + def __rfloordiv__(self, other): + return constexpr(_constexpr_to_value(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _constexpr_to_value(other)) + + def __rgt__(self, other): + return constexpr(_constexpr_to_value(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _constexpr_to_value(other)) + + def __rge__(self, other): + return constexpr(_constexpr_to_value(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _constexpr_to_value(other)) + + def __rlt__(self, other): + return constexpr(_constexpr_to_value(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _constexpr_to_value(other)) + + def __rle__(self, other): + return constexpr(_constexpr_to_value(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _constexpr_to_value(other)) + + def __ne__(self, other): + return constexpr(self.value != _constexpr_to_value(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _constexpr_to_value(other)) + + def logical_and(self, other): + return constexpr(self.value and _constexpr_to_value(other)) + + def __or__(self, other): + return constexpr(self.value | _constexpr_to_value(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _constexpr_to_value(other)) + + def logical_or(self, other): + return constexpr(self.value or _constexpr_to_value(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_constexpr_to_value(other)) + + def __rpow__(self, other): + return constexpr(_constexpr_to_value(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _constexpr_to_value(other)) + + def __lshift__(self, other): + return constexpr(self.value << _constexpr_to_value(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + +CONSTEXPR_0 = constexpr(0) + + +def _unwrap_if_constexpr(o): + return o.value if isinstance(o, constexpr) else o + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +# ----------------------- +# dtype +# ----------------------- + + class dtype: SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] @@ -161,8 +331,7 @@ class SIGNEDNESS(Enum): UNSIGNED = 1 def __init__(self, name): - if hasattr(name, 'value'): - name = name.value + name = _unwrap_if_constexpr(name) self.name = name assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name if name in dtype.SINT_TYPES: @@ -388,6 +557,7 @@ def __repr__(self): class pointer_type(dtype): def __init__(self, element_ty: dtype, address_space: int = 1): + element_ty = _unwrap_if_constexpr(element_ty) if not isinstance(element_ty, dtype): raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.') self.element_ty = element_ty @@ -551,166 +721,10 @@ def get_int_dtype(bitwidth: int, signed: bool) -> dtype: # ----------------------- -# constexpr +# tensor # ----------------------- -class const: - """ - This class is used as a type annotation to mark pointers to constant data. - The `store` function cannot be called with a pointer to const. Constness - is part of the pointer type and the usual Triton type consistency rules - apply. For example you cannot have a function that returns constant pointer - in one return statement and non-constant pointer in another. - """ - pass - - -class constexpr: - """ - This class is used to store a value that is known at compile-time. - """ - - def __init__(self, value): - if isinstance(value, constexpr): - self.value = value.value - else: - self.value = value - - def __repr__(self) -> str: - return f"constexpr[{self.value}]" - - def __index__(self): - return self.value - - # In interpreter mode, constant values are not wrapped in constexpr, - # and therefore do not have a .value attribute. - # As a result, from here and below, we need to call the _constexpr_to_value - # function to obtain either constexpr.value or the value itself. - def __add__(self, other): - return constexpr(self.value + _constexpr_to_value(other)) - - def __radd__(self, other): - return constexpr(_constexpr_to_value(other) + self.value) - - def __sub__(self, other): - return constexpr(self.value - _constexpr_to_value(other)) - - def __rsub__(self, other): - return constexpr(_constexpr_to_value(other) - self.value) - - def __mul__(self, other): - return constexpr(self.value * _constexpr_to_value(other)) - - def __mod__(self, other): - return constexpr(self.value % _constexpr_to_value(other)) - - def __rmul__(self, other): - return constexpr(_constexpr_to_value(other) * self.value) - - def __truediv__(self, other): - return constexpr(self.value / _constexpr_to_value(other)) - - def __rtruediv__(self, other): - return constexpr(_constexpr_to_value(other) / self.value) - - def __floordiv__(self, other): - return constexpr(self.value // _constexpr_to_value(other)) - - def __rfloordiv__(self, other): - return constexpr(_constexpr_to_value(other) // self.value) - - def __gt__(self, other): - return constexpr(self.value > _constexpr_to_value(other)) - - def __rgt__(self, other): - return constexpr(_constexpr_to_value(other) > self.value) - - def __ge__(self, other): - return constexpr(self.value >= _constexpr_to_value(other)) - - def __rge__(self, other): - return constexpr(_constexpr_to_value(other) >= self.value) - - def __lt__(self, other): - return constexpr(self.value < _constexpr_to_value(other)) - - def __rlt__(self, other): - return constexpr(_constexpr_to_value(other) < self.value) - - def __le__(self, other): - return constexpr(self.value <= _constexpr_to_value(other)) - - def __rle__(self, other): - return constexpr(_constexpr_to_value(other) <= self.value) - - def __eq__(self, other): - return constexpr(self.value == _constexpr_to_value(other)) - - def __ne__(self, other): - return constexpr(self.value != _constexpr_to_value(other)) - - def __bool__(self): - return bool(self.value) - - def __neg__(self): - return constexpr(-self.value) - - def __and__(self, other): - return constexpr(self.value & _constexpr_to_value(other)) - - def logical_and(self, other): - return constexpr(self.value and _constexpr_to_value(other)) - - def __or__(self, other): - return constexpr(self.value | _constexpr_to_value(other)) - - def __xor__(self, other): - return constexpr(self.value ^ _constexpr_to_value(other)) - - def logical_or(self, other): - return constexpr(self.value or _constexpr_to_value(other)) - - def __pos__(self): - return constexpr(+self.value) - - def __invert__(self): - return constexpr(~self.value) - - def __pow__(self, other): - return constexpr(self.value**_constexpr_to_value(other)) - - def __rpow__(self, other): - return constexpr(_constexpr_to_value(other)**self.value) - - def __rshift__(self, other): - return constexpr(self.value >> _constexpr_to_value(other)) - - def __lshift__(self, other): - return constexpr(self.value << _constexpr_to_value(other)) - - def __not__(self): - return constexpr(not self.value) - - def __iter__(self): - return iter(self.value) - - def __call__(self, *args, **kwds): - return self.value(*args, **kwds) - - -CONSTEXPR_0 = constexpr(0) - - -def check_bit_width(value, shift_value): - if isinstance(value, tensor) and isinstance(shift_value, constexpr): - bitwidth = value.type.scalar.primitive_bitwidth - if shift_value.value >= bitwidth: - warn( - f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." - ) - - class tensor: """Represents an N-dimensional array of values or pointers. @@ -986,8 +1000,8 @@ def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: """ # Triton doesn't like core functions calling other core functions, so we # just copy-paste the implementation of cast here. It's not too bad. - if isinstance(bitcast, constexpr): - bitcast = bitcast.value + dtype = _unwrap_if_constexpr(dtype) + bitcast = _unwrap_if_constexpr(bitcast) if bitcast: return semantic.bitcast(self, dtype, _builder) return semantic.cast(self, dtype, _builder, fp_downcast_rounding) diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index b66ff9a9ae..6da4208967 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -7,10 +7,6 @@ # constexpr utilities -def _unwrap_if_constexpr(o): - return o.value if isinstance(o, core.constexpr) else o - - def _log2(i: core.constexpr): log2 = 0 n = i.value @@ -395,8 +391,8 @@ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTE def _get_flip_dim(dim, shape): - dim = _unwrap_if_constexpr(dim) - shape = _unwrap_if_constexpr(shape) + dim = core._unwrap_if_constexpr(dim) + shape = core._unwrap_if_constexpr(shape) if dim is None: dim = len(shape) - 1 assert dim == len(shape) - 1, "Currently only support flipping the last dimension" From 6c3005c8be4f418776f9434df3eddff397c5cead Mon Sep 17 00:00:00 2001 From: pawelszczerbuk <153013546+pawelszczerbuk@users.noreply.github.com> Date: Wed, 12 Jun 2024 08:48:20 -0700 Subject: [PATCH 3/5] Persistent kernel tutorial for fp16 (#4121) * Added fp16 support to the cuBLAS wrapper * Added test for fp16 cuBLAS matmul * Added fp16 tutorial code and measured performance against torch matmul (identical) * Fixed an issue in cuBLAS config, boosting performance for small shapes. --------- Co-authored-by: Philippe Tillet --- python/test/unit/runtime/test_cublas.py | 18 +- ...-fp8-matmul.py => 09-persistent-matmul.py} | 255 +++++++++++------- third_party/nvidia/include/cublas_instance.h | 26 +- third_party/nvidia/triton_nvidia.cc | 22 +- 4 files changed, 203 insertions(+), 118 deletions(-) rename python/tutorials/{09-persistent-fp8-matmul.py => 09-persistent-matmul.py} (63%) diff --git a/python/test/unit/runtime/test_cublas.py b/python/test/unit/runtime/test_cublas.py index d40afdcf41..a4315fc3cb 100644 --- a/python/test/unit/runtime/test_cublas.py +++ b/python/test/unit/runtime/test_cublas.py @@ -14,9 +14,13 @@ def is_cuda(): @pytest.mark.parametrize("m, n, k", [(16, 16, 16), (32, 16, 16), (16, 32, 16), (16, 16, 32)]) -def test_cublas_fp8(m, n, k, device): - if not (is_cuda() and torch.cuda.get_device_capability()[0] >= 9): - pytest.skip("test_cublas_fp8 is only supported on CUDA with cc >= 90") +@pytest.mark.parametrize("dtype_str", ["float8_e4m3fn", "float16"]) +def test_cublas(m, n, k, dtype_str, device): + dtype = getattr(torch, dtype_str) + if not is_cuda(): + pytest.skip("test_cublas is only supported on CUDA") + if dtype == torch.float8_e4m3fn and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("fp8 is only supported on CUDA with cc >= 90") from triton._C.libtriton import nvidia @@ -29,16 +33,16 @@ def limited_rand(elements, shape): return elements[indices].view(shape) elements = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32, device=device) - a = limited_rand(elements, (m, k)).to(torch.float8_e4m3fn) - b = limited_rand(elements, (k, n)).to(torch.float8_e4m3fn) - c = torch.zeros((m, n), dtype=torch.float8_e4m3fn, device=device) + a = limited_rand(elements, (m, k)).to(dtype) + b = limited_rand(elements, (k, n)).to(dtype) + c = torch.zeros((m, n), dtype=dtype, device=device) b = b.T.contiguous() workspace = torch.empty(workspace_size, dtype=torch.int8, device=device) cublas = nvidia.cublas.CublasLt(workspace) - cublas.fp8_matmul(a, b, c) + cublas.matmul(a, b, c) ref = torch.matmul(a.to(torch.float16), b.to(torch.float16).T) diff --git a/python/tutorials/09-persistent-fp8-matmul.py b/python/tutorials/09-persistent-matmul.py similarity index 63% rename from python/tutorials/09-persistent-fp8-matmul.py rename to python/tutorials/09-persistent-matmul.py index 0b24ce4f42..305b70e242 100644 --- a/python/tutorials/09-persistent-fp8-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -16,11 +16,16 @@ import triton.language as tl import triton.profiler as proton -from triton._C.libtriton import nvidia - if torch.cuda.is_available(): + from triton._C.libtriton import nvidia cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) cublas = nvidia.cublas.CublasLt(cublas_workspace) +else: + cublas = None + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" def _matmul_launch_metadata(grid, kernel, args): @@ -28,7 +33,11 @@ def _matmul_launch_metadata(grid, kernel, args): M, N, K = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" ret["flops8"] = 2. * M * N * K - ret["bytes"] = M * K + N * K + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret["bytes"] = bytes_per_elem * (M * K + N * K) return ret @@ -77,7 +86,10 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, # a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - c = accumulator.to(tl.float8e4nv) + if (c_ptr.dtype == tl.float8e4nv): + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -87,19 +99,23 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, # def matmul(a, b): - BLOCK_SIZE_M = 128 - BLOCK_SIZE_N = 256 - BLOCK_SIZE_K = 128 - GROUP_SIZE = 8 - num_stages = 3 - num_warps = 8 - + configs = { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, + "num_warps": 8 + }, torch.float16: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, + "num_warps": 8 + } + } # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" M, K = a.shape K, N = b.shape + dtype = a.dtype - c = torch.empty((M, N), device=a.device, dtype=torch.float8_e4m3fn) + c = torch.empty((M, N), device=a.device, dtype=dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) matmul_kernel[grid]( @@ -108,12 +124,12 @@ def matmul(a, b): a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # - BLOCK_SIZE_M=BLOCK_SIZE_M, # - BLOCK_SIZE_N=BLOCK_SIZE_N, # - BLOCK_SIZE_K=BLOCK_SIZE_K, # - GROUP_SIZE_M=GROUP_SIZE, # - num_stages=num_stages, # - num_warps=num_warps, # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # ) return c @@ -185,27 +201,33 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - c = accumulator.to(tl.float8e4nv) + if (c_ptr.dtype == tl.float8e4nv): + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) tl.store(c_ptrs, c, mask=c_mask) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) def matmul_persistent(a, b): - BLOCK_SIZE_M = 128 - BLOCK_SIZE_N = 256 - BLOCK_SIZE_K = 128 - GROUP_SIZE = 8 - num_stages = 3 - num_warps = 8 - + configs = { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, + "num_warps": 8 + }, torch.float16: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, + "num_warps": 8 + } + } # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" - + assert a.dtype == b.dtype, "Incompatible dtypes" NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count M, K = a.shape K, N = b.shape + dtype = a.dtype # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=a.dtype) + c = torch.empty((M, N), device=a.device, dtype=dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) matmul_kernel_persistent[grid]( @@ -214,13 +236,13 @@ def matmul_persistent(a, b): a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # - BLOCK_SIZE_M=BLOCK_SIZE_M, # - BLOCK_SIZE_N=BLOCK_SIZE_N, # - BLOCK_SIZE_K=BLOCK_SIZE_K, # - GROUP_SIZE_M=GROUP_SIZE, # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # NUM_SMS=NUM_SMS, # - num_stages=num_stages, # - num_warps=num_warps, # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # ) return c @@ -232,7 +254,9 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # NUM_SMS: tl.constexpr): # + dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -273,44 +297,50 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # offs_k = ki * BLOCK_SIZE_K - a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], tl.float8e4nv) - b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], tl.float8e4nv) + a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) accumulator = tl.dot(a, b.T, accumulator) if ki == k_tiles - 1: - c = accumulator.to(tl.float8e4nv) + c = accumulator.to(dtype) tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) def matmul_tma_persistent(a, b): - BLOCK_SIZE_M = 128 - BLOCK_SIZE_N = 256 - BLOCK_SIZE_K = 128 - GROUP_SIZE = 8 - num_stages = 3 - num_warps = 8 + # Autotuner does not work with TMA. Use manual config. + configs = { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, + "num_warps": 8 + }, torch.float16: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, + "num_warps": 8 + } + } # Check constraints. assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + assert a.dtype == b.dtype, "Incompatible dtypes" M, K = a.shape N, K = b.shape + dtype = a.dtype - c = torch.zeros((M, N), device=a.device, dtype=torch.float8_e4m3fn) + c = torch.zeros((M, N), device=a.device, dtype=dtype) TMA_SIZE = 128 desc_a = np.empty(TMA_SIZE, dtype=np.int8) desc_b = np.empty(TMA_SIZE, dtype=np.int8) desc_c = np.empty(TMA_SIZE, dtype=np.int8) - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), M, K, BLOCK_SIZE_M, BLOCK_SIZE_K, - a.element_size(), desc_a) - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), N, K, BLOCK_SIZE_N, BLOCK_SIZE_K, - b.element_size(), desc_b) - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, - c.element_size(), desc_c) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), M, K, configs[dtype]["BLOCK_SIZE_M"], + configs[dtype]["BLOCK_SIZE_K"], a.element_size(), desc_a) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), N, K, configs[dtype]["BLOCK_SIZE_N"], + configs[dtype]["BLOCK_SIZE_K"], b.element_size(), desc_b) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), M, N, configs[dtype]["BLOCK_SIZE_M"], + configs[dtype]["BLOCK_SIZE_N"], c.element_size(), desc_c) desc_a = torch.tensor(desc_a, device="cuda") desc_b = torch.tensor(desc_b, device="cuda") @@ -322,13 +352,14 @@ def matmul_tma_persistent(a, b): matmul_kernel_tma_persistent[grid]( desc_a, desc_b, desc_c, # M, N, K, # - BLOCK_SIZE_M=BLOCK_SIZE_M, # - BLOCK_SIZE_N=BLOCK_SIZE_N, # - BLOCK_SIZE_K=BLOCK_SIZE_K, # - GROUP_SIZE_M=GROUP_SIZE, # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + FP8_OUTPUT=dtype == torch.float8_e4m3fn, # NUM_SMS=NUM_SMS, # - num_stages=num_stages, # - num_warps=num_warps, # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # ) return c @@ -336,30 +367,48 @@ def matmul_tma_persistent(a, b): def cublas_matmul(a, b): # Check constraints. assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed - M, K = a.shape N, K = b.shape + dtype = a.dtype + c = torch.empty((M, N), device=a.device, dtype=dtype) + bytes_per_elem = a.element_size() + flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops" + with proton.scope(f"cublas M={M}, N={N}, K={K}", + {"bytes": bytes_per_elem * (M * K + N * K), flops_str: 2. * M * N * K}): + cublas.matmul(a, b, c) + return c - c = torch.empty((M, N), device=a.device, dtype=torch.float8_e4m3fn) - with proton.scope(f"cublas M={M}, N={N}, K={K}", {"bytes": M * K + N * K, "flops8": 2. * M * N * K}): - cublas.fp8_matmul(a, b, c) +def torch_matmul(a, b): + M, K = a.shape + N, K = b.shape + dtype = a.dtype + bytes_per_elem = a.element_size() + flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops" + with proton.scope(f"torch M={M}, N={N}, K={K}", + {"bytes": bytes_per_elem * (M * K + N * K), flops_str: 2. * M * N * K}): + c = torch.matmul(a, b.T) return c -def bench(K, reps=10): +def bench(K, dtype, reps=10): M = 8192 N = 8192 - a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) - b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) + a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) + b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) b = b.T.contiguous() proton.activate(0) - for _ in range(reps): - cublas_matmul(a, b) + if cublas is not None: + for _ in range(reps): + cublas_matmul(a, b) time.sleep(0.01) + if dtype == torch.float16: + for _ in range(reps): + torch_matmul(a, b) + time.sleep(0.01) for _ in range(reps): matmul(a, b.T) time.sleep(0.01) @@ -373,46 +422,60 @@ def bench(K, reps=10): proton.deactivate(0) -def validate(M, N, K): - a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) - b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) +def validate(M, N, K, dtype): + a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) + b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) b = b.T.contiguous() - cublas_result = cublas_matmul(a, b) + + torch_result = torch_matmul(a, b) if dtype == torch.float16 else None + cublas_result = cublas_matmul(a, b) if cublas is not None else None naive_result = matmul(a, b.T) persistent_result = matmul_persistent(a, b.T) tma_persistent_result = matmul_tma_persistent(a, b) - naive_vs_cublas = "✅" if torch.allclose(naive_result.to(torch.float16), cublas_result.to(torch.float16), - atol=1.0) else "❌" + if torch_result is not None: + naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16), + atol=1.0) else "❌" + if cublas_result is not None: + naive_vs_cublas = "✅" if torch.allclose(naive_result.to(torch.float16), cublas_result.to(torch.float16), + atol=1.0) else "❌" naive_vs_persistent = "✅" if torch.allclose(naive_result.to(torch.float16), persistent_result.to(torch.float16), atol=1.0) else "❌" naive_vs_tma_persistent = "✅" if torch.allclose(cublas_result.to(torch.float16), tma_persistent_result.to(torch.float16), atol=1.0) else "❌" - print( - f"M={M}, N={N}, K={K} verification naive vs: cublas {naive_vs_cublas}, persistent {naive_vs_persistent}, TMA persistent {naive_vs_tma_persistent}" - ) + print(f"M={M}, N={N}, K={K} verification naive vs: ", end="") + if torch_result is not None: + print(f"torch: {naive_vs_torch} ", end="") + if cublas_result is not None: + print(f"cublas: {naive_vs_cublas} ", end="") + print(f"persistent: {naive_vs_persistent} ", end="") + print(f"TMA persistent: {naive_vs_tma_persistent}") if __name__ == "__main__": - if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9: - parser = argparse.ArgumentParser() - parser.add_argument("-K", type=int, required=False) - parser.add_argument("--K_range", type=int, nargs=2) - parser.add_argument("--K_step", type=int, default=512) - args = parser.parse_args() - - if args.K: - args.K_range = [args.K, args.K] - args.K_step = 1 # doesn't matter as long as it's not 0 - - torch.manual_seed(0) - - validate(32, 32, 32) - validate(8192, 8192, 512) - - proton.start("matmul", hook="triton") - for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): - bench(K) - proton.finalize() - else: - print("This tutorial fp8_matmul is only supported on CUDA with cc >= 90") + parser = argparse.ArgumentParser() + parser.add_argument("-K", type=int, required=False, default=512) + parser.add_argument("--K_range", type=int, nargs=2) + parser.add_argument("--K_step", type=int, default=512) + parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp8") + args = parser.parse_args() + + if args.prec == 'fp8' and (not hasattr(torch, "float8_e4m3fn") or not is_cuda()): + print("This example requires CUDA with fp8 support.") + exit(1) + + dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16 + + if args.K and args.K_range is None: + args.K_range = [args.K, args.K] + args.K_step = 1 # doesn't matter as long as it's not 0 + + torch.manual_seed(0) + + validate(32, 32, 32, dtype) + validate(8192, 8192, 512, dtype) + + proton.start("matmul", hook="triton") + for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + bench(K, dtype) + proton.finalize() diff --git a/third_party/nvidia/include/cublas_instance.h b/third_party/nvidia/include/cublas_instance.h index 3ab42f482a..d79d4d76bf 100644 --- a/third_party/nvidia/include/cublas_instance.h +++ b/third_party/nvidia/include/cublas_instance.h @@ -120,7 +120,8 @@ class CublasLtInstance { } // Simple wrapper around the cublasLtMatmul function - void fp8Matmul_impl(int m, int n, int k, uint64_t A, uint64_t B, uint64_t D) { + void matmul_impl(int m, int n, int k, uint64_t A, uint64_t B, uint64_t D, + cudaDataType_t dtype) { cublasLtMatmulDesc_t matmulDesc = NULL; cublasOperation_t transa = CUBLAS_OP_T; @@ -140,14 +141,16 @@ class CublasLtInstance { matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); successOrExit(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); - successOrExit(cublasLtMatmulDescSetAttribute( - matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccum, - sizeof(fastAccum))); + if (dtype == CUDA_R_8F_E4M3) { + successOrExit(cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccum, + sizeof(fastAccum))); + } - successOrExit(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8F_E4M3, k, m, k)); - successOrExit(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8F_E4M3, k, n, k)); - successOrExit(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16BF, m, n, m)); - successOrExit(cublasLtMatrixLayoutCreate(&Ddesc, CUDA_R_8F_E4M3, m, n, m)); + successOrExit(cublasLtMatrixLayoutCreate(&Adesc, dtype, k, m, k)); + successOrExit(cublasLtMatrixLayoutCreate(&Bdesc, dtype, k, n, k)); + successOrExit(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16F, m, n, m)); + successOrExit(cublasLtMatrixLayoutCreate(&Ddesc, dtype, m, n, m)); successOrExit(cublasLtMatmulAlgoGetHeuristic( ltHandle, matmulDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, @@ -159,12 +162,10 @@ class CublasLtInstance { float alpha = 1.0f; float beta = 0.0f; - successOrExit(cublasLtMatmul(ltHandle, matmulDesc, &alpha, (void *)A, Adesc, (void *)B, Bdesc, &beta, nullptr, Cdesc, (void *)D, Ddesc, &heuristicResult.algo, (void *)workspace, workspaceSize, 0)); - if (Ddesc) successOrExit(cublasLtMatrixLayoutDestroy(Ddesc)); if (Cdesc) @@ -201,10 +202,11 @@ class CublasLtInstance { // *will-not* transpose the matrices, so the caller is responsible for // ensuring that the matrices are in the correct format and have the correct // dimensions. - void fp8Matmul(int m, int n, int k, uint64_t A, uint64_t B, uint64_t C) { + void matmul(int m, int n, int k, uint64_t A, uint64_t B, uint64_t C, + cudaDataType_t dtype) { // CUDA is column-major, while triton is row-major, therefore we need to // reverse the order of the matrices ( A * B = (B^T * A^T)^T ). - fp8Matmul_impl(n, m, k, B, A, C); + matmul_impl(n, m, k, B, A, C, dtype); } }; diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 7881f94b4b..1269dcda00 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -94,8 +94,8 @@ void init_triton_nvidia(py::module &&m) { workspace.attr("element_size")().cast(); return new CublasLtInstance(wrk_ptr, wrk_size); })) - .def("fp8_matmul", [](CublasLtInstance &self, py::object &A, - py::object &B, py::object &C) { + .def("matmul", [](CublasLtInstance &self, py::object &A, py::object &B, + py::object &C) { auto A_ptr = A.attr("data_ptr")().cast(); auto B_ptr = B.attr("data_ptr")().cast(); auto C_ptr = C.attr("data_ptr")().cast(); @@ -104,6 +104,21 @@ void init_triton_nvidia(py::module &&m) { auto B_shape = B.attr("shape").cast>(); auto C_shape = C.attr("shape").cast>(); + auto A_dtype = A.attr("dtype").attr("__str__")().cast(); + auto B_dtype = B.attr("dtype").attr("__str__")().cast(); + auto C_dtype = C.attr("dtype").attr("__str__")().cast(); + + assert(A_dtype == B_dtype && A_dtype == C_dtype); + assert(A_dtype == "torch.float8_e4m3fn" || A_dtype == "torch.float16"); + + std::string dtype_str = A_dtype.substr(A_dtype.find_last_of('.') + 1); + cudaDataType_t dtype; + if (dtype_str == "float8_e4m3fn") { + dtype = CUDA_R_8F_E4M3; + } else if (dtype_str == "float16") { + dtype = CUDA_R_16F; + } + if (A_shape.size() != 2 || B_shape.size() != 2 || C_shape.size() != 2) { throw std::runtime_error("Only 2D matrices are supported."); } @@ -140,6 +155,7 @@ void init_triton_nvidia(py::module &&m) { "that B needs to be transposed."); } - self.fp8Matmul(A_shape[0], B_shape[0], A_shape[1], A_ptr, B_ptr, C_ptr); + self.matmul(A_shape[0], B_shape[0], A_shape[1], A_ptr, B_ptr, C_ptr, + dtype); }); } From c1776fadc77ca279023142d82aea155b2886a2b7 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Wed, 12 Jun 2024 15:24:52 -0400 Subject: [PATCH 4/5] [PROTON][AMD] Add Proton HIP GPU Utilization Metrics (#4119) This PR adds Proton HIP GPU utilization metrics and an associated test. --- .../proton/csrc/include/Driver/Device.h | 4 +- .../proton/csrc/include/Driver/GPU/HipApi.h | 5 ++ .../proton/csrc/lib/Driver/GPU/CudaApi.cpp | 3 +- .../proton/csrc/lib/Driver/GPU/HipApi.cpp | 26 +++++++- third_party/proton/proton/viewer.py | 11 +++- .../test/{example.json => example_cuda.json} | 4 +- third_party/proton/test/example_hip.json | 64 +++++++++++++++++++ third_party/proton/test/test_viewer.py | 25 +++++++- 8 files changed, 128 insertions(+), 14 deletions(-) rename third_party/proton/test/{example.json => example_cuda.json} (96%) create mode 100644 third_party/proton/test/example_hip.json diff --git a/third_party/proton/csrc/include/Driver/Device.h b/third_party/proton/csrc/include/Driver/Device.h index 79a9bf11ec..3e414c824b 100644 --- a/third_party/proton/csrc/include/Driver/Device.h +++ b/third_party/proton/csrc/include/Driver/Device.h @@ -27,13 +27,13 @@ struct Device { uint64_t memoryClockRate; // khz uint64_t busWidth; uint64_t numSms; - uint64_t arch; + std::string arch; Device() = default; Device(DeviceType type, uint64_t id, uint64_t clockRate, uint64_t memoryClockRate, uint64_t busWidth, uint64_t numSms, - uint64_t arch) + std::string arch) : type(type), id(id), clockRate(clockRate), memoryClockRate(memoryClockRate), busWidth(busWidth), numSms(numSms), arch(arch) {} diff --git a/third_party/proton/csrc/include/Driver/GPU/HipApi.h b/third_party/proton/csrc/include/Driver/GPU/HipApi.h index 942e310c13..fadb9c425c 100644 --- a/third_party/proton/csrc/include/Driver/GPU/HipApi.h +++ b/third_party/proton/csrc/include/Driver/GPU/HipApi.h @@ -16,8 +16,13 @@ hipError_t deviceGetAttribute(int *value, hipDeviceAttribute_t attribute, template hipError_t getDeviceCount(int *count); +template +hipError_t getDeviceProperties(hipDeviceProp_t *prop, int deviceId); + Device getDevice(uint64_t index); +const std::string getHipArchName(uint64_t index); + const char *getKernelNameRef(const hipFunction_t f); const char *getKernelNameRefByPtr(const void *hostFunction, hipStream_t stream); diff --git a/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp index d58127dcb5..aae8b4ceb3 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp @@ -49,7 +49,8 @@ Device getDevice(uint64_t index) { int minor; cuda::deviceGetAttribute( &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); - auto arch = major * 10 + minor; + std::string arch = std::to_string(major * 10 + minor); + return Device(DeviceType::CUDA, index, clockRate, memoryClockRate, busWidth, numSms, arch); } diff --git a/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp index 80e3fbab59..18de4a4f62 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp @@ -23,6 +23,9 @@ DEFINE_DISPATCH(ExternLibHip, deviceGetAttribute, hipDeviceGetAttribute, int *, DEFINE_DISPATCH(ExternLibHip, getDeviceCount, hipGetDeviceCount, int *); +DEFINE_DISPATCH(ExternLibHip, getDeviceProperties, hipGetDeviceProperties, + hipDeviceProp_t *, int); + Device getDevice(uint64_t index) { int clockRate; (void)hip::deviceGetAttribute(&clockRate, hipDeviceAttributeClockRate, @@ -37,13 +40,30 @@ Device getDevice(uint64_t index) { (void)hip::deviceGetAttribute( &smCount, hipDeviceAttributeMultiprocessorCount, index); - // TODO: Compute capability is a NVIDIA concept. It doesn't map naturally to - // AMD GPUs. Figure out a better way to support this. - uint64_t arch = 0; + std::string arch = getHipArchName(index); + return Device(DeviceType::HIP, index, clockRate, memoryClockRate, busWidth, smCount, arch); } +// TODO: hipDeviceProp_t was updated to point from hipDeviceProp_tR0000 -> +// hipDeviceProp_tR0600 as part of a breaking API change in Rocm 6.0 +// https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/driver.c +// uses hipDeviceProp_tR0000 and imports the hip_deprecated.h header file to be +// be back compatible with ROCm 5.x. PyTorch stills needs to support 5.x and the +// hipDeviceProp_tR0600 symbol does not exist pre-Rocm 6.0. Calling +// hipDeviceProp_tR0000 here with Rocm 6.1 causes a stack corruption. Therefore +// were will use hipDeviceProp_t and investigate if we can unify the definitions +// in the two files. + +const std::string getHipArchName(uint64_t index) { + hipDeviceProp_t devProp; + (void)hip::getDeviceProperties(&devProp, index); + std::string gcnArchName(devProp.gcnArchName); + std::string hipArch = gcnArchName.substr(0, 6); + return hipArch; +} + const char *getKernelNameRef(const hipFunction_t f) { typedef const char *(*hipKernelNameRef_t)(const hipFunction_t); static hipKernelNameRef_t func = nullptr; diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index fee1b1fa61..3ef3a4c93a 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -43,14 +43,19 @@ def get_min_time_flops(df, device_info): continue max_flops = 0 if device_type == "CUDA": - if arch == 80: + if arch == "80": max_flops = 624e12 / (width / 8) - elif arch == 89: + elif arch == "89": # TODO(Keren): Implement fp16 acc-> 660.6 fp8 max_flops = (330.3 * 1e12) / (width / 8) - elif arch == 90: + elif arch == "90": # 114 sms and 1755mhz is the base number of sms and clock rate of H100 pcie max_flops = ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) / (width / 8) + elif device_type == "HIP": + if arch == "gfx90a": + max_flops = 383e12 / (width / 8) + elif arch == "gfx941" or arch == "gfx942": + max_flops = 2614.9e12 / (width / 8) else: raise ValueError(f"Unsupported device type: {device_type}") min_time_flops.loc[idx, "min_time"] += device_frames[f"flops{width}"].fillna(0) / max_flops diff --git a/third_party/proton/test/example.json b/third_party/proton/test/example_cuda.json similarity index 96% rename from third_party/proton/test/example.json rename to third_party/proton/test/example_cuda.json index ea65853cd1..9e148ff791 100644 --- a/third_party/proton/test/example.json +++ b/third_party/proton/test/example_cuda.json @@ -46,14 +46,14 @@ { "CUDA": { "0": { - "arch": 89, + "arch": "89", "bus_width": 384, "clock_rate": 2625000, "memory_clock_rate": 10501000, "num_sms": 128 }, "1": { - "arch": 90, + "arch": "90", "bus_width": 6144, "clock_rate": 1980000, "memory_clock_rate": 2619000, diff --git a/third_party/proton/test/example_hip.json b/third_party/proton/test/example_hip.json new file mode 100644 index 0000000000..2fcfad3c5d --- /dev/null +++ b/third_party/proton/test/example_hip.json @@ -0,0 +1,64 @@ + [ + { + "children": [ + { + "children": [], + "frame": { + "name": "foo0", + "type": "function" + }, + "metrics": { + "Count": 1, + "DeviceId": "1", + "DeviceType": "HIP", + "Time (ns)": 204800, + "flops8": 1e11, + "bytes": 1e8 + } + }, + { + "children": [], + "frame": { + "name": "foo1", + "type": "function" + }, + "metrics": { + "Count": 1, + "DeviceId": "0", + "DeviceType": "HIP", + "Time (ns)": 204800, + "flops8": 1e10, + "bytes": 1e7 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "Count": 0, + "Time (ns)": 0, + "flops8": 0, + "bytes": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + }, + "1": { + "arch": "gfx941", + "bus_width": 8192, + "clock_rate": 5200000, + "memory_clock_rate": 2525000, + "num_sms": 304 + } + } + } +] diff --git a/third_party/proton/test/test_viewer.py b/third_party/proton/test/test_viewer.py index 57295c7919..63a74b06ce 100644 --- a/third_party/proton/test/test_viewer.py +++ b/third_party/proton/test/test_viewer.py @@ -3,7 +3,8 @@ import numpy as np file_path = __file__ -example_file = file_path.replace("test_viewer.py", "example.json") +cuda_example_file = file_path.replace("test_viewer.py", "example_cuda.json") +hip_example_file = file_path.replace("test_viewer.py", "example_hip.json") def test_help(): @@ -13,7 +14,7 @@ def test_help(): def test_min_time_flops(): - with open(example_file, "r") as f: + with open(cuda_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_flops(gf.dataframe, device_info) device0_idx = gf.dataframe["DeviceId"] == "0" @@ -22,10 +23,19 @@ def test_min_time_flops(): np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000025]], atol=1e-5) # sm90 np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[0.00005]], atol=1e-5) + with open(hip_example_file, "r") as f: + gf, _, device_info = get_raw_metrics(f) + ret = get_min_time_flops(gf.dataframe, device_info) + device0_idx = gf.dataframe["DeviceId"] == "0" + device1_idx = gf.dataframe["DeviceId"] == "1" + # MI200 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000026]], atol=1e-5) + # MI300 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[0.000038]], atol=1e-5) def test_min_time_bytes(): - with open(example_file, "r") as f: + with open(cuda_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_bytes(gf.dataframe, device_info) device0_idx = gf.dataframe["DeviceId"] == "0" @@ -34,3 +44,12 @@ def test_min_time_bytes(): np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[9.91969e-06]], atol=1e-6) # sm90 np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[2.48584e-05]], atol=1e-6) + with open(hip_example_file, "r") as f: + gf, _, device_info = get_raw_metrics(f) + ret = get_min_time_bytes(gf.dataframe, device_info) + device0_idx = gf.dataframe["DeviceId"] == "0" + device1_idx = gf.dataframe["DeviceId"] == "1" + # MI200 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[6.10351e-06]], atol=1e-6) + # MI300 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[1.93378e-05]], atol=1e-6) From 49014e72b25d7cdefe254f1235a67275fb30fd60 Mon Sep 17 00:00:00 2001 From: Andrew James Date: Wed, 12 Jun 2024 14:49:45 -0500 Subject: [PATCH 5/5] [FRONTEND] Add TRITON_DISABLE_PYTHON_STACKTRACE envvar (#4130) Used to disable stacktrace handler registration in the python module. fixes #4129 The stacktrace handler disrupts normal signal propagation, so we introduce an environment variable to disable registration at import time. --- include/triton/Tools/Sys/GetEnv.hpp | 3 +++ python/src/llvm.cc | 7 +++++++ python/src/main.cc | 3 ++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 8b51143e39..7d0d41075a 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -33,7 +33,10 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { }; inline const std::set CACHE_NEUTRAL_ENV_VARS = { + // clang-format off "TRITON_REPRODUCER_PATH", + "TRITON_DISABLE_PYTHON_STACKTRACE" + // clang-format on }; namespace tools { diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 851b61eb76..f4c023f232 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -17,6 +17,7 @@ #include "llvm/Passes/PassPlugin.h" #include "llvm/Passes/StandardInstrumentations.h" #include "llvm/Support/CodeGen.h" +#include "llvm/Support/Signals.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" @@ -420,3 +421,9 @@ void init_triton_llvm(py::module &&m) { } }); } + +void init_triton_stacktrace_hook(pybind11::module &m) { + if (!mlir::triton::tools::getBoolEnv("TRITON_DISABLE_PYTHON_STACKTRACE")) { + llvm::sys::PrintStackTraceOnErrorSignal("triton_python"); + } +} diff --git a/python/src/main.cc b/python/src/main.cc index fc142e70d8..82289edc0f 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -40,11 +40,12 @@ void init_triton_ir(pybind11::module &&m); void init_triton_llvm(pybind11::module &&m); void init_triton_interpreter(pybind11::module &&m); void init_triton_passes(pybind11::module &&m); +void init_triton_stacktrace_hook(pybind11::module &m); FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; - llvm::sys::PrintStackTraceOnErrorSignal("triton_python"); + init_triton_stacktrace_hook(m); init_triton_env_vars(m); init_triton_ir(m.def_submodule("ir")); init_triton_passes(m.def_submodule("passes"));