From 7a024da3da7f12eb98b8a8c13fb6b6f8d033cf93 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 21 Aug 2023 00:24:37 -0700 Subject: [PATCH] Wrap remaining unfused instructions in fusions before conversion to LHLO. After this, `HloToLhloModule` no longer creates additional fusions, which means we can easily produce a mapping from MLIR operation to corresponding HLO instruction (and their types will match). After this, we can remove the MLIR->HLO conversion step from `ir_emitter_unnested`, which means codegen will again have access to valid (properly connected) HLO. This should be an NFC. PiperOrigin-RevId: 558699182 --- tensorflow/compiler/xla/service/gpu/BUILD | 28 +++- .../service/gpu/compile_module_to_llvm_ir.cc | 7 + .../xla/service/gpu/fusion_wrapper.cc | 147 +++++++++++++++++ .../compiler/xla/service/gpu/fusion_wrapper.h | 42 +++++ .../xla/service/gpu/fusion_wrapper_test.cc | 151 ++++++++++++++++++ .../xla/service/gpu/tests/copy_nested.hlo | 4 +- .../tests/element_wise_row_vectorization.hlo | 2 +- .../gpu/tests/gpu_kernel_tiling_test.cc | 10 +- .../xla/service/gpu/tests/gpu_ldg_test.cc | 2 +- .../service/gpu/tests/gpu_unrolling_test.cc | 2 +- .../tests/reduction_vectorization_sm_all.hlo | 32 ++-- .../mhlo_to_lhlo_with_xla.cc | 121 +------------- .../mhlo_to_lhlo_with_xla.h | 7 - .../tests/hlo_text_to_lhlo_no_opt.hlotxt | 123 +------------- .../tests/no_opt_ops.hlotxt | 33 +++- .../tests/non_identity_layouts.hlotxt | 9 +- 16 files changed, 441 insertions(+), 279 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/fusion_wrapper.cc create mode 100644 tensorflow/compiler/xla/service/gpu/fusion_wrapper.h create mode 100644 tensorflow/compiler/xla/service/gpu/fusion_wrapper_test.cc diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 14504d38e6fd8d..5ba8d085c79291 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -2311,12 +2311,12 @@ cc_library( ], deps = [ ":executable_proto_cc", + ":fusion_wrapper", ":gpu_constants", ":gpu_convert_async_collectives_to_sync", ":gpu_device_info", ":gpu_executable", ":gpu_hlo_schedule", - ":ir_emitter", ":ir_emitter_context", ":ir_emitter_unnested", ":metrics", @@ -4180,6 +4180,32 @@ cc_library( ], ) +cc_library( + name = "fusion_wrapper", + srcs = ["fusion_wrapper.cc"], + hdrs = ["fusion_wrapper.h"], + deps = [ + ":gpu_fusible", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/tsl/platform:errors", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "fusion_wrapper_test", + srcs = ["fusion_wrapper_test.cc"], + deps = [ + ":fusion_wrapper", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", + ], +) + xla_cc_test( name = "copy_fusion_test", srcs = ["copy_fusion_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc index 7d8f7e10ee0413..598c7c69d0f454 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/fusion_wrapper.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_convert_async_collectives_to_sync.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" @@ -386,6 +387,12 @@ Status CompileModuleToLlvmIrImpl( } } + HloPassPipeline pipeline("fusion-wrapper"); + // Wrap remaining unfused ops that have no LHLO equivalent in single-op + // fusions. This needs to happen after rematerialization, because it will + // insert additional copies. + TF_RETURN_IF_ERROR(FusionWrapper().Run(hlo_module).status()); + auto buffer_size_bytes_function = [pointer_size](const BufferValue& buffer_value) -> int64_t { return GetSizeOfShape(buffer_value.shape(), pointer_size); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_wrapper.cc b/tensorflow/compiler/xla/service/gpu/fusion_wrapper.cc new file mode 100644 index 00000000000000..5e1bbe08105779 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusion_wrapper.cc @@ -0,0 +1,147 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/gpu/fusion_wrapper.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/tsl/platform/errors.h" + +namespace xla { +namespace gpu { + +StatusOr FusionWrapper::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + auto instructions = module->entry_computation()->MakeInstructionPostOrder(); + bool changed = false; + + std::function handle_instruction; + handle_instruction = [&](HloInstruction* instruction) -> Status { + switch (instruction->opcode()) { + case HloOpcode::kConditional: + case HloOpcode::kWhile: + for (auto* computation : instruction->called_computations()) { + for (auto* inner_instruction : + computation->MakeInstructionPostOrder()) { + TF_RETURN_IF_ERROR(handle_instruction(inner_instruction)); + } + } + break; + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kAnd: + case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: + case HloOpcode::kBroadcast: + case HloOpcode::kCeil: + case HloOpcode::kCbrt: + case HloOpcode::kClamp: + case HloOpcode::kClz: + case HloOpcode::kCompare: + case HloOpcode::kComplex: + case HloOpcode::kConcatenate: + case HloOpcode::kConvert: + case HloOpcode::kCos: + case HloOpcode::kDivide: + case HloOpcode::kDot: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kGather: + case HloOpcode::kImag: + case HloOpcode::kIota: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kMap: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: + case HloOpcode::kPad: + case HloOpcode::kPopulationCount: + case HloOpcode::kPower: + case HloOpcode::kReal: + case HloOpcode::kReshape: + case HloOpcode::kReducePrecision: + case HloOpcode::kReduceWindow: + case HloOpcode::kRemainder: + case HloOpcode::kReverse: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRoundNearestEven: + case HloOpcode::kRsqrt: + case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightLogical: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSlice: + case HloOpcode::kSqrt: + case HloOpcode::kSubtract: + case HloOpcode::kStochasticConvert: + case HloOpcode::kTan: + case HloOpcode::kTanh: + case HloOpcode::kTranspose: + case HloOpcode::kXor: + case HloOpcode::kCopy: + case HloOpcode::kReduce: { + auto* computation = instruction->parent(); + auto* fusion_instruction = + computation->AddInstruction(HloInstruction::CreateFusion( + instruction->shape(), + ChooseFusionKind(*instruction /*unused but required*/, + *instruction), + instruction)); + instruction->GetModule()->SetAndUniquifyInstrName( + fusion_instruction, absl::StrCat("wrapped_", instruction->name())); + if (module->has_schedule()) { + module->schedule().replace_instruction(computation, instruction, + fusion_instruction); + } + TF_RETURN_IF_ERROR( + fusion_instruction->CopyAllControlDepsFrom(instruction)); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + changed = true; + break; + } + default: + break; + } + return OkStatus(); + }; + + for (auto* instruction : instructions) { + TF_RETURN_IF_ERROR(handle_instruction(instruction)); + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusion_wrapper.h b/tensorflow/compiler/xla/service/gpu/fusion_wrapper.h new file mode 100644 index 00000000000000..95f28e4ae3e4d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusion_wrapper.h @@ -0,0 +1,42 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_WRAPPER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_WRAPPER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace gpu { + +// Wraps leftover unfused instruction that are in the entry computation that +// have no LHLO equivalent in fusions containing just that instruction. +class FusionWrapper : public HloModulePass { + public: + absl::string_view name() const override { return "fusion-wrapper"; } + + using HloPassInterface::Run; + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_WRAPPER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/fusion_wrapper_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_wrapper_test.cc new file mode 100644 index 00000000000000..4a32e0049e0f18 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusion_wrapper_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/gpu/fusion_wrapper.h" + +#include +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace { + +class FusionWrapperTest : public HloTestBase {}; + +TEST_F(FusionWrapperTest, SimpleOp) { + RunAndFilecheckHloRewrite(R"( + HloModule TestModule + + ENTRY TestComputation { + p0 = f16[30,41] parameter(0) + p1 = f16[30,41] parameter(1) + ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0} + })", + FusionWrapper(), R"( +// CHECK: %fused_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] { +// CHECK: %param_0 = f16[30,41]{1,0} parameter(0) +// CHECK: %param_1 = f16[30,41]{1,0} parameter(1) +// CHECK: ROOT %result.1 = f16[60,41]{1,0} concatenate(%param_0, %param_1), dimensions={0} +// CHECK: } + +// CHECK: ENTRY %TestComputation (p0: f16[30,41], p1: f16[30,41]) -> f16[60,41] { +// CHECK: %p0 = f16[30,41]{1,0} parameter(0) +// CHECK: %p1 = f16[30,41]{1,0} parameter(1) +// CHECK: ROOT %wrapped_result = f16[60,41]{1,0} fusion(%p0, %p1), kind=kLoop, calls=%fused_computation +// CHECK: })"); +} + +TEST_F(FusionWrapperTest, ControlDependency) { + RunAndFilecheckHloRewrite(R"( + HloModule TestModule + + fusion { + ROOT param = f32[] parameter(0) + } + + ENTRY main { + param = f32[] parameter(0) + fusion = f32[] fusion(param), kind=kLoop, calls=fusion + constant_one = f32[] constant(1) + ROOT add = f32[] add(param, constant_one), control-predecessors={fusion} + })", + FusionWrapper(), R"( +// CHECK: ROOT %wrapped_add = f32[] fusion(%param.1, %constant_one), +// CHECK-SAME: control-predecessors={%fusion})"); +} + +TEST_F(FusionWrapperTest, While) { + RunAndFilecheckHloRewrite(R"( + HloModule While + + %body { + %parameter.5 = (f32[5]{0}) parameter(0) + %constant_8 = f32[] constant(0) + %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={} + ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9) + } + + %cond { + %parameter.12 = (f32[5]{0}) parameter(0) + ROOT %constant_1 = pred[] constant(false) + } + + ENTRY %main (parameter.1: f32[5]) -> (f32[5]) { + %parameter.1 = f32[5]{0} parameter(0) + %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1) + %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3) + ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body + })", + FusionWrapper(), R"( +// CHECK: %fused_computation.1 {{.*}} { +// CHECK: %param_0.1 = f32[] parameter(0) +// CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={} +// CHECK: } +// CHECK: %body {{.*}} { +// CHECK: %parameter.5 = (f32[5]{0}) parameter(0) +// CHECK: %constant_8 = f32[] constant(0) +// CHECK: %wrapped_broadcast.9 = f32[5]{0} fusion(%constant_8), kind=kLoop, calls=%fused_computation.1 +// CHECK: ROOT %tuple.2 = (f32[5]{0}) tuple(%wrapped_broadcast.9) +// CHECK: } +// CHECK: %cond {{.*}} { +// CHECK: %parameter.12 = (f32[5]{0}) parameter(0) +// CHECK: ROOT %constant_1 = pred[] constant(false) +// CHECK: } +// CHECK: %fused_computation {{.*}} { +// CHECK: %param_0 = f32[5]{0} parameter(0) +// CHECK: ROOT %copy.0 = f32[5]{0} copy(%param_0) +// CHECK: } +// CHECK: ENTRY %main {{.*}} { +// CHECK: %parameter.1 = f32[5]{0} parameter(0) +// CHECK: %wrapped_copy.3 = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%fused_computation +// CHECK: %tuple = (f32[5]{0}) tuple(%wrapped_copy.3) +// CHECK: ROOT %while.19 = (f32[5]{0}) while(%tuple), condition=%cond, body=%body +// CHECK: })"); +} + +TEST_F(FusionWrapperTest, WhileInFusion) { + RunAndFilecheckHloRewrite(R"( + HloModule While + + %body { + %parameter.5 = (f32[5]{0}) parameter(0) + %constant_8 = f32[] constant(0) + %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={} + ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9) + } + + %cond { + %parameter.12 = (f32[5]{0}) parameter(0) + ROOT %constant_1 = pred[] constant(false) + } + + %fusion { + %parameter.1 = f32[5]{0} parameter(0) + %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1) + %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3) + ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body + } + + ENTRY %main (parameter.1: f32[5]) -> (f32[5]) { + %parameter.1 = f32[5]{0} parameter(0) + ROOT %fusion = (f32[5]{0}) fusion(f32[5]{0} %parameter.1), kind=kLoop, calls=%fusion + })", + FusionWrapper(), + // No change + std::nullopt); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/copy_nested.hlo b/tensorflow/compiler/xla/service/gpu/tests/copy_nested.hlo index de201e7dc43764..f1753617439674 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/copy_nested.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/copy_nested.hlo @@ -49,11 +49,11 @@ // CHECK: %[[VAL_39:.*]] = udiv i32 %[[VAL_34]], 30000 // CHECK: %[[VAL_40:.*]] = icmp ult i32 %[[VAL_16]], 6000000 // CHECK: br i1 %[[VAL_40]], label %[[VAL_41:.*]], label %[[VAL_2]] -// CHECK: b.in_bounds-after: ; preds = %[[VAL_41]], %[[VAL_7]] +// CHECK: wrapped_b.in_bounds-after: // CHECK: br label %[[VAL_1]], !llvm.loop // CHECK: loop.loop_exit: ; preds = %[[VAL_1]] // CHECK: ret void -// CHECK: b.in_bounds-true: ; preds = %[[VAL_7]] +// CHECK: wrapped_b.in_bounds-true: // CHECK: %[[VAL_42:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], ptr %[[VAL_43:.*]], i32 0, i32 %[[VAL_20]], i32 %[[VAL_21]], i32 %[[VAL_18]] // CHECK: %[[VAL_44:.*]] = load float, ptr %[[VAL_42]], align 4, !invariant.load // CHECK: %[[VAL_45:.*]] = getelementptr inbounds float, ptr %[[VAL_46:.*]], i32 %[[VAL_16]] diff --git a/tensorflow/compiler/xla/service/gpu/tests/element_wise_row_vectorization.hlo b/tensorflow/compiler/xla/service/gpu/tests/element_wise_row_vectorization.hlo index f83227bb551c77..26ff69bde73d3e 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/element_wise_row_vectorization.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/element_wise_row_vectorization.hlo @@ -284,7 +284,7 @@ ENTRY main { ROOT %broadcastRowToLong = f32[3025,2025]{1,0} broadcast(%param_0), dimensions={1} } // Check that we didn't emit the simpler row broadcasting. -// CHECK-LLVM-LABEL: @broadcastRowToLong +// CHECK-LLVM-LABEL: @wrapped_broadcastRowToLong // CHECK-LLVM-NOT: row_index // ----- diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 03c8074ec8cd8f..d2c9536010351c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -448,7 +448,7 @@ TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @reduce +; CHECK-LABEL: define KERNEL_ANNOTATION @wrapped_reduce ; CHECK: call SHUFFLE ; CHECK: } )"; @@ -482,7 +482,7 @@ TEST_F(GpuKernelTilingTest, RowReductionTwoRowsPerWarp) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @reduce +; CHECK-LABEL: define KERNEL_ANNOTATION @wrapped_reduce ; CHECK: %[[TID_X:.*]] = tail call i32 TIDX() ; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 15 ; CHECK: call SHUFFLE @@ -521,7 +521,7 @@ TEST_F(GpuKernelTilingTest, RowReductionFourRowsPerWarp) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @reduce +; CHECK-LABEL: define KERNEL_ANNOTATION @wrapped_reduce ; CHECK: %[[TID_X:.*]] = tail call i32 TIDX() ; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 7 ; CHECK: call SHUFFLE @@ -561,7 +561,7 @@ TEST_F(GpuKernelTilingTest, ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); const char *expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @reduce +; CHECK-LABEL: define KERNEL_ANNOTATION @wrapped_reduce ; CHECK: store float %{{.*}}, ptr addrspace(1) ; CHECK: } )"; @@ -682,7 +682,7 @@ TEST_F(GpuKernelTilingTest, ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @reduce +; CHECK-LABEL: define KERNEL_ANNOTATION @wrapped_reduce ; CHECK-NOT: call SHUFFLE ; CHECK: } )"; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 91ba7705802222..138e77452b8933 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -145,7 +145,7 @@ TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { hlo_module->AddEntryComputation(std::move(computation)); CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"( - CHECK-LABEL: .entry add + CHECK-LABEL: .entry wrapped_add CHECK: { CHECK-NOT: ld.global.nc.f32 CHECK: ld.global.f32 diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 5415c139fbea9d..9b9c72279b17fa 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -90,7 +90,7 @@ TEST_F(GpuUnrollingTest, UnrollUnfusedAdd) { CompileAndVerifyIr(std::move(hlo_module), R"( -; CHECK-LABEL: @add +; CHECK-LABEL: @wrapped_add ; CHECK: load float ; CHECK: load float ; CHECK: fadd diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo index f84edff7098846..a1c5a315b15de6 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo @@ -4,7 +4,7 @@ // RUN: hlo_to_llvm_ir --ptx --sm=70 %s | FileCheck %s --check-prefix=CHECK-SM70 // RUN: hlo_to_llvm_ir --ptx --sm=86 %s | FileCheck %s --check-prefix=CHECK-SM86 -// CHECK-LABEL: .entry reduce_odd_row +// CHECK-LABEL: .entry wrapped_reduce_odd_row // CHECK-NOT: ld.global.nc.v2.f32 // CHECK-NOT: ld.global.nc.v4.f32 // CHECK-NOT: ld.global.nc.u64 @@ -26,7 +26,7 @@ ENTRY %main { // ----- -// CHECK-SM86-LABEL: .entry reduce_small_row +// CHECK-SM86-LABEL: .entry wrapped_reduce_small_row // CHECK-SM86: .reqntid 96, 1, 1 HloModule ReduceSmallRow @@ -45,7 +45,7 @@ ENTRY main { // ----- -// CHECK-LABEL: .entry reduce_sine +// CHECK-LABEL: .entry wrapped_reduce_sine // CHECK-COUNT-7: ld.global.nc.v2.f32 HloModule DisableSin @@ -67,16 +67,16 @@ ENTRY %main { // SM dependent tests -// CHECK-SM50-LABEL: .entry reduce_exp +// CHECK-SM50-LABEL: .entry wrapped_reduce_exp // CHECK-SM50-NOT: ld.global.nc.v2.f32 // CHECK-SM50-COUNT-8: ld.global.nc.f32 -// CHECK-SM60: .entry exp -// CHECK-SM60-LABEL: .entry reduce_exp +// CHECK-SM60: .entry wrapped_exp +// CHECK-SM60-LABEL: .entry wrapped_reduce_exp // CHECK-SM60-COUNT-8: ld.global.nc.v2.f32 -// CHECK-SM70: .entry exp -// CHECK-SM70-LABEL: .entry reduce_exp +// CHECK-SM70: .entry wrapped_exp +// CHECK-SM70-LABEL: .entry wrapped_reduce_exp // CHECK-SM70-COUNT-8: ld.global.nc.v2.f32 HloModule Exp @@ -98,14 +98,14 @@ ENTRY %main { HloModule ReduceTileFit -// CHECK-SM50-LABEL: .entry reduce_tile_fit +// CHECK-SM50-LABEL: .entry wrapped_reduce_tile_fit // CHECK-SM50-NOT: ld.global.nc.v2.f32 // CHECK-SM50-COUNT-8: ld.global.nc.f32 -// CHECK-SM60-LABEL: .entry reduce_tile_fit +// CHECK-SM60-LABEL: .entry wrapped_reduce_tile_fit // CHECK-SM60-COUNT-8: ld.global.nc.v2.f32 -// CHECK-SM70-LABEL: .entry reduce_tile_fit +// CHECK-SM70-LABEL: .entry wrapped_reduce_tile_fit // CHECK-SM70-COUNT-4: ld.global.nc.v2.f32 %max_ { @@ -124,14 +124,14 @@ ENTRY %main { HloModule ReducePower2 -// CHECK-SM50-LABEL: .entry reduce_pow_2 +// CHECK-SM50-LABEL: .entry wrapped_reduce_pow_2 // CHECK-SM50-NOT: ld.global.nc.v2.f32 // CHECK-SM50-COUNT-8: ld.global.nc.f32 -// CHECK-SM60-LABEL: .entry reduce_pow_2 +// CHECK-SM60-LABEL: .entry wrapped_reduce_pow_2 // CHECK-SM60-COUNT-4: ld.global.nc.v2.f32 -// CHECK-SM70-LABEL: .entry reduce_pow_2 +// CHECK-SM70-LABEL: .entry wrapped_reduce_pow_2 // CHECK-SM70-COUNT-4: ld.global.nc.v2.f32 %max_ { @@ -150,11 +150,11 @@ ENTRY %main { HloModule ReduceEvenColumns -// CHECK-SM60-LABEL: .entry reduce_even_col +// CHECK-SM60-LABEL: .entry wrapped_reduce_even_col // CHECK-SM60-NOT: ld.global.nc.f32 // CHECK-SM60-COUNT-8: ld.global.nc.f32 -// CHECK-SM70-LABEL: .entry reduce_even_col +// CHECK-SM70-LABEL: .entry wrapped_reduce_even_col // CHECK-SM70-COUNT-2: ld.global.nc.v2.f32 // CHECK-SM70-COUNT-2: ld.global.nc.f32 diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index a8842fc6688fe7..91b2e4a13a892e 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -53,7 +53,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/mlir/utils/error_util.h" -#include "tensorflow/compiler/xla/mlir_hlo/_virtual_includes/lhlo_gpu/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -201,59 +200,6 @@ tsl::StatusOr LhloDialectEmitter::CreateOpWithoutAttrs( return CreateOpWithoutAttrs(instr, operands); } -tsl::StatusOr LhloDialectEmitter::CreateOpInFusion( - const HloInstruction* instr, ValueRange buffer_operands, - size_t num_arguments, size_t num_results) { - Location loc = getLocation(instr); - std::vector buffers(buffer_operands.begin(), buffer_operands.end()); - absl::Span arguments = - absl::MakeSpan(buffers).subspan(0, num_arguments); - absl::Span results = - absl::MakeSpan(buffers).subspan(num_arguments, num_results); - - mlir::lmhlo::FusionOp fusion = builder_.create(loc); - mlir::OpBuilder b(&fusion.getRegion()); - - llvm::SmallVector loads; - for (Value arg : arguments) { - auto load = b.create(loc, arg); - Shape shape = xla::TypeToShape(arg.getType()); - TF_RET_CHECK(shape.IsArray()); - if (shape.layout() != - xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) { - load->setAttr("xla_shape", - b.getStringAttr(shape.ToString(/*print_layout=*/true))); - } - loads.push_back(load); - } - TF_ASSIGN_OR_RETURN(mlir::Operation * op, - xla::HloFunctionImporter::ImportInstruction( - instr, loads, symbol_table_, &b, - xla::DynamicShapeHandlingMode::kConvertToStatic)); - if (llvm::isa(op)) { - auto underlyingOp = op->getPrevNode(); - op->erase(); - op = underlyingOp; - } - TF_RET_CHECK(op->getNumResults() == num_results); - for (int i = 0; i < results.size(); i++) { - b.create(loc, op->getResult(i), results[i]); - } - return op; -} - -tsl::StatusOr LhloDialectEmitter::CreateOpInFusion( - const HloInstruction* instr) { - llvm::SmallVector operands; - size_t num_arguments, num_results; - TF_RETURN_IF_ERROR(CreateOperands(instr, std::nullopt, - TokenLoweringMode::kFailToLower, operands, - num_arguments, num_results)); - TF_ASSIGN_OR_RETURN( - auto op, CreateOpInFusion(instr, operands, num_arguments, num_results)); - return op->getParentOp(); -} - tsl::StatusOr LhloDialectEmitter::EmitOp( const HloInstruction* instr) { using xla::HloOpcode; @@ -324,73 +270,10 @@ tsl::StatusOr LhloDialectEmitter::EmitOp( return EmitRecvOp(instr); case HloOpcode::kRecvDone: return EmitRecvDoneOp(instr); - - case HloOpcode::kAbs: - case HloOpcode::kAdd: - case HloOpcode::kAnd: - case HloOpcode::kAtan2: - case HloOpcode::kBitcastConvert: - case HloOpcode::kBroadcast: - case HloOpcode::kCeil: - case HloOpcode::kCbrt: - case HloOpcode::kClamp: - case HloOpcode::kClz: - case HloOpcode::kCompare: - case HloOpcode::kComplex: - case HloOpcode::kConcatenate: - case HloOpcode::kConvert: - case HloOpcode::kCos: - case HloOpcode::kDivide: - case HloOpcode::kDot: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kExp: - case HloOpcode::kExpm1: - case HloOpcode::kFloor: - case HloOpcode::kGather: - case HloOpcode::kImag: - case HloOpcode::kIota: - case HloOpcode::kIsFinite: - case HloOpcode::kLog: - case HloOpcode::kLog1p: - case HloOpcode::kMap: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNegate: - case HloOpcode::kNot: - case HloOpcode::kOr: - case HloOpcode::kPad: - case HloOpcode::kPopulationCount: - case HloOpcode::kPower: - case HloOpcode::kReal: - case HloOpcode::kReshape: - case HloOpcode::kReducePrecision: - case HloOpcode::kReduceWindow: - case HloOpcode::kRemainder: - case HloOpcode::kReverse: - case HloOpcode::kRoundNearestAfz: - case HloOpcode::kRoundNearestEven: - case HloOpcode::kRsqrt: - case HloOpcode::kSelect: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightLogical: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kSlice: - case HloOpcode::kSqrt: - case HloOpcode::kSubtract: - case HloOpcode::kStochasticConvert: - case HloOpcode::kTan: - case HloOpcode::kTanh: - case HloOpcode::kTranspose: - case HloOpcode::kXor: - case HloOpcode::kCopy: - case HloOpcode::kReduce: - return CreateOpInFusion(instr); default: llvm::errs() << instr->ToString(); + llvm::errs() << "\n\nModule:\n" + << instr->GetModule()->ToString() << "\n\n"; return tsl::errors::Internal( absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()), " is not supported.")); diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h index 070a3d4f2add76..29ea27340edec6 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h @@ -180,13 +180,6 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault { OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr, ValueRange operands); - tsl::StatusOr CreateOpInFusion( - const xla::HloInstruction* instr, ValueRange buffer_operands, - size_t num_arguments, size_t num_results); - - tsl::StatusOr CreateOpInFusion( - const xla::HloInstruction* instr); - template DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) { return builder_.getI64TensorAttr( diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt index 1397cb2aba330f..13f66f42284311 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt @@ -385,107 +385,6 @@ ENTRY main { // ----- -HloModule Test - -// CHECK: func @main -// CHECK: "lmhlo.fusion"() -// CHECK: "mhlo.dot_general"(%{{.*}}, %{{.*}}) { -// CHECK-SAME: dot_dimension_numbers = -// CHECK-SAME: lhs_batching_dimensions = [0] -// CHECK-SAME: rhs_batching_dimensions = [0] -// CHECK-SAME: lhs_contracting_dimensions = [2] -// CHECK-SAME: rhs_contracting_dimensions = [1] -// CHECK-SAME: precision_config = [#mhlo, #mhlo] -// CHECK-SAME: : (tensor<1x3x4xf32>, tensor<1x4x5xf32>) -> tensor<1x3x5xf32> -ENTRY main { - %arg0 = f32[1,3,4]{2,1,0} parameter(0) - %arg1 = f32[1,4,5]{2,1,0} parameter(1) - ROOT %out = f32[1,3,5]{2,1,0} dot(%arg0, %arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} -} - -// ----- - -HloModule Test - -// CHECK: func @main -// CHECK: "lmhlo.fusion"() -// CHECK: mhlo.reshape %{{.*}} : (tensor<2xf32>) -> tensor<1x2xf32> -ENTRY main { - %arg0 = f32[2]{0} parameter(0) - ROOT %out = f32[1,2]{1,0} reshape(%arg0) -} - -// ----- - -HloModule Test - -max { - %a = f32[] parameter(0) - %b = f32[] parameter(1) - ROOT %c = f32[] maximum(%a, %b) -} - -// CHECK: func @main -// CHECK: "lmhlo.fusion"() -// CHECK: "mhlo.reduce_window"(%{{.*}}, %{{.*}}) ({ -// CHECK: ^bb0(%[[ARG6:.*]]: tensor, %[[ARG7:.*]]: tensor): -// CHECK: %[[RET:.*]] = mhlo.maximum %[[ARG6]], %[[ARG7]] : tensor -// CHECK: mhlo.return %[[RET]] : tensor -// CHECK: }) { -// CHECK-SAME: padding = dense<{{\[}}[0, 0], [2, 0], [0, 2], [0, 0]{{\]}}> : tensor<4x2xi64>, -// CHECK-SAME: window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, -// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, -// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} -// CHECK-SAME: : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x5x8x7xf32> -ENTRY main { - %arg0 = f32[2,17,31,7] parameter(0) - %c = f32[] constant(0) - ROOT %out = reduce-window(%arg0, %c), window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 lhs_dilate=1x1x1x1 rhs_dilate=1x2x2x1}, to_apply=max -} - -// ----- - -HloModule Test - -// CHECK: func @main -// CHECK: "lmhlo.fusion"() -// CHECK: "mhlo.pad"(%{{.*}}, %{{.*}}) { -// CHECK-SAME: edge_padding_high = dense<[4, 5]> : tensor<2xi64>, -// CHECK-SAME: edge_padding_low = dense<[2, 3]> : tensor<2xi64>, -// CHECK-SAME: interior_padding = dense<1> : tensor<2xi64>} -// CHECK-SAME: : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32> -ENTRY main { - %arg0 = f32[4,6] parameter(0) - %arg1 = f32[] parameter(1) - %out = f32[13,19] pad(%arg0, %arg1), padding=2_4_1x3_5_1 -} - -// ----- - -HloModule Test - -// CHECK: func @main -// CHECK: "lmhlo.fusion"() -// CHECK: "mhlo.transpose"(%{{.*}}) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xf32>) -> tensor<2x1x4x3xf32> -ENTRY main { - %arg0 = f32[1,2,3,4] parameter(0) - %out = f32[2,1,4,3] transpose(%arg0), dimensions={1,0,3,2} -} - -// ----- - -HloModule Test - -// CHECK: func @main -// CHECK: "lmhlo.fusion"() -// CHECK: "mhlo.broadcast_in_dim"(%{{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<10xf32> -ENTRY main { - %arg0 = f32[1] parameter(0) - %out = f32[10] broadcast(%arg0), dimensions={0} -} - -// ----- - HloModule TestModule // CHECK: func @main @@ -634,10 +533,7 @@ ENTRY main { HloModule WhileConstantCondition %body { - %parameter.5 = (f32[5]{0}) parameter(0) - %constant_8 = f32[] constant(0) - %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={} - ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9) + ROOT %parameter.5 = (f32[5]{0}) parameter(0) } %cond { @@ -647,27 +543,12 @@ HloModule WhileConstantCondition ENTRY %main (parameter.1: f32[5]) -> (f32[5]) { %parameter.1 = f32[5]{0} parameter(0) - %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1) - %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3) + %tuple = (f32[5]{0}) tuple(f32[5]{0} %parameter.1) ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body } // ----- -// CHECK: func @main -// CHECK: mhlo.copy -// CHECK-SAME: tensor<2xi32> -HloModule CopyTest - -ENTRY main { - %parameter.1 = s32[2]{0} parameter(0) - %parameter.2 = s32[] parameter(1) - %custom-call = s32[<=2]{0} custom-call(s32[2]{0} %parameter.1, s32[] %parameter.2), custom_call_target="SliceToDynamic" - ROOT %copy = s32[<=2]{0} copy(s32[<=2]{0} %custom-call) -} - -// ----- - HloModule CustomCallNoComputation // CHECK: "lmhlo.custom_call" diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/no_opt_ops.hlotxt b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/no_opt_ops.hlotxt index bdea76d3ed7d26..0c7fc220f73afc 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/no_opt_ops.hlotxt +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/no_opt_ops.hlotxt @@ -7,16 +7,31 @@ HloModule indexed_conditional ROOT %negate = f32[] negate(f32[] %x) } +%NegateCond (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + ROOT %negate = f32[] fusion(f32[] %x), kind=kLoop, calls=%Negate +} + %Identity (y: f32[]) -> f32[] { %y = f32[] parameter(0) ROOT %copy = f32[] copy(f32[] %y) } +%IdentityCond (x: f32[]) -> f32[] { + %y = f32[] parameter(0) + ROOT %copy = f32[] fusion(f32[] %y), kind=kLoop, calls=%Identity +} + %Floor (z: f32[]) -> f32[] { %z = f32[] parameter(0) ROOT %floor = f32[] floor(f32[] %z) } +%FloorCond (x: f32[]) -> f32[] { + %z = f32[] parameter(0) + ROOT %floor = f32[] fusion(f32[] %z), kind=kLoop, calls=%Floor +} + // CHECK: %{{.*}} = memref.view // CHECK: "lmhlo.case"(%{{.*}}) ({ // CHECK: mhlo.negate @@ -34,7 +49,7 @@ ENTRY %Parameters1.v4 () -> (f32[]) { %constant.1 = f32[] parameter(1) %constant.2 = f32[] parameter(2) %constant.3 = f32[] parameter(3) - %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor} + %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%NegateCond, %IdentityCond, %FloorCond} ROOT %t = (f32[]) tuple(%conditional) } @@ -42,16 +57,28 @@ ENTRY %Parameters1.v4 () -> (f32[]) { HloModule WhileWithScalarS32Result_module +%Add (a: s32[], b: s32[]) -> s32[] { + %a = s32[] parameter(0) + %b = s32[] parameter(1) + ROOT %add = s32[] add(s32[] %a, s32[] %b) +} + %body.v3 (prev.1: s32[]) -> s32[] { %constant = s32[] constant(1) %prev.1 = s32[] parameter(0) - ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1) + ROOT %add = s32[] fusion(s32[] %constant, s32[] %prev.1), kind=kLoop, calls=%Add +} + +%Compare (a: s32[], b: s32[]) -> pred[] { + %a = s32[] parameter(0) + %b = s32[] parameter(1) + ROOT %greater-than = pred[] compare(s32[] %a, s32[] %b), direction=GT } %condition.v3 (prev.2: s32[]) -> pred[] { %constant.1 = s32[] constant(5) %prev.2 = s32[] parameter(0) - ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT + ROOT %greater-than = pred[] fusion(s32[] %constant.1, s32[] %prev.2), kind=kLoop, calls=%Compare } // CHECK: %{{.*}} = memref.view diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt index 5d2fd661919caa..56a60ec9aec907 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt @@ -4,6 +4,11 @@ HloModule TestModule // CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> +Fusion { + x = f32[3, 2]{1,0} parameter(0) + ROOT x.copy = f32[3, 2]{0,1} copy(x) +} + // CHECK: func @TestComputation ENTRY TestComputation { x = f32[3, 2]{1,0} parameter(0) @@ -17,6 +22,6 @@ ENTRY TestComputation { // CHECK-SAME: } : tensor<3x2xf32> // CHECK: memref.tensor_store %[[VAL3:.*]], %{{.*}} : memref<3x2xf32, #[[MAP]]> // CHECK: "lmhlo.terminator"() : () -> () - // CHECK: }) : () -> () - ROOT x.copy = f32[3, 2]{0,1} copy(x) + // CHECK: }) {backend_config = "{{.*}}"} : () -> () + ROOT fusion = f32[3, 2]{0,1} fusion(f32[3, 2]{1,0} x), kind=kLoop, calls=Fusion }