From 991dd48fe8e096963e535036cc5790752863e088 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 1 May 2024 14:44:50 +0000 Subject: [PATCH] Address code review comments Signed-off-by: Tiotto, Ettore --- .../Pipeliner/MatmulLoopPipeline.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp index c9cf27bf44..e56acf52c1 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp @@ -6,6 +6,7 @@ #include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; namespace tt = mlir::triton; @@ -14,10 +15,11 @@ namespace ttgi = mlir::triton::gpu::intel; namespace { +/// Represent a candidate load operation which is used by operations that +/// convert its layout to a 'dot' layout (e.g. triton_gpu.convert_layout). struct LoadDotOperand { LoadDotOperand(tt::LoadOp load, - ttg::DotOperandEncodingAttr dotOperandEncoding, - bool needTrans = false) + ttg::DotOperandEncodingAttr dotOperandEncoding) : load(load), dotOperandEncoding(dotOperandEncoding) {} tt::LoadOp load; ttg::DotOperandEncodingAttr dotOperandEncoding; @@ -41,7 +43,7 @@ static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { static ttg::DotOperandEncodingAttr allTransitiveUsesHaveDotEncoding(Value val); -static ttg::DotOperandEncodingAttr getEncodingFromUser(Operation *user) { +static ttg::DotOperandEncodingAttr getDotEncodingFromUser(Operation *user) { if (user->getNumResults() != 1) return nullptr; @@ -71,7 +73,7 @@ static ttg::DotOperandEncodingAttr getEncodingFromUser(Operation *user) { static ttg::DotOperandEncodingAttr allTransitiveUsesHaveDotEncoding(Value val) { ttg::DotOperandEncodingAttr attr{nullptr}; for (Operation *user : val.getUsers()) { - ttg::DotOperandEncodingAttr dotAttr = getEncodingFromUser(user); + ttg::DotOperandEncodingAttr dotAttr = getDotEncodingFromUser(user); if (!dotAttr || (attr != nullptr && attr != dotAttr)) return nullptr; attr = dotAttr;