Skip to content

Commit

Permalink
Wrap remaining unfused instructions in fusions before conversion to L…
Browse files Browse the repository at this point in the history
…HLO.

After this, `HloToLhloModule` no longer creates additional fusions, which
means we can easily produce a mapping from MLIR operation to corresponding
HLO instruction (and their types will match). After this, we can remove
the MLIR->HLO conversion step from `ir_emitter_unnested`, which means
codegen will again have access to valid (properly connected) HLO.

This should be an NFC.

PiperOrigin-RevId: 558699182
  • Loading branch information
jreiffers authored and tensorflower-gardener committed Aug 21, 2023
1 parent a0b91dd commit 7a024da
Show file tree
Hide file tree
Showing 16 changed files with 441 additions and 279 deletions.
28 changes: 27 additions & 1 deletion tensorflow/compiler/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2311,12 +2311,12 @@ cc_library(
],
deps = [
":executable_proto_cc",
":fusion_wrapper",
":gpu_constants",
":gpu_convert_async_collectives_to_sync",
":gpu_device_info",
":gpu_executable",
":gpu_hlo_schedule",
":ir_emitter",
":ir_emitter_context",
":ir_emitter_unnested",
":metrics",
Expand Down Expand Up @@ -4180,6 +4180,32 @@ cc_library(
],
)

cc_library(
name = "fusion_wrapper",
srcs = ["fusion_wrapper.cc"],
hdrs = ["fusion_wrapper.h"],
deps = [
":gpu_fusible",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/tsl/platform:errors",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)

xla_cc_test(
name = "fusion_wrapper_test",
srcs = ["fusion_wrapper_test.cc"],
deps = [
":fusion_wrapper",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"@com_google_googletest//:gtest_main",
],
)

xla_cc_test(
name = "copy_fusion_test",
srcs = ["copy_fusion_test.cc"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_wrapper.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_convert_async_collectives_to_sync.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
Expand Down Expand Up @@ -386,6 +387,12 @@ Status CompileModuleToLlvmIrImpl(
}
}

HloPassPipeline pipeline("fusion-wrapper");
// Wrap remaining unfused ops that have no LHLO equivalent in single-op
// fusions. This needs to happen after rematerialization, because it will
// insert additional copies.
TF_RETURN_IF_ERROR(FusionWrapper().Run(hlo_module).status());

auto buffer_size_bytes_function =
[pointer_size](const BufferValue& buffer_value) -> int64_t {
return GetSizeOfShape(buffer_value.shape(), pointer_size);
Expand Down
147 changes: 147 additions & 0 deletions tensorflow/compiler/xla/service/gpu/fusion_wrapper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/fusion_wrapper.h"

#include <functional>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/tsl/platform/errors.h"

namespace xla {
namespace gpu {

StatusOr<bool> FusionWrapper::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
auto instructions = module->entry_computation()->MakeInstructionPostOrder();
bool changed = false;

std::function<Status(HloInstruction*)> handle_instruction;
handle_instruction = [&](HloInstruction* instruction) -> Status {
switch (instruction->opcode()) {
case HloOpcode::kConditional:
case HloOpcode::kWhile:
for (auto* computation : instruction->called_computations()) {
for (auto* inner_instruction :
computation->MakeInstructionPostOrder()) {
TF_RETURN_IF_ERROR(handle_instruction(inner_instruction));
}
}
break;
case HloOpcode::kAbs:
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kBitcastConvert:
case HloOpcode::kBroadcast:
case HloOpcode::kCeil:
case HloOpcode::kCbrt:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConcatenate:
case HloOpcode::kConvert:
case HloOpcode::kCos:
case HloOpcode::kDivide:
case HloOpcode::kDot:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kGather:
case HloOpcode::kImag:
case HloOpcode::kIota:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kMap:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kPad:
case HloOpcode::kPopulationCount:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kReshape:
case HloOpcode::kReducePrecision:
case HloOpcode::kReduceWindow:
case HloOpcode::kRemainder:
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRoundNearestEven:
case HloOpcode::kRsqrt:
case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSlice:
case HloOpcode::kSqrt:
case HloOpcode::kSubtract:
case HloOpcode::kStochasticConvert:
case HloOpcode::kTan:
case HloOpcode::kTanh:
case HloOpcode::kTranspose:
case HloOpcode::kXor:
case HloOpcode::kCopy:
case HloOpcode::kReduce: {
auto* computation = instruction->parent();
auto* fusion_instruction =
computation->AddInstruction(HloInstruction::CreateFusion(
instruction->shape(),
ChooseFusionKind(*instruction /*unused but required*/,
*instruction),
instruction));
instruction->GetModule()->SetAndUniquifyInstrName(
fusion_instruction, absl::StrCat("wrapped_", instruction->name()));
if (module->has_schedule()) {
module->schedule().replace_instruction(computation, instruction,
fusion_instruction);
}
TF_RETURN_IF_ERROR(
fusion_instruction->CopyAllControlDepsFrom(instruction));
TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction));
TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
changed = true;
break;
}
default:
break;
}
return OkStatus();
};

for (auto* instruction : instructions) {
TF_RETURN_IF_ERROR(handle_instruction(instruction));
}
return changed;
}

} // namespace gpu
} // namespace xla
42 changes: 42 additions & 0 deletions tensorflow/compiler/xla/service/gpu/fusion_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_WRAPPER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_WRAPPER_H_

#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"

namespace xla {
namespace gpu {

// Wraps leftover unfused instruction that are in the entry computation that
// have no LHLO equivalent in fusions containing just that instruction.
class FusionWrapper : public HloModulePass {
public:
absl::string_view name() const override { return "fusion-wrapper"; }

using HloPassInterface::Run;
StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
};

} // namespace gpu
} // namespace xla

#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_WRAPPER_H_
Loading

0 comments on commit 7a024da

Please sign in to comment.