Skip to content

Commit

Permalink
Merge commit 'a5b3783491a42a61ed5b8cb32a1178eb08e7b085'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jun 22, 2024
2 parents 61042a1 + a5b3783 commit 8f1a253
Show file tree
Hide file tree
Showing 44 changed files with 626 additions and 3,677 deletions.
9 changes: 1 addition & 8 deletions .github/workflows/build-test-reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,7 @@ jobs:
cd python/test/unit
TRITON_INTERPRET=1 TRITON_TEST_SUITE=interpreter \
pytest -vvv -n 16 -m interpreter language/test_core.py language/test_standard.py \
language/test_random.py operators/test_flash_attention.py::test_op --device cpu
- name: Run partial operators tests
run: |
source ./scripts/pytest-utils.sh
cd python/test/unit
TRITON_TEST_SUITE=operators \
pytest -vvv -n 8 --device xpu operators
language/test_random.py --device cpu
- name: Regression tests
run: |
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ jobs:
cd python/test/unit
python3 -m pytest -s -n 16 -m interpreter language/test_core.py language/test_standard.py \
language/test_random.py language/test_block_pointer.py language/test_subprocess.py \
operators/test_flash_attention.py::test_op \
../../tutorials/06-fused-attention.py::test_op --device cpu
- name: Run C++ unittests
run: |
Expand Down Expand Up @@ -384,7 +383,7 @@ jobs:
cd python/test/unit
## test_subprocess.py is flaky on the AMD CI.
## TODO (lixun) find a solution and re-enable it.
pytest --capture=tee-sys -rfs -n 32 language operators \
pytest --capture=tee-sys -rfs -n 32 language \
hopper/test_mixed_io.py \
hopper/test_gemm.py \
hopper/test_tma_store_gemm.py \
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ jobs:
cd python/test/unit
python3 -m pytest -s -n 16 -m interpreter language/test_core.py language/test_standard.py \
language/test_random.py language/test_block_pointer.py language/test_subprocess.py \
operators/test_flash_attention.py::test_op \
../../tutorials/06-fused-attention.py::test_op --device cpu

- &run-cpp-unittests-step
Expand Down Expand Up @@ -388,7 +387,7 @@ jobs:
cd python/test/unit
## test_subprocess.py is flaky on the AMD CI.
## TODO (lixun) find a solution and re-enable it.
pytest --capture=tee-sys -rfs -n 32 language operators \
pytest --capture=tee-sys -rfs -n 32 language \
hopper/test_mixed_io.py \
hopper/test_gemm.py \
hopper/test_tma_store_gemm.py \
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).
- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks.

# Usage Guide

Expand Down
107 changes: 107 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Schedule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h"
#include "llvm/ADT/ArrayRef.h"
#include <list>
#include <vector>

namespace mlir {
namespace triton {

/// This fill out the pipelining options including schedule and annotations
/// for wait ops. This also does pre-processing by converting some of the
/// loads into async loads so that the IR is ready to be pipelined.
bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages,
mlir::triton::PipeliningOption &options);

/// Fills out pipelining options for an outer loop pipelining case. This
/// schedules async copies to overlap with the epilogue of a loop.
bool getOuterLoopSchedule(scf::ForOp &forOp, int numStages,
mlir::triton::PipeliningOption &options);

/// Pipeline the TMA stores in the loop.
bool pipelineTMAStores(scf::ForOp forOp);

/// This does post-processing on the pipelined loop to try to pipeline wgmma
/// ops.
// TODO: this should be included as part of the pipeline but currently the wgmma
// wait modeling is problematic.
void asyncLaunchDots(scf::ForOp forOp);

/// Post process the pipelined loop by updating the wait ops with the right
/// number of groups in flight.
void updateWaits(ModuleOp module);

class CoarseSchedule {
public:
class ClusterList {
std::list<int> orderClusters;

public:
using iterator = decltype(orderClusters)::iterator;
ClusterList() = default;
iterator begin() { return orderClusters.begin(); }
iterator end() { return orderClusters.end(); }
size_t size() { return orderClusters.size(); }
iterator newAtBack() {
orderClusters.push_back(orderClusters.size());
return std::prev(orderClusters.end());
}
iterator newAtFront() {
orderClusters.push_front(-1);
for (auto &clusterId : orderClusters) {
clusterId++;
}
return orderClusters.begin();
}
iterator newBefore(iterator cluster) {
auto ret = orderClusters.insert(cluster, *cluster);
for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) {
clusterId++;
}
return ret;
}
};

CoarseSchedule(int numStages) : numStages(numStages) {}
int numStages;
ClusterList clusters;
using Cluster = decltype(clusters)::iterator;

DenseMap<Operation *, std::pair<int, Cluster>> opToStageAndCluster;

void insert(Operation *op, int stage, Cluster cluster) {
opToStageAndCluster[op] = {stage, cluster};
}

bool insertIfAbsent(Operation *op, int stage, Cluster cluster) {
if (opToStageAndCluster.count(op))
return false;
insert(op, stage, cluster);
return true;
}

void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
bool includeArg);

void erase(Operation *op) { opToStageAndCluster.erase(op); }

int count(Operation *op) { return opToStageAndCluster.count(op); }

std::pair<int, Cluster> operator[](Operation *op) {
return opToStageAndCluster[op];
}

SmallVector<std::tuple<Operation *, int, Cluster>>
getOpsInOrder(scf::ForOp forOp);
std::vector<std::pair<Operation *, unsigned>>
createFinalSchedule(scf::ForOp forOp);
void dump();
};

} // namespace triton
} // namespace mlir
#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
52 changes: 0 additions & 52 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,58 +482,6 @@ bool supportMFMA(triton::DotOp op) {
return true;
}

static bool supportWMMAGranularity(int m, int n, int k) {
return m % 16 == 0 && n % 16 == 0 && k % 16 == 0;
}

static bool supportWMMATypes(Type a, Type b, Type c, Type d) {
if (a != b || c != d)
return false;
auto aWidth = a.getIntOrFloatBitWidth();
auto cWidth = c.getIntOrFloatBitWidth();
if (a.isIntOrIndex()) {
if (!c.isIntOrIndex())
return false;
bool aValid = aWidth <= 8;
bool cValid = cWidth <= 32;
return aValid && cValid;
} else if (isa<FloatType>(a) && isa<FloatType>(c)) {
if (a.isBF16())
return c.isBF16() || c.isF32();
if (a.isF16())
return c.isF16() || c.isF32();
return aWidth <= cWidth && aWidth <= 16;
}
return false;
}

bool supportWMMA(triton::DotOp op) {
auto aTy = cast<RankedTensorType>(op.getA().getType());
auto bTy = cast<RankedTensorType>(op.getB().getType());
auto cTy = cast<RankedTensorType>(op.getC().getType());
auto dTy = cast<RankedTensorType>(op.getResult().getType());

auto aElemTy = aTy.getElementType();
auto bElemTy = bTy.getElementType();
auto cElemTy = cTy.getElementType();
auto dElemTy = dTy.getElementType();

if (!supportWMMATypes(aElemTy, bElemTy, cElemTy, dElemTy))
return false;

auto aShape = aTy.getShape();
auto bShape = bTy.getShape();

auto rank = aShape.size();
assert(bShape.size() == rank);
assert(aShape[rank - 1] == bShape[rank - 2]);
if (!supportWMMAGranularity(aShape[rank - 2], bShape[rank - 1],
aShape[rank - 1]))
return false;

return true;
}

bool supportMMA(triton::DotOp op, int version) {
// Refer to mma section for the data type supported by Volta and Hopper
// Tensor Core in
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
for (int baseVersion : versionsSupported) {
if (supportMMA(op, baseVersion))
return baseVersion;
if (baseVersion == 3)
op.emitRemark() << "Warning: can't use MMA V3 for the dot op";
}
return 0;
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_triton_library(TritonGPUTransforms
Pipeliner/SoftwarePipeliner.cpp
Pipeliner/TMAStoresPipeline.cpp
Pipeliner/PipeliningUtility.cpp
Pipeliner/Schedule.cpp
Prefetch.cpp
RemoveLayoutConversions.cpp
ReorderInstructions.cpp
Expand Down
Loading

0 comments on commit 8f1a253

Please sign in to comment.