Skip to content

Commit

Permalink
LinearLayout conversion interface (#1875)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwnam831 authored and whitneywhtsang committed Aug 20, 2024
1 parent 013d9ef commit 59539b0
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 243 deletions.
3 changes: 1 addition & 2 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,7 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
std::optional<LinearLayout> parentLL =
triton::gpu::toLinearLayout(parentShape, getParent());
if (!parentLL.has_value())
llvm::report_fatal_error(
"Failed to compute parent layout for slice layout.");
return std::nullopt;

// Remove dimension getDim() from the parent layout.
//
Expand Down
58 changes: 22 additions & 36 deletions test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK-DAG: %[[CST_6:.*]] = llvm.mlir.constant(6 : i32) : i32
// CHECK-DAG: %[[CST_18:.*]] = llvm.mlir.constant(18 : i32) : i32
// CHECK-DAG: %[[CST_20:.*]] = llvm.mlir.constant(20 : i32) : i32
// CHECK-DAG: %[[CST_22:.*]] = llvm.mlir.constant(22 : i32) : i32

// CHECK: %[[THREAD_ID:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]])
// CHECK: %[[THREAD_ID_32:.*]] = llvm.trunc %[[THREAD_ID]] : i64 to i32
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_32]], %[[CST_16]] : i32
// CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_32]], %[[CST_16]] : i32
// CHECK: %[[WARP_ID_Y:.*]] = llvm.urem %[[WARP_ID]], %[[CST_2]] : i32
// CHECK: %[[VAL_23:.*]] = llvm.udiv %[[WARP_ID]], %[[CST_2]] : i32
// CHECK: %[[WARP_ID_X:.*]] = llvm.urem %[[VAL_23]], %[[CST_2]] : i32
// CHECK: %[[ROUNDED_WARP_ID_X:.*]] = llvm.urem %[[WARP_ID_X]], %[[CST_4]] : i32
// CHECK: %[[ROUNDED_WARP_ID_Y:.*]] = llvm.urem %[[WARP_ID_Y]], %[[CST_4]] : i32
// CHECK: %[[WARP_OFFSET_X:.*]] = llvm.mul %[[ROUNDED_WARP_ID_X]], %[[CST_8]] : i32
// CHECK: %[[WARP_OFFSET_Y:.*]] = llvm.mul %[[ROUNDED_WARP_ID_Y]], %[[CST_8]] : i32
// CHECK: %[[LANE_OFFSET_X:.*]] = llvm.udiv %[[LANE_ID]], %[[CST_8]] : i32
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[LANE_OFFSET_X]], %[[WARP_OFFSET_X]] : i32
// CHECK: %[[LANE_OFFSET_Y:.*]] = llvm.urem %[[LANE_ID]], %[[CST_8]] : i32
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[LANE_OFFSET_Y]], %[[WARP_OFFSET_Y]] : i32
// CHECK: %[[VAL_33:.*]] = llvm.urem %[[CST_0]], %[[CST_1]] : i32
// CHECK: %[[VAL_34:.*]] = llvm.udiv %[[CST_0]], %[[CST_1]] : i32
// CHECK: %[[VAL_35:.*]] = llvm.urem %[[VAL_34]], %[[CST_1]] : i32
// CHECK: %[[VAL_36:.*]] = llvm.urem %[[VAL_35]], %[[CST_1]] : i32
// CHECK: %[[VAL_37:.*]] = llvm.urem %[[VAL_33]], %[[CST_1]] : i32
// CHECK: %[[CTA_OFFSET_Y:.*]] = llvm.mul %[[VAL_36]], %[[CST_32]] : i32
// CHECK: %[[CTA_OFFSET_X:.*]] = llvm.mul %[[VAL_37]], %[[CST_32]] : i32
// CHECK: %[[VAL_40:.*]] = llvm.add %[[OFFSET_X]], %[[CTA_OFFSET_Y]] : i32
// CHECK: %[[VAL_41:.*]] = llvm.add %[[OFFSET_Y]], %[[CTA_OFFSET_X]] : i32
// CHECK: %[[OFFSET_X_0:.*]] = llvm.add %[[VAL_40]], %[[CST_0]] : i32
// CHECK: %[[OFFSET_Y_0:.*]] = llvm.add %[[VAL_41]], %[[CST_0]] : i32
// CHECK: %[[OFFSET_X_1:.*]] = llvm.add %[[VAL_40]], %[[CST_2]] : i32
// CHECK: %[[OFFSET_X_2:.*]] = llvm.add %[[VAL_40]], %[[CST_4]] : i32
// CHECK: %[[OFFSET_X_3:.*]] = llvm.add %[[VAL_40]], %[[CST_6]] : i32
// CHECK: %[[OFFSET_Y_1:.*]] = llvm.add %[[VAL_41]], %[[CST_16]] : i32
// CHECK: %[[OFFSET_X_4:.*]] = llvm.add %[[VAL_40]], %[[CST_16]] : i32
// CHECK: %[[OFFSET_X_5:.*]] = llvm.add %[[VAL_40]], %[[CST_18]] : i32
// CHECK: %[[OFFSET_X_6:.*]] = llvm.add %[[VAL_40]], %[[CST_20]] : i32
// CHECK: %[[OFFSET_X_7:.*]] = llvm.add %[[VAL_40]], %[[CST_22]] : i32
// CHECK-DAG: %[[THREAD_ID:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]])
// CHECK-DAG: %[[THREAD_ID_32:.*]] = llvm.trunc %[[THREAD_ID]] : i64 to i32
// CHECK-DAG: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_32]], %[[CST_16]] : i32
// CHECK-DAG: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_32]], %[[CST_16]] : i32
// CHECK: %[[VAL_37:.*]] = llvm.and %[[WARP_ID]], %[[CST_1]] : i32
// CHECK: %[[VAL_38:.*]] = llvm.icmp "eq" %[[VAL_37]], %[[CST_0]] : i32
// CHECK: %[[VAL_39:.*]] = llvm.select %[[VAL_38]], %[[CST_0]], %[[CST_8]] : i1, i32
// CHECK: %[[VAL_40:.*]] = llvm.xor %{{.*}}, %[[VAL_39]] : i32
// CHECK: %[[VAL_41:.*]] = llvm.and %[[WARP_ID]], %[[CST_2]] : i32
// CHECK: %[[VAL_42:.*]] = llvm.icmp "eq" %[[VAL_41]], %[[CST_0]] : i32
// CHECK: %[[VAL_43:.*]] = llvm.select %[[VAL_42]], %[[CST_0]], %[[CST_8]] : i1, i32
// CHECK: %[[VAL_44:.*]] = llvm.xor %{{.*}}, %[[VAL_43]] : i32
// CHECK: %[[OFFSET_X_0:.*]] = llvm.xor %[[VAL_44]], %[[CST_0]] : i32
// CHECK: %[[OFFSET_Y_0:.*]] = llvm.xor %[[VAL_40]], %[[CST_0]] : i32
// CHECK: %[[OFFSET_X_1:.*]] = llvm.xor %[[VAL_44]], %[[CST_2]] : i32
// CHECK: %[[OFFSET_X_2:.*]] = llvm.xor %[[VAL_44]], %[[CST_4]] : i32
// CHECK: %[[OFFSET_X_3:.*]] = llvm.xor %[[VAL_44]], %[[CST_6]] : i32
// CHECK: %[[OFFSET_Y_1:.*]] = llvm.xor %[[VAL_40]], %[[CST_16]] : i32
// CHECK: %[[OFFSET_X_4:.*]] = llvm.xor %[[VAL_44]], %[[CST_16]] : i32
// CHECK: %[[OFFSET_X_5:.*]] = llvm.xor %[[VAL_44]], %[[CST_18]] : i32
// CHECK: %[[OFFSET_X_6:.*]] = llvm.xor %[[VAL_44]], %[[CST_20]] : i32
// CHECK: %[[OFFSET_X_7:.*]] = llvm.xor %[[VAL_44]], %[[CST_22]] : i32
// CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_0]], %[[OFFSET_Y_0]], {{.*}}, {{.*}})
// CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_1]], %[[OFFSET_Y_0]], {{.*}}, {{.*}})
// CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_2]], %[[OFFSET_Y_0]], {{.*}}, {{.*}})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef TRITON_DIALECT_TRITON_INTEL_GPU_IR_ATTRIBUTES_H
#define TRITON_DIALECT_TRITON_INTEL_GPU_IR_ATTRIBUTES_H

#include "triton/Dialect/TritonGPU/IR/Attributes.h"

#define GET_ATTRDEF_CLASSES
#include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.h.inc"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ namespace mlir::triton::gpu {
// DPAS operand B: opIdx=1
// DPAS operand C (default): opIdx=2
// Operand A and B conversion are not used yet
std::optional<LinearLayout> DPAStoLinearLayout(ArrayRef<int64_t> shape,
Attribute layout,
unsigned opIdx = 2);
LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
unsigned opIdx = 2);

} // namespace mlir::triton::gpu

Expand Down
6 changes: 6 additions & 0 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <numeric>

#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "triton/Dialect/Triton/IR/Utility.h"
Expand Down Expand Up @@ -415,6 +416,11 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
<< "}>";
}

std::optional<LinearLayout>
DpasEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return DPAStoLinearLayout(shape, *this);
}

//===----------------------------------------------------------------------===//
// WarpEncodingAttr
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ DPASLaneBasesC(int repeatCount, int executionSize, int threadsPerWarp) {
return laneBases;
}

std::optional<LinearLayout>
DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout, unsigned opIdx) {
LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
unsigned opIdx) {
assert(opIdx < 3 && "opIdx must be 0, 1, or 2");
auto dpas = dyn_cast<DpasEncodingAttr>(layout);
assert(dpas && "Must be DPAS layout");
Expand All @@ -497,13 +497,12 @@ DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout, unsigned opIdx) {

StringAttr kRegister = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");

const SmallVector<unsigned> warpsPerCTA = dpas.getWarpsPerCTA();
int threadsPerWarp = triton::gpu::getWarpSize(dpas);
unsigned opsPerChannel = dpas.getOpsPerChannel();
auto repCluster = dpas.getRepCluster();
SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, opIdx);

auto tileLayout = LinearLayout::empty();
int systolicDepth = dpas.getSystolicDepth();
int repeatCount = dpas.getRepeatCount();
Expand All @@ -520,8 +519,14 @@ DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout, unsigned opIdx) {
// A only repeats by repCluster[0]
tileLayout *=
LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]);

nonKDim = 0;
KDim = 1;
// K-dimension is shared among warps
tileLayout *= LinearLayout::zeros1D(warpsPerCTA[1], kWarp, outDimNames[1]);
tileLayout *=
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);

} else if (opIdx == 1) { // Operand B
auto regBasesB = DPASRegBasesB(opsPerChannel, executionSize, threadsPerWarp,
systolicDepth);
Expand All @@ -532,8 +537,14 @@ DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout, unsigned opIdx) {
// B only repeats by repCluster[1]
tileLayout *=
LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]);

nonKDim = 1;
KDim = 0;

// K-dimension is shared among warps
tileLayout *=
LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]);
tileLayout *= LinearLayout::zeros1D(warpsPerCTA[0], kWarp, outDimNames[0]);
} else { // opIdx=2 -> Operand C
auto regBasesC = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp);
auto laneBasesC =
Expand All @@ -547,36 +558,26 @@ DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout, unsigned opIdx) {
LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]);
tileLayout *=
LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]);

// The identical layout is repeated among warps
tileLayout *=
LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]);
tileLayout *=
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
nonKDim = 0;
KDim = 1;
}

// Lastly, the layout repeats to match the shape.
// Operand A/B repeats through the K-dimension first then repeats
// through non-K dimension.
// through the non-K dimension.
SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, opIdx);
tileLayout *=
LinearLayout::identity1D(numReps[KDim], kRegister, outDimNames[KDim]);
tileLayout *= LinearLayout::identity1D(numReps[nonKDim], kRegister,
outDimNames[nonKDim]);

// For Operand C, warps split the tensor identically.
// For Operand A and B, warps in the K-dimension share the same data.
// In these cases, the warp hops for K-dimensions are zero.
LinearLayout warpLayout = LinearLayout::empty();
StringAttr kWarp = S("warp");
if (opIdx == 0) {
warpLayout =
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
warpLayout *= LinearLayout::zeros1D(warpsPerCTA[1], kWarp, outDimNames[1]);
} else if (opIdx == 1) {
warpLayout = LinearLayout::zeros1D(warpsPerCTA[0], kWarp, outDimNames[0]);
warpLayout *=
LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]);
} else { /* opIdx == 2 */
warpLayout = identityND(kWarp, warpsPerCTA, {0, 1}, outDimNames);
}
LinearLayout ctaLayout = tileLayout * warpLayout;

return combineCtaCgaWithShape(std::move(ctaLayout),
return combineCtaCgaWithShape(tileLayout,
CTALayoutAttr::getDefault(ctx, rank), shape);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,25 +459,16 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
const auto &shape = op.getType().getShape();
std::optional<LinearLayout> srcLayout;
auto srcTy = op.getSrc().getType();

if (auto dpasLayout = dyn_cast<DpasEncodingAttr>(srcTy.getEncoding())) {
srcLayout = gpu::DPAStoLinearLayout(shape, dpasLayout);
} else {
srcLayout = gpu::toLinearLayout(shape, srcTy.getEncoding());
}
srcLayout = gpu::toLinearLayout(shape, srcTy.getEncoding());

std::optional<LinearLayout> dstLayout;
auto dstTy = op.getType();
if (auto dpasLayout = dyn_cast<DpasEncodingAttr>(dstTy.getEncoding())) {
dstLayout = gpu::DPAStoLinearLayout(shape, dpasLayout);
} else {
dstLayout = gpu::toLinearLayout(shape, dstTy.getEncoding());
}

dstLayout = gpu::toLinearLayout(shape, dstTy.getEncoding());

if (!srcLayout.has_value() || !dstLayout.has_value()) {
return failure();
}

// There are four cases to handle.
//
// 1. Transfer between values in the same thread, in which case we simply
Expand Down
125 changes: 0 additions & 125 deletions third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,128 +113,3 @@ LLVM::LLVMFuncOp getSpirvPrintfDeclaration(RewriterBase &rewriter) {
}

} // namespace mlir::LLVM::intel

namespace mlir::triton::intel {
bool emitTransferBetweenDPASAndShared(
RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
MLIRContext *ctx = rewriter.getContext();

auto shape = registerTy.getShape();
int rank = shape.size();

StringAttr kBlock = str_attr("block");
StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");

std::optional<LinearLayout> regLayout;
if (auto dpas = dyn_cast<DpasEncodingAttr>(registerTy.getEncoding())) {
// Default is operandC (opidx == 2)
regLayout = triton::gpu::DPAStoLinearLayout(shape, dpas);
} else {
regLayout = triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
}

std::optional<LinearLayout> sharedLayout;
if (auto dpas = dyn_cast<DpasEncodingAttr>(sharedTy.getEncoding())) {
sharedLayout = triton::gpu::DPAStoLinearLayout(shape, dpas);
} else {
sharedLayout = triton::gpu::toLinearLayout(
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
}

if (!regLayout.has_value() || !sharedLayout.has_value()) {
return false;
}
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());

// sharedLayout's in-dims are currently (offset, block). Reshape to
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
// shmem strides. (The offsetX's appear in minor-to-major order.)
auto sharedLegacy =
cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());
SmallVector<std::pair<StringAttr, int32_t>> multiDimSharedSize;
for (int i = 0; i < rank; i++) {
int dim = sharedOrder[i];
int64_t size = std::max(
int64_t{1},
shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]);
multiDimSharedSize.push_back(
{str_attr("offset" + std::to_string(dim)), size});
}
multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)});
sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize);

// regToSharedLayout maps from (register, lane, warp, block) to (offsetX1,
// ..., offsetXN, block), where the offsetX's are in minor-to-major order.
LinearLayout regToSharedLayout = regLayout->invertAndCompose(*sharedLayout);

// TODO(jlebar): We don't currently support loading from shared memory in a
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock);
inBlock *= 2) {
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply(
{{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}})));
// offsetX1, ..., offsetXN must all be 0.
if (!llvm::all_of(ArrayRef(idx).drop_back(1),
[&](auto offset) { return offset == 0; })) {
return false;
}
int32_t outBlock = idx.back();
if (outBlock != inBlock) {
return false;
}
}

// Determine how many consecutive registers map to consecutive shmem elements
// in out-dimension offsetN. This is our load instruction's vector width.
//
// It's OK if the vector width we choose here is wider than the hardware
// supports; LLVM will legalize it.
//
// TODO(jlebar): shmemStrides are Values, but most of them are usually integer
// constants. We could add those constant strides to the LL, and then before
// calling getNumConsecutiveInOut(), we could flatten consecutive out-dims
// which have known strides. This would allow us to vectorize across multiple
// shmem out dimensions where possible.
const int vecElems =
std::min(regToSharedLayout.getNumConsecutiveInOut(),
maxVecElems.value_or(std::numeric_limits<int>::max()));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);

int numElems = regToSharedLayout.getInDimSize(kRegister);
auto vecTy = vec_ty(elemLlvmTy, vecElems);
auto ptrTy = ptr_ty(ctx, /*addressSpace=*/3);
Value zero = i32_val(0);
SmallVector<Value> ret;
for (int i = 0; i < numElems / vecElems; i++) {
// Get the address to load/store. The multi-dim address is (offsetX1, ...,
// offsetXN, block), where the offsets appear in minor-to-major order, and
// we drop_end to drop block, which we know from above will be 0.
auto multiDimShmemOffset =
llvm::to_vector(llvm::drop_end(llvm::make_second_range(
applyLinearLayout(loc, rewriter, regToSharedLayout,
{{kRegister, i32_val(i * vecElems)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, zero}}))));

// Reorder strides according to `order`. This way they match the
// multi-dimensional offsets in regToSharedLayout.
Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset,
applyPermutation(shmemStrides, sharedOrder));
auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset);
vecAddr.setInbounds(true);
perVectorCallback(vecTy, vecAddr);
}
return true;
}
} // namespace mlir::triton::intel
Loading

0 comments on commit 59539b0

Please sign in to comment.