diff --git a/third_party/triton/temporary/fix_InsertInstructionSchedHints.patch b/third_party/triton/temporary/fix_InsertInstructionSchedHints.patch new file mode 100644 index 0000000000000..d3c09f081bdfa --- /dev/null +++ b/third_party/triton/temporary/fix_InsertInstructionSchedHints.patch @@ -0,0 +1,11 @@ +--- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td ++++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +@@ -59,7 +59,7 @@ + let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; + let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()"; + +- let dependentDialects = ["mlir::LLVM::LLVMDialect"]; ++ let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::triton::amdgpu::TritonAMDGPUDialect"]; + } + + def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 274faf600e048..829a8e8adb3ce 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -17,5 +17,6 @@ temporary_patch_list = [ "//third_party/triton:temporary/fix_left_shift_overflow.patch", "//third_party/triton:temporary/prefetch.patch", "//third_party/triton:temporary/i4_to_bf16.patch", + "//third_party/triton:temporary/fix_InsertInstructionSchedHints.patch", # Add new patches just above this line ] diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index e794a1afd330f..c9e6d553cd5d9 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -91,7 +91,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass()); pm.addPass(mlir::createCanonicalizerPass()); } - //pm.addPass(mt::createInsertInstructionSchedHintsPass()); + pm.addPass(mt::createInsertInstructionSchedHintsPass()); pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication());