Skip to content

Commit

Permalink
Merge commit '5e3d85548928df11f3f7c96733df986905c57563'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Sep 5, 2024
2 parents 49358eb + 5e3d855 commit 156f014
Show file tree
Hide file tree
Showing 37 changed files with 1,057 additions and 182 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ ptxas

# Third-party include
third_party/nvidia/backend/include
third_party/nvidia/backend/lib/cupti

# Docs
docs/_build/
Expand Down
52 changes: 51 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
}

def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [
MemoryEffects<[MemWrite<GlobalMemory>]>]> {
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>]> {
let summary = "store value based on descriptor";
let description = [{
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
Expand All @@ -1156,4 +1156,54 @@ def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [
}];
}

def TT_ExperimentalTensormapCreateOp: TT_Op<
"experimental_tensormap_create",
[
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
AttrSizedOperandSegments,
]
> {
let summary = "Create a new TMA descriptor on device";
let arguments = (
ins
TT_PtrType:$desc_ptr,
TT_PtrType:$global_address,
Variadic<I32>:$box_dim,
Variadic<I32>:$global_dim,
Variadic<I64>:$global_stride,
Variadic<I32>:$element_stride,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<12>]>:$elem_type,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<2>]>:$interleave_layout,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$swizzle_mode,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$fill_mode
);
let extraClassDeclaration = [{
int32_t getRank() {
return getBoxDim().size();
}
}];
let assemblyFormat = [{
$desc_ptr `,` $global_address `,`
`[` $box_dim `]` `,`
`[` $global_dim `]` `,`
`[` $global_stride `]` `,`
`[` $element_stride `]`
attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
}

def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op<
"experimental_tensormap_fenceproxy_acquire",
[MemoryEffects<[MemWrite<GlobalMemory>]>]
> {
let summary = "Acquire fence on a tensormap object";
let arguments = (ins TT_PtrType:$desc_ptr);
let assemblyFormat = [{
$desc_ptr attr-dict `:` qualified(type($desc_ptr))
}];
}


#endif // Triton_OPS
16 changes: 15 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,27 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods<MemoryEf

Explicitly deallocating a buffer is optional; see local_dealloc.
}];
let arguments = (ins Optional<TT_Tensor>:$src);
let arguments = (
ins
Optional<TT_Tensor>:$src,
OptionalAttr<I32Attr>:$alignment
);

let builders = [
OpBuilder<(ins "Type":$result),
[{ build($_builder, $_state, result, Value(), IntegerAttr()); }]>,
OpBuilder<(ins "Type":$result, "Value":$src),
[{ build($_builder, $_state, result, src, IntegerAttr()); }]>,
OpBuilder<(ins "Type":$result, "Value":$src, "int32_t":$alignment),
[{ build($_builder, $_state, result, src, $_builder.getI32IntegerAttr(alignment)); }]>
];

let extraClassDeclaration = [{
bool isSharedMemoryAlloc() {
return getType().getMemorySpace() &&
isa<SharedMemorySpaceAttr>(getType().getMemorySpace());
}
int32_t getAlignmentOrDefault();
}];
let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,5 +244,4 @@ def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> {
let assemblyFormat = "attr-dict";
}


#endif
19 changes: 9 additions & 10 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Alias.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -179,8 +180,6 @@ class AllocationAnalysis {

/// Initializes explicitly defined shared memory values for a given operation.
void getExplicitValueSize(Operation *op) {
// XXX(Keren): Why this hard-coded alignment?
size_t kAlignment = 8;
for (Value result : op->getResults()) {
auto alloc = result.getDefiningOp<triton::gpu::LocalAllocOp>();
if (alloc && alloc.isSharedMemoryAlloc()) {
Expand All @@ -191,15 +190,9 @@ class AllocationAnalysis {
auto bytes = product<int64_t>(shapePerCTA) *
allocType.getElementTypeBitWidth() / 8;

// XXX(Keren): magic numbers 256 and 1024
// benzh@maybe alignment should be passed in.
// Software swizzling calculates phase based on offset, while hardware
// swizzling do that based on physical address. Thus only by setting the
// alignment to 1024 can ensure the correctness. 
if (bytes > 256)
kAlignment = 1024;
auto alignment = alloc.getAlignmentOrDefault();
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes,
kAlignment);
alignment);
}
}
}
Expand Down Expand Up @@ -285,6 +278,12 @@ class AllocationAnalysis {
auto bytes = funcAlloc->getSharedMemorySize();
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
scratchAlignment);
} else if (auto createTensormap =
dyn_cast<ExperimentalTensormapCreateOp>(op)) {
constexpr int32_t kTMASize = 128;
constexpr int32_t kTMAAlign = 128;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, kTMASize,
kTMAAlign);
}
}

Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::AtomicRMWOp>, GenericOpPattern<ReturnOp>,
GenericOpPattern<triton::ExperimentalDescriptorLoadOp>,
GenericOpPattern<triton::ExperimentalDescriptorStoreOp>,
GenericOpPattern<triton::ExperimentalTensormapCreateOp>,
GenericOpPattern<triton::ExperimentalTensormapFenceproxyAcquireOp>,
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
context);
}
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1016,5 +1016,23 @@ void ExternElementwiseOp::getEffects(
SideEffects::DefaultResource::get());
}

// -- ExperimentalTensormapCreateOp --
LogicalResult ExperimentalTensormapCreateOp::verify() {
auto rank = getBoxDim().size();
if (getGlobalDim().size() != rank) {
return emitError("Rank mismatch for global dim. Got")
<< getGlobalDim().size() << " but expected " << rank;
}
if (getGlobalStride().size() + 1 != rank) {
return emitError("Rank mismatch for global stride. Got")
<< getGlobalStride().size() << " but expected " << rank - 1;
}
if (getElementStride().size() != rank) {
return emitError("Rank mismatch for element stride. Got")
<< getElementStride().size() << " but expected " << rank;
}
return success();
}

} // namespace triton
} // namespace mlir
20 changes: 20 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3030,6 +3030,26 @@ LogicalResult MemDescSubviewOp::verify() {
return success();
}

// -- LocalAllocOp --

int32_t LocalAllocOp::getAlignmentOrDefault() {
auto align = getAlignment();
if (align) {
return *align;
}

auto ty = getType();
auto shapePerCTA = triton::gpu::getShapePerCTA(ty);
auto bytes =
product<int64_t>(shapePerCTA) * (ty.getElementTypeBitWidth() / 8);

// XXX(Keren): magic numbers 256 and 1024
// Software swizzling calculates phase based on offset, while hardware
// swizzling do that based on physical address. Thus only by setting the
// alignment to 1024 can ensure the correctness.
return bytes > 256 ? 1024 : 8;
}

//===----------------------------------------------------------------------===//
// Layout debug printing
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 156f014

Please sign in to comment.