diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2f74eaf498abd3..15ba868a1b12c4 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1654,6 +1654,7 @@ xla_test( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_rematerialization", + "//xla/service/gpu/transforms:stream_attribute_annotator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 1a17c3217f7ef3..2350d0d61510aa 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -2523,6 +2523,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( main_pipeline.AddPass("remat-pipeline"); pipeline.AddPass(remat_opts, sizes); + pipeline.AddPass(); pipeline.AddPass(); } diff --git a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc index e40cde05596c72..3aa9d79977bd81 100644 --- a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc @@ -30,6 +30,8 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout.h" #include "xla/service/buffer_value.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/transforms/stream_attribute_annotator.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_rematerialization.h" @@ -216,6 +218,18 @@ TEST_F(GpuOffloadingTest, CopyIRCreationTest) { RunHloRematerialization( /*memory_limit_bytes=*/10 * 1024, module.get())); ASSERT_TRUE(changed); + StreamAttributeAnnotator attr_annotator; + TF_ASSERT_OK_AND_ASSIGN(bool changed_attr, attr_annotator.Run(module.get())); + EXPECT_TRUE(changed_attr); + // Verify that the stream attribute for a copy-start is annotated + for (std::string i : {"", ".1", ".2", ".3"}) { + const HloInstruction* cp_start = + FindInstruction(module.get(), "copy-start" + i); + EXPECT_TRUE(cp_start->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + cp_start->backend_config()); + EXPECT_GT(gpu_config.operation_queue_id(), 0); + } // The module should still have a schedule. ASSERT_TRUE(module->has_schedule()); diff --git a/third_party/xla/xla/service/gpu/runtime/copy_thunk.cc b/third_party/xla/xla/service/gpu/runtime/copy_thunk.cc index 093324fd5bde58..7bc2874706b865 100644 --- a/third_party/xla/xla/service/gpu/runtime/copy_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/copy_thunk.cc @@ -130,8 +130,7 @@ absl::Status DeviceToHostCopyThunk::ExecuteOnStream( VLOG(2) << "Memcpy D2H from the main stream"; return absl::OkStatus(); } - VLOG(2) << absl::StreamFormat("Memcpy D2Hfrom the other stream %d", - Thunk::execution_stream_id().value()); + VLOG(2) << "Memcpy D2H from the other stream"; se::StreamExecutor* executor = params.stream->parent(); TF_ASSIGN_OR_RETURN(auto event, executor->CreateEvent()); // Record memcpy operation completion. @@ -169,8 +168,7 @@ absl::Status HostToDeviceCopyThunk::ExecuteOnStream( VLOG(2) << "Memcpy H2D from the main stream"; return absl::OkStatus(); } - VLOG(2) << absl::StreamFormat("Memcpy H2D from the other stream %d", - Thunk::execution_stream_id().value()); + VLOG(2) << "Memcpy H2D from the other stream"; se::StreamExecutor* executor = params.stream->parent(); TF_ASSIGN_OR_RETURN(auto event, executor->CreateEvent()); // Record memcpy operation completion. diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc index a3f46b69291101..68805b1ddc3c0c 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc @@ -89,6 +89,20 @@ absl::StatusOr AnnotateStreamAttributesForInstruction( return true; } +absl::StatusOr AnnotateStreamAttributesForCopyStart( + HloInstruction* instr, int64_t channel_id, + GpuBackendConfig& instr_gpu_config) { + // Do nothing if copy-start has already been annotated + if (instr_gpu_config.operation_queue_id() != + Thunk::kDefaultExecutionStreamId.value()) { + return false; + } + instr_gpu_config.set_operation_queue_id(channel_id); + TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config)); + VLOG(3) << "Add copy-start's backend config: " << channel_id; + return true; +} + absl::StatusOr WrapIntoFusionAndAnnotateStreamAttributes( HloInstruction* instruction, int64_t channel_id, GpuBackendConfig& instr_gpu_config) { @@ -181,6 +195,12 @@ absl::StatusOr StreamAttributeAnnotator::Run( AnnotateStreamAttributesForInstruction( instr, instr_gpu_config.value())); changed |= comp_result; + } else if (instr->opcode() == HloOpcode::kCopyStart) { + TF_ASSIGN_OR_RETURN(bool comp_result, + AnnotateStreamAttributesForCopyStart( + instr, channel_id, instr_gpu_config.value())); + changed |= comp_result; + continue; } else if (comp->IsAsyncComputation() && (instr->opcode() == HloOpcode::kDynamicSlice || instr->opcode() == HloOpcode::kDynamicUpdateSlice)) { diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc index 48ca59818ac560..c7d2ca59cff0e9 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc @@ -166,6 +166,50 @@ TEST_F(StreamAttributeAnnotatorTest, FusionIsAnnotated) { EXPECT_EQ(gpu_config.operation_queue_id(), 1); } +TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) { + constexpr absl::string_view kHloString = R"( + HloModule offloading + ENTRY %main (param_0: f32[1024], param_1: f32[1024]) -> f32[1024] { + %param_1 = f32[1024]{0} parameter(1) + %param_0 = f32[1024]{0} parameter(0) + %res_3 = f32[1024]{0} add(f32[1024]{0} %param_0, f32[1024]{0} %param_1) + %copy-start = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_3) + %res_4 = f32[1024]{0} tanh(f32[1024]{0} %res_3) + %copy-start.2 = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_4) + %res_5 = f32[1024]{0} tanh(f32[1024]{0} %res_4) + %copy-done = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start) + %res_6 = f32[1024]{0} tanh(f32[1024]{0} %res_5) + %copy-done.2 = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start.2) + %copy-start.3 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done.2) + %res_7 = f32[1024]{0} add(f32[1024]{0} %res_6, f32[1024]{0} %res_6) + %copy-start.1 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done) + %res_8 = f32[1024]{0} add(f32[1024]{0} %res_7, f32[1024]{0} %res_5) + %copy-done.3 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.3) + %res_9 = f32[1024]{0} add(f32[1024]{0} %res_8, f32[1024]{0} %copy-done.3) + %copy-done.1 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.1) + %res_10 = f32[1024]{0} add(f32[1024]{0} %res_9, f32[1024]{0} %copy-done.1) + ROOT %res_11 = f32[1024]{0} tanh(f32[1024]{0} %res_10) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + StreamAttributeAnnotator attr_annotator; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); + EXPECT_TRUE(changed); + + for (std::string i : {"", ".1", ".2", ".3"}) { + const HloInstruction* cp_start = + FindInstruction(module.get(), "copy-start" + i); + EXPECT_TRUE(cp_start->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + cp_start->backend_config()); + EXPECT_EQ(gpu_config.operation_queue_id(), 1); + } +} + TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) { constexpr absl::string_view kHloString = R"( HloModule ModuleWithAsyncDynamicUpdateSlice, is_scheduled=true