Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GEMM-perf] matmul is slower when one input needs to be transposed #1795

Closed
mgrabban opened this issue Aug 7, 2024 · 16 comments · Fixed by #2347 or #2443
Closed

[GEMM-perf] matmul is slower when one input needs to be transposed #1795

mgrabban opened this issue Aug 7, 2024 · 16 comments · Fixed by #2347 or #2443

Comments

@mgrabban
Copy link

mgrabban commented Aug 7, 2024

I find that matmul(X, Y) is ~4X slower when either X or Y needs to be transposed.

So I have a matmul kernel that is similar to the one in triton tutorial here.

That kernel is launched from this code

def fused_mul_add(X, Y, b, transpose_x, transpose_y):
    if transpose_x:
        K, M = X.shape
        Xstride0, Xstride1 = X.stride(1), X.stride(0)
    else:
        M, K = X.shape
        Xstride0, Xstride1 = X.stride(0), X.stride(1)
    if transpose_y:
        N, _ = Y.shape
        Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
    else:
        _, N = Y.shape
        Wstride0, Wstride1 = Y.stride(0), Y.stride(1)
    # Allocates output.
    Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel_with_block_pointers[grid](
        X, Y, b, Z,
        M, N, K,
        Xstride0, Xstride1,
        Wstride0, Wstride1,
        Z.stride(0), Z.stride(1),
        BIAS_REQD=b is not None,
    )

    return Z

Note that the strides of X or Y are switched (e.g. Xstride0, Xstride1 = X.stride(1), X.stride(0)) if it needs to be transposed.

I notice ff neither needs to be transposed, performance is similar to PyTorch's matmul perf but when either needs to be transposed (so that strides are switched for that input), performance is 4X slower.

This does not happen on CUDA devices. So can you please look into making it efficient for XPU devices as well?

@vlad-penkin
Copy link
Contributor

@mgrabban thanks for the feedback. Could please provide information on you runtime environment:

  • GPU HW Model. . Please note that all matmul performance optimizations are only available for the PVC as of now.
  • Agama Driver version. Please note that all matmul performance optimizations are only available with the latest Rolling Driver.
  • Pytorch or IPEX version or commit id. Please note that regular IPEX is not supported, we are at the final stages of deprecating dependency on the "special IPEX test proxy" and switching fully to the Upstream PyTorch
  • oneAPI Basekit or PyTorch Dependency bundle version. Please note that regular oneAPI Basekit is not supported as of now.

@mgrabban
Copy link
Author

@mgrabban thanks for the feedback. Could please provide information on you runtime environment:

  • GPU HW Model. . Please note that all matmul performance optimizations are only available for the PVC as of now.

I am doing this on PVC (Intel GPU Max 1550).

  • Agama Driver version. Please note that all matmul performance optimizations are only available with the latest Rolling Driver.

My Agama version is 950.4

  • Pytorch or IPEX version or commit id. Please note that regular IPEX is not supported, we are at the final stages of deprecating dependency on the "special IPEX test proxy" and switching fully to the Upstream PyTorch

I am using the PyTorch/IPEX installed using script inside scripts folder

I am using oneAPI/2024.2.0

@vlad-penkin vlad-penkin self-assigned this Aug 17, 2024
@vlad-penkin
Copy link
Contributor

Could you please retest with the

To build Upstream PyTorch from source run the following script.

./scripts/compile-pytorch-ipex.sh --pytorch --upstream-pytorch --source

Our Tutorials code still have import intel_extension_for_pytorch line. You can either comment it out or install the dummy no-op ipex using this script:

from os import chdir, makedirs
from tempfile import TemporaryDirectory
from subprocess import run

with TemporaryDirectory() as tmpdir:
    pkg = "intel_extension_for_pytorch"
    chdir(tmpdir)
    makedirs(pkg, exist_ok=True)
    files = {
        f"{pkg}/__init__.py": "",
        "setup.py": (
            "from setuptools import setup, find_packages\n"
            f"setup(name='{pkg}', version='2', packages=find_packages())"
        ),
        "project.toml": (
            "[build-system]\n"
            "requires = [\"setuptools\", \"wheel\"]\n"
            "build-backend = \"setuptools.build_meta\""
        )
    }
    for file, content in files.items():
        with open(file, "w") as f:
            f.write(content)
    cmds = [
        f"pip uninstall -y {pkg}",
        "pip install build",
        "python -m build .",
        f"pip install dist/{pkg}-2-py3-none-any.whl"
    ]
    for cmd in cmds:
        run(cmd.split(), check=True)

@mgrabban
Copy link
Author

@vlad-penkin the pytorch-ipex installation script keeps changing.
Yesterday I tried your command, it installs but the matmul run was failing due to ipex import. I did comment it out.

Today the install itself fails. I tried
./scripts/compile-pytorch-ipex.sh --upstream-pytorch --source --venv
And it gave this error

CMake Error at third_party/kineto/libkineto/src/plugin/xpupti/CMakeLists.txt:23 (find_package):
  By not providing "FindPti.cmake" in CMAKE_MODULE_PATH this project has
  asked CMake to find a package configuration file provided by "Pti", but
  CMake did not find one.

  Could not find a package configuration file provided by "Pti" with any of
  the following names:

    PtiConfig.cmake
    pti-config.cmake

  Add the installation prefix of "Pti" to CMAKE_PREFIX_PATH or set "Pti_DIR"
  to a directory containing one of the above files.  If "Pti" provides a
  separate development package or SDK, be sure it has been installed.

Are you able to run matmul/triton benchmarck.py from your end?

@mgrabban
Copy link
Author

The installation issue is now fixed but timing is now broken so triton perf time is showing as 0.0. I think this is the reason
WARNING:root:Wall time is used instead of elapsed_time (not supported). The timing measurements could be innacurate.

@vlad-penkin
Copy link
Contributor

@Mgarban thanks for the update!

See below my notes:

  1. You are seeing the warning, because pytorch you are using does not support XPUEvent elapsed_time feature. To enable it you need to build pytorch with the additional PR's recommended by us - ./scripts/compile-pytorch-ipex.sh --upstream-pytorch --venv
  2. To build upstream pytorch you need to install and activate matching PTI. it's no longer optional for upstream pytorch build
  3. For more details see the discussion on a similar topic in:

@mgrabban
Copy link
Author

@vlad-penkin I'm now able to run and get perf data as shown below

{'torch_inf': 0.15876160562038422,
 'torch_train': 0.42427361011505127,
 'triton_inf': 0.1633344143629074,
 'triton_train': 1.8272528648376465}

As you can see, the issue is not resolved: inference involving matmul(A, B) is performant while training that additionally involves matmul(A, B^T) is not.

@vlad-penkin vlad-penkin removed their assignment Aug 27, 2024
@arunjose696 arunjose696 self-assigned this Sep 5, 2024
@arunjose696
Copy link
Contributor

arunjose696 commented Sep 5, 2024

@mgrabban , what are the sizes of Matrices you are using. I could not run triton_inf or triton_train as they were not shared. However I tried running the matmul kernel in triton tutorials with and without transposing both inputs a and b for various matrix sizes.

I used this code to launch my kernel, It is just slightly modified version of your code except I do just a multiply instead of fused_mul_add

def matmul(X, Y,transpose_x,transpose_y,  activation=""):
  
    if transpose_x:
        K, M = X.shape
        Xstride0, Xstride1 = X.stride(1), X.stride(0)
    else:
        M, K = X.shape
        Xstride0, Xstride1 = X.stride(0), X.stride(1)
    if transpose_y:
        N, _ = Y.shape
        Ystride0, Ystride1 = Y.stride(1), Y.stride(0)
    else:
        _, N = Y.shape
        Ystride0, Ystride1 = Y.stride(0), Y.stride(1)
    
    # Allocates output.
    Z = torch.empty((M, N), device=X.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        X, Y, Z,  #
        M, N, K,  #
        Xstride0, Xstride1 ,  #
        Ystride0, Ystride1,  #
        Z.stride(0), Z.stride(1),  #
        ACTIVATION=activation  #
    )
    return Z

And below are my results for different matrix sizes

<style> </style>
M N K A*B (timings) A_transposed*B((timings) A*B_transposed(timings)
256 256 256 1.318964 0.907858 1.226405
384 384 384 3.118012 2.131894 2.900774
512 512 512 5.785247 3.892625 5.412005
640 640 640 9.077008 5.217835 7.710117
768 768 768 13.291809 7.600417 11.168265
896 896 896 18.055299 10.394843 15.239896
1024 1024 1024 15.391941 8.935934 14.143069
1152 1152 1152 20.20116 11.388735 18.375286
1280 1280 1280 23.831273 13.951251 21.836236
1408 1408 1408 17.213304 8.707603 14.907655
1536 1536 1536 20.06133 10.370532 17.148773
1664 1664 1664 24.393493 12.220038 20.981069
1792 1792 1792 27.560273 14.140419 23.638617
1920 1920 1920 22.512367 11.050912 18.981677
2048 2048 2048 24.232494 12.344698 20.752645
2176 2176 2176 26.653839 13.926401 23.222386
2304 2304 2304 24.732246 12.112852 20.724194
2432 2432 2432 25.813591 13.186987 22.266819
2560 2560 2560 28.185633 14.555469 24.545318
2688 2688 2688 27.394669 13.298907 23.200646
2816 2816 2816 29.298933 14.379667 24.988221
2944 2944 2944 28.007605 13.306519 23.660148
3072 3072 3072 30.752535 14.676327 25.870064
3200 3200 3200 29.507959 13.703122 24.627964
3328 3328 3328 29.311299 14.481528 24.880887
3456 3456 3456 29.818425 13.93255 24.705676
3584 3584 3584 31.112594 14.856676 26.305472
3712 3712 3712 31.907325 15.681895 27.125861
3840 3840 3840 33.08352 15.447832 27.721634
3968 3968 3968 30.94293 14.636378 25.89539
4096 4096 4096 32.431981 15.491316 27.195817

I also tried modifying the kernel in tutorial to a fused_multiply_add and still get similar numbers, I don't see a performance degradation when one of the inputs is transposed instead I still see a slight performance increase. Could you recheck if you are using latest Agama drivers and pytorch from upstream, and run the kernel in this tutorial with launch script. And let me know if the performance degradation still exists just for running matrix multiplication alone. As there might be possibly other functionalities in triton_inf or triton_train which might have a unexpected effect.

These are my hw details

LIBIGC1_VERSION=1.0.17193.16-950
LEVEL_ZERO_VERSION=1.3.30049.10-950
AGAMA_VERSION=950
GPU_DEVICE=Intel(R) Data Center GPU Max 1100

@alexbaden
Copy link
Contributor

@mgrabban could you provide us with the cached Triton-generated code for both runs (transpose and w/out transpose?) The easiest way to do it is to delete your Triton cache (rm -rf ~/.triton/cache) and then run both kernels. You should see ~5 folders in the cache dir. Two will contain several files ending in .ttir, .llir, .spv etc - one for the transpose and one w/out. Can you copy both folders here so we can examine the IR? That will also let us run your generated code verbatim on our systems.

@mgrabban
Copy link
Author

mgrabban commented Sep 5, 2024

Just for reference: a single file reproducer was provided to @alexbaden

@mgrabban
Copy link
Author

mgrabban commented Oct 2, 2024

Hi,
We are not seeing any change in perf when re-building triton from the main branch (I assume the fix got merged).
Can you please provide details about how to build triton and associated stuff so that we can get the fix. We need the specific driver and pytorch-dev-gpu versions needed and the build steps. Thanks.

@Egor-Krivov
Copy link
Contributor

Looks like it currently works only in case of A@B^t. I don't see perf improvements for A^t@B case.

I will add microbenchmarks to track A@B^t and A^t@B to this repo to track performance of both cases:
#2414
#2424

@Egor-Krivov
Copy link
Contributor

Egor-Krivov commented Oct 4, 2024

Just for reference, here is a reproducer mentioned above and it's output.

Output:

Compute A x B
Time for torch: 0.3025856018066406 ms
Time for triton: 0.34675198793411255 ms
Compute A.T x B
Time for torch: 0.3038911819458008 ms
Time for triton: 3.4030990600585938 ms
Compute A x B.T
Max diff is  tensor(1., device='xpu:0', dtype=torch.bfloat16)
Time for torch: 0.3247919976711273 ms
Time for triton: 0.6464511752128601 ms
Compute A.T x B.T
OpenCL API not available for this operation. Got %2472 = "triton_gen.2Dblockload"(%2454, %2471, %2457, %2470, %2469, %2468) <{cache_control = 0 : i32, elem_size_in_bits = 32 : i32, tile_height = 32 : i32, tile_width = 8 : i32, transpose = true, v_blocks = 1 : i32, vnni_transform = false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32>
OpenCL API not available for this operation. Got %2503 = "triton_gen.2Dblockload"(%2454, %2502, %2457, %2501, %2500, %2499) <{cache_control = 0 : i32, elem_size_in_bits = 32 : i32, tile_height = 32 : i32, tile_width = 8 : i32, transpose = true, v_blocks = 1 : i32, vnni_transform = false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32>
OpenCL API not available for this operation. Got %2534 = "triton_gen.2Dblockload"(%2454, %2533, %2457, %2532, %2531, %2530) <{cache_control = 0 : i32, elem_size_in_bits = 32 : i32, tile_height = 32 : i32, tile_width = 8 : i32, transpose = true, v_blocks = 1 : i32, vnni_transform = false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32>
OpenCL API not available for this operation. Got %2565 = "triton_gen.2Dblockload"(%2454, %2564, %2457, %2563, %2562, %2561) <{cache_control = 0 : i32, elem_size_in_bits = 32 : i32, tile_height = 32 : i32, tile_width = 8 : i32, transpose = true, v_blocks = 1 : i32, vnni_transform = false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32>
Time for torch: 0.4516463875770569 ms
Time for triton: 3.7456798553466797 ms

Reproducer:

import torch
import triton
import triton.language as tl
from functools import partial

device = 'xpu'
backend = getattr(torch, device)


def compute_time(
    fn,
    warmup=1,
    rep=5,
    grad_to_none=None,
    quantiles=None,
    fast_flush=True,
    return_mode="mean",
):
    assert return_mode in ["min", "max", "mean", "median"]

    """
    Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
    the 20-th and 80-th performance percentile.

    :param fn: Function to benchmark
    :type fn: Callable
    :param warmup: Warmup time (in ms)
    :type warmup: int
    :param rep: Repetition time (in ms)
    :type rep: int
    :param grad_to_none: Reset the gradient of the provided tensor to None
    :type grad_to_none: torch.tensor, optional
    :param quantiles: Performance percentile to return in addition to the median.
    :type quantiles: list[float]
    :param fast_flush: Use faster kernel to flush L2 between measurements
    :type fast_flush: bool
    """
    backend.synchronize()

    # We maintain a buffer of 256 MB that we clear
    # before each kernel call to make sure that the L2
    # doesn't contain any input data before the run
    if fast_flush:
        cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device)
    else:
        cache = torch.empty(int(256e6), dtype=torch.int8, device=device)

    # compute number of warmup and repeat

    start_event = [backend.Event(enable_timing=True) for i in range(rep)]
    end_event = [backend.Event(enable_timing=True) for i in range(rep)]
    # Warm-up
    for _ in range(warmup):
        fn()
    # Benchmark
    for i in range(rep):
        # we don't want `fn` to accumulate gradient values
        # if it contains a backward pass. So we clear the
        # provided gradients
        if grad_to_none is not None:
            for x in grad_to_none:
                if hasattr(x, 'grad'):
                    x.grad = None
        # we clear the L2 cache before each run
        cache.zero_()
        # record time of `fn`
        start_event[i].record()
        fn()
        end_event[i].record()
    # Record clocks
    backend.synchronize()
    times = torch.tensor(
        [s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float
    )
    if quantiles is not None:
        ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
        if len(ret) == 1:
            ret = ret[0]
        return ret
    return getattr(torch, return_mode)(times).item()


@triton.autotune(
    configs=[
        triton.Config(kwargs={'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2, num_warps=32),
        # triton.Config(kwargs={'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=3, num_warps=32),
        # triton.Config(kwargs={'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2, num_warps=32),
        # triton.Config(kwargs={'BLOCK_SIZE_M':  64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2, num_warps=32),
        # triton.Config(kwargs={'BLOCK_SIZE_M':   8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}, num_stages=2, num_warps=32),
    ],
    key=['M', 'N', 'K'],)
@triton.jit
def matmul_kernel_with_block_pointers(
        # Pointers to matrices
        a_ptr, b_ptr, bias_ptr, c_ptr,
        # Matrix dimensions
        M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am: tl.constexpr, stride_ak: tl.constexpr,  #
        stride_bk: tl.constexpr, stride_bn: tl.constexpr,  #
        stride_cm: tl.constexpr, stride_cn: tl.constexpr,
        BIAS_REQD: tl.constexpr,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See the matrix multiplication tutorial for details.
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create block pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction and accumulate.
    # See above `Make a Block Pointer` section for details.
    a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
                                    offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
                                    order=(1, 0))
    b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
                                    offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
                                    order=(1, 0))

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block.
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_SIZE_K):
        # Load with boundary checks, no need to calculate the mask manually.
        # For better performance, you may remove some axis from the boundary
        # check, if you can guarantee that the access is always in-bound in
        # that axis.
        # See above `Load/Store a Block Pointer` section for details.
        a = tl.load(a_block_ptr, boundary_check=(0, 1))
        b = tl.load(b_block_ptr, boundary_check=(0, 1))
        # We accumulate along the K dimension.
        accumulator += tl.dot(a, b)
        # Advance the block pointer to the next K block.
        # See above `Advance a Block Pointer` section for details.
        a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
        b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
    c = accumulator.to(tl.float32)
    # add bias to accumulator
    if BIAS_REQD:
        offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
        bias = tl.load(bias_ptr + offs_yn, mask=offs_yn < N, other=0.0).to(tl.float32)
        c += bias[None, :]
    # ----------------------------------------------------------------
    # Write back the block of the output matrix C with boundary checks.
    # See above `Load/Store a Block Pointer` section for details.
    c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
                                    offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
                                    block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
    tl.store(c_block_ptr, c.to(tl.bfloat16), boundary_check=(0, 1))


def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
    if transpose_x:
        K, M = X.shape
        Xstride0, Xstride1 = X.stride(1), X.stride(0)
    else:
        M, K = X.shape
        Xstride0, Xstride1 = X.stride(0), X.stride(1)
    if transpose_y:
        N, _ = Y.shape
        Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
    else:
        _, N = Y.shape
        Wstride0, Wstride1 = Y.stride(0), Y.stride(1)
    # Allocates output.
    Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )

    matmul_kernel_with_block_pointers[grid](
        X, Y, b, Z,
        M, N, K,
        Xstride0, Xstride1,
        Wstride0, Wstride1,
        Z.stride(0), Z.stride(1),
        BIAS_REQD=b is not None,
    )

    return Z


M = 1024
K = 5120
N = 4096
dtype  = torch.bfloat16
torch.manual_seed(0)

print('Compute A x B')
X = torch.randn((M, K), device=device, dtype=dtype, requires_grad=False)
Y = torch.randn((K, N), device=device, dtype=dtype, requires_grad=False)

fn_tor = partial(torch.mm, X, Y)
fn_tri = partial(triton_mm, X, Y)

t_tor = compute_time(fn_tor, warmup=5, rep=100)
t_tri = compute_time(fn_tri, warmup=5, rep=100)
print(f"Time for torch: {t_tor} ms")
print(f"Time for triton: {t_tri} ms")

print('Compute A.T x B')
X = torch.randn((K, M), device=device, dtype=dtype, requires_grad=False)
Y = torch.randn((K, N), device=device, dtype=dtype, requires_grad=False)

fn_tor = partial(torch.mm, X.T, Y)
fn_tri = partial(triton_mm, X, Y, transpose_x=True)

t_tor = compute_time(fn_tor, warmup=5, rep=100)
t_tri = compute_time(fn_tri, warmup=5, rep=100)
print(f"Time for torch: {t_tor} ms")
print(f"Time for triton: {t_tri} ms")

torch.manual_seed(0)
print('Compute A x B.T')
X = torch.randn((M, K), device=device, dtype=dtype, requires_grad=False)
Y = torch.randn((N, K), device=device, dtype=dtype, requires_grad=False)

def f():
    return torch.mm(X, torch.transpose(Y, -2, -1))

fn_tor = f
fn_tri = partial(triton_mm, X, Y, transpose_y=True)
print("Max diff is ",(fn_tor() - fn_tri()).abs().max())

t_tor = compute_time(fn_tor, warmup=5, rep=100)
t_tri = compute_time(fn_tri, warmup=5, rep=100)
print(f"Time for torch: {t_tor} ms")
print(f"Time for triton: {t_tri} ms")

print('Compute A.T x B.T')
X = torch.randn((K, M), device=device, dtype=dtype, requires_grad=False)
Y = torch.randn((N, K), device=device, dtype=dtype, requires_grad=False)

fn_tor = partial(torch.mm, X.T, Y.T)
fn_tri = partial(triton_mm, X, Y, transpose_x=True, transpose_y=True)

t_tor = compute_time(fn_tor, warmup=5, rep=100)
t_tri = compute_time(fn_tri, warmup=5, rep=100)
print(f"Time for torch: {t_tor} ms")
print(f"Time for triton: {t_tri} ms")

Egor-Krivov added a commit that referenced this issue Oct 8, 2024
Based on this feedback
#2408 (review)

Changed GEMM benchmark to include transposed matrices case.

Closes #2424
Relates to
#1795

A@B^t case is important because weight matrix is often stored in [M, K]
format. For example, in
https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
Right now we are about 1.5 times slower on XPU against raw torch for
that case.


A^t@B case is important because it's part of matmul backprop. Right now
we are about 4 times slower on XPU against raw torch for that case.
@mgrabban
Copy link
Author

mgrabban commented Oct 8, 2024

Hello @vlad-penkin ,

Here is my finding:
Previously (using latest in llvm-target branch), we had

matmul/triton $ rm -r ~/.triton/cache/*
matmul/triton $ python test1.py 
Compute A x B
(I): Detected 7680 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
Time for torch: 0.15951359272003174 ms
Time for triton: 0.16610080003738403 ms
Compute A.T x B
(I): Detected 7680 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
Time for torch: 0.1574448049068451 ms
Time for triton: 0.6618703603744507 ms
Compute A x B.T
(I): Detected 7680 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
Time for torch: 0.17146721482276917 ms
Time for triton: 0.6434800028800964 ms
Compute A.T x B.T
(I): Detected 8000 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 128 spills
Time for torch: 0.24118880927562714 ms
Time for triton: 1.1827216148376465 ms

Now (using latest in main branch), we have

matmul/triton $ rm -r ~/.triton/cache/*
matmul/triton $ python test1.py 
Compute A x B
Time for torch: 0.16101761162281036 ms
Time for triton: 0.16846241056919098 ms
Compute A.T x B
Time for torch: 0.15552160143852234 ms
Time for triton: 1.3799647092819214 ms
Compute A x B.T
Time for torch: 0.17011040449142456 ms
Time for triton: 0.29556477069854736 ms
Compute A.T x B.T
Time for torch: 0.2403343915939331 ms
Time for triton: 1.5219616889953613 ms

So while A x B.T performance has improved (3.8X slower to now 1.7X slower vs pytorch), A.T x B performance has deteriorated (4.2X slower to now 8.9X slower vs pytorch). (I see training perf for matmul kernel in llmbench has deteriorated overall vs pytorch).
Obviously we would like improvement for both as both are needed for matmul related kernels/motifs (Linear Layer and FFN etc.) but it seems that there might be both hardware and software limitations that we also need to consider.

What do you think? Do you plan to pursue this further?

@Egor-Krivov
Copy link
Contributor

We now have A@B.T and A.T@B cases in our internal microbenchmarks set, to track performance
#2408
#2430

alexbaden added a commit that referenced this issue Oct 10, 2024
We cannot lower a transposed A matrix to a transposed 2D block load.
Instead, the load is lowered via the LLVM path introduced in #2181 .
There appears to be a performance regression in this path which is
slower than materializing the block in SLM and then reading into
registers and computing the dot product from there. Using the work in
#2420 I am able to drop the block load attribute for this case and go
down the non block ptr path.

Performance on main:
```
Compute A x B
✅ Triton and Torch match
Time for torch: 0.32444801926612854 ms
Time for triton: 0.44371041655540466 ms
Compute A x B.T
✅ Triton and Torch match
Time for torch: 0.32708799839019775 ms
Time for triton: 0.634996771812439 ms
Compute A.T x B
✅ Triton and Torch match
Time for torch: 0.31204161047935486 ms
Time for triton: 3.4140689373016357 ms
Compute A.T x B.T
✅ Triton and Torch match
Time for torch: 0.45701122283935547 ms
Time for triton: 3.7463345527648926 ms
```

Performance on this PR:
```
Compute A x B
✅ Triton and Torch match
Time for torch: 0.3081200122833252 ms
Time for triton: 0.44333598017692566 ms
Compute A x B.T
✅ Triton and Torch match
Time for torch: 0.33799198269844055 ms
Time for triton: 0.6391856074333191 ms
Compute A.T x B
✅ Triton and Torch match
Time for torch: 0.31700319051742554 ms
Time for triton: 1.5733630657196045 ms
Compute A.T x B.T
✅ Triton and Torch match
Time for torch: 0.45083683729171753 ms
Time for triton: 1.8271965980529785 ms
```

Note that the important commit is
`31386ef1132c3f6cf9cb5f1063ecfab705f4c2a1`. Once #2420 is merged I will
rebase this.

Depends on #2420. Links to #1795.
@alexbaden alexbaden reopened this Oct 10, 2024
@alexbaden
Copy link
Contributor

Confirmed with @mgrabban that performance has improved (and perhaps more importantly, not regressed). We will track additional improvements in #2354.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment