Skip to content

Commit

Permalink
Merge commit '49014e72b25d7cdefe254f1235a67275fb30fd60'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jun 12, 2024
2 parents 3d79d38 + 49014e7 commit 078f45b
Show file tree
Hide file tree
Showing 28 changed files with 539 additions and 316 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ jobs:
run_tutorial_test "06-fused-attention"
run_tutorial_test "07-extern-functions"
run_tutorial_test "08-grouped-gemm"
run_tutorial_test "09-experimental-block-pointer"
TRITON_INTEL_ENABLE_BLOCK_PTR=1 run_tutorial_test "09-experimental-block-pointer"
run_tutorial_test "10-experimental-block-pointer"
TRITON_INTEL_ENABLE_BLOCK_PTR=1 run_tutorial_test "10-experimental-block-pointer"
- name: Run CXX unittests
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/conda-basekit-build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
s/\(.*03-matrix-multiplication\)/#\1/
s/\(.*07-extern-functions\)/#\1/
s/\(.*08-grouped-gemm\)/#\1/
s/\(.*09-experimental-block-pointer\)/#\1/
s/\(.*10-experimental-block-pointer\)/#\1/
' scripts/test-triton.sh
conda create -y -n dpcpp -c intel -c conda-forge dpcpp_linux-64=2024.1.*
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/xetla_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Gemm benchmark
============================
This benchmark is come from the Triton tutorial 09-experimental-block-pointer.py
This benchmark is come from the Triton tutorial 10-experimental-block-pointer.py
To compare the performance to XeTLA kernel.
"""
Expand Down
5 changes: 5 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class TargetInfoBase {
StringRef message, StringRef file, StringRef func,
int line) const = 0;

// Whether to enable linear layout. This is a per-backend temporary escape
// hatch to disable linear layout while figuring out issues. Eventually we
// want to enable linear layout everywhere and delete this control.
virtual bool enableLinearLayout() const { return true; }

virtual ~TargetInfoBase() {}
};
} // namespace mlir::triton
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
bool allowLL = true) {
// Eventually the LinearLayout path will be the only one. For now we allow
// both paths so we can test that they produce the same results.
if (allowLL) {
if (allowLL && target.enableLinearLayout()) {
std::optional<SmallVector<SmallVector<Value>>> llOffsets =
emitIndicesUsingLinearLayouts(loc, rewriter, target, layout, type,
withCTAOffset);
Expand Down
3 changes: 3 additions & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
};

inline const std::set<std::string> CACHE_NEUTRAL_ENV_VARS = {
// clang-format off
"TRITON_REPRODUCER_PATH",
"TRITON_DISABLE_PYTHON_STACKTRACE"
// clang-format on
};

namespace tools {
Expand Down
7 changes: 7 additions & 0 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/Passes/PassPlugin.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
Expand Down Expand Up @@ -423,3 +424,9 @@ void init_triton_llvm(py::module &&m) {
}
});
}

void init_triton_stacktrace_hook(pybind11::module &m) {
if (!mlir::triton::tools::getBoolEnv("TRITON_DISABLE_PYTHON_STACKTRACE")) {
llvm::sys::PrintStackTraceOnErrorSignal("triton_python");
}
}
3 changes: 2 additions & 1 deletion python/src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ void init_triton_ir(pybind11::module &&m);
void init_triton_llvm(pybind11::module &&m);
void init_triton_interpreter(pybind11::module &&m);
void init_triton_passes(pybind11::module &&m);
void init_triton_stacktrace_hook(pybind11::module &m);
FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE)

PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
llvm::sys::PrintStackTraceOnErrorSignal("triton_python");
init_triton_stacktrace_hook(m);
init_triton_env_vars(m);
init_triton_ir(m.def_submodule("ir"));
init_triton_passes(m.def_submodule("passes"));
Expand Down
18 changes: 11 additions & 7 deletions python/test/unit/runtime/test_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ def is_cuda():


@pytest.mark.parametrize("m, n, k", [(16, 16, 16), (32, 16, 16), (16, 32, 16), (16, 16, 32)])
def test_cublas_fp8(m, n, k, device):
if not (is_cuda() and torch.cuda.get_device_capability()[0] >= 9):
pytest.xfail("test_cublas_fp8 is only supported on CUDA with cc >= 90")
@pytest.mark.parametrize("dtype_str", ["float8_e4m3fn", "float16"])
def test_cublas(m, n, k, dtype_str, device):
dtype = getattr(torch, dtype_str)
if not is_cuda():
pytest.skip("test_cublas is only supported on CUDA")
if dtype == torch.float8_e4m3fn and torch.cuda.get_device_capability()[0] < 9:
pytest.skip("fp8 is only supported on CUDA with cc >= 90")

from triton._C.libtriton import nvidia

Expand All @@ -29,16 +33,16 @@ def limited_rand(elements, shape):
return elements[indices].view(shape)

elements = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32, device=device)
a = limited_rand(elements, (m, k)).to(torch.float8_e4m3fn)
b = limited_rand(elements, (k, n)).to(torch.float8_e4m3fn)
c = torch.zeros((m, n), dtype=torch.float8_e4m3fn, device=device)
a = limited_rand(elements, (m, k)).to(dtype)
b = limited_rand(elements, (k, n)).to(dtype)
c = torch.zeros((m, n), dtype=dtype, device=device)

b = b.T.contiguous()

workspace = torch.empty(workspace_size, dtype=torch.int8, device=device)

cublas = nvidia.cublas.CublasLt(workspace)
cublas.fp8_matmul(a, b, c)
cublas.matmul(a, b, c)

ref = torch.matmul(a.to(torch.float16), b.to(torch.float16).T)

Expand Down
5 changes: 1 addition & 4 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .. import language
from .._C.libtriton import ir
from ..language import constexpr, tensor, str_to_ty
from ..language.core import _unwrap_if_constexpr
from ..runtime.jit import _normalize_ty
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
Expand Down Expand Up @@ -62,10 +63,6 @@ def _is_list_like(o: Any) -> bool:
return isinstance(o, (list, tuple))


def _unwrap_if_constexpr(o: Any):
return o.value if isinstance(o, constexpr) else o


def _check_fn_args(node, fn, args):
if fn.noinline:
for idx, arg in enumerate(args):
Expand Down
Loading

0 comments on commit 078f45b

Please sign in to comment.