From 2bd6f7eb52229050ed1347889e0408bfe58542cd Mon Sep 17 00:00:00 2001 From: Zoran Jovanovic Date: Sun, 15 Dec 2024 11:30:46 +0000 Subject: [PATCH] [ROCm] Fixed an issue with InstructionSchedHintsPass --- .../temporary/fix_InsertInstructionSchedHints.patch | 11 +++++++++++ third_party/triton/temporary/series.bzl | 1 + .../gpu/fusions/triton/compilation_pipeline_rocm.cc | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 third_party/triton/temporary/fix_InsertInstructionSchedHints.patch diff --git a/third_party/triton/temporary/fix_InsertInstructionSchedHints.patch b/third_party/triton/temporary/fix_InsertInstructionSchedHints.patch new file mode 100644 index 0000000000000..d3c09f081bdfa --- /dev/null +++ b/third_party/triton/temporary/fix_InsertInstructionSchedHints.patch @@ -0,0 +1,11 @@ +--- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td ++++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +@@ -59,7 +59,7 @@ + let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; + let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()"; + +- let dependentDialects = ["mlir::LLVM::LLVMDialect"]; ++ let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::triton::amdgpu::TritonAMDGPUDialect"]; + } + + def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 274faf600e048..829a8e8adb3ce 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -17,5 +17,6 @@ temporary_patch_list = [ "//third_party/triton:temporary/fix_left_shift_overflow.patch", "//third_party/triton:temporary/prefetch.patch", "//third_party/triton:temporary/i4_to_bf16.patch", + "//third_party/triton:temporary/fix_InsertInstructionSchedHints.patch", # Add new patches just above this line ] diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index e794a1afd330f..c9e6d553cd5d9 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -91,7 +91,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass()); pm.addPass(mlir::createCanonicalizerPass()); } - //pm.addPass(mt::createInsertInstructionSchedHintsPass()); + pm.addPass(mt::createInsertInstructionSchedHintsPass()); pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication());