Skip to content

Commit

Permalink
[hlo-translate] Tool : introduce hlo-translate tool. leaner, cleaner …
Browse files Browse the repository at this point in the history
…and simpler UX vs xla-translate.

PiperOrigin-RevId: 693563050
  • Loading branch information
abhigunj authored and tensorflower-gardener committed Nov 6, 2024
1 parent 9547023 commit f487e9b
Show file tree
Hide file tree
Showing 9 changed files with 396 additions and 0 deletions.
34 changes: 34 additions & 0 deletions third_party/xla/xla/hlo/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,37 @@ xla_cc_binary(
"@local_tsl//tsl/platform:status",
],
)

xla_cc_binary(
name = "hlo-translate",
testonly = True,
srcs = ["hlo_translate.cc"],
deps = [
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/translate:stablehlo",
"//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service:hlo_proto_cc",
"//xla/service:hlo_proto_util",
"//xla/service/cpu:cpu_compiler",
"//xla/service/cpu:cpu_transfer_manager",
"//xla/service/llvm_ir:llvm_util",
"//xla/stream_executor/host:host_platform",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TranslateLib",
"@local_tsl//tsl/platform:protobuf",
"@stablehlo//:register",
],
)
216 changes: 216 additions & 0 deletions third_party/xla/xla/hlo/tools/hlo_translate.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <string>
#include <utility>

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "stablehlo/dialect/Register.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
#include "xla/hlo/translate/stablehlo.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_proto_util.h"
#include "xla/service/llvm_ir/llvm_util.h"
#include "xla/shape_util.h"
#include "tsl/platform/protobuf.h"

namespace {

// NOLINTNEXTLINE
llvm::cl::opt<bool> emit_mhlo("emit-mhlo",
llvm::cl::desc("Translate to MHLO instead of "
"default StableHLO"),
llvm::cl::init(false));

// NOLINTNEXTLINE
llvm::cl::opt<bool> emit_proto("emit-proto",
llvm::cl::desc("Emit HLO proto instead of text"),
llvm::cl::init(false));

// NOLINTNEXTLINE
llvm::cl::opt<bool> print_layouts(
"print-layouts", llvm::cl::desc("Print layouts in the generated HLO text"),
llvm::cl::init(false));

// NOLINTNEXTLINE
llvm::cl::opt<bool> print_large_constants(
"print-large-constants",
llvm::cl::desc("Print large constants in the generated HLO text"),
llvm::cl::init(false));

// NOLINTNEXTLINE
llvm::cl::opt<bool> print_sugar(
"print-sugar",
llvm::cl::desc(
"Print async ops using syntactic sugar in the generated HLO text"),
llvm::cl::init(true));

static void RegisterInputMlirDialects(mlir::DialectRegistry& registry) {
mlir::stablehlo::registerAllDialects(registry);
registry.insert<mlir::arith::ArithDialect, mlir::func::FuncDialect,
mlir::tensor::TensorDialect, mlir::mhlo::MhloDialect>();
}

// Error collector that simply ignores errors reported.
class NoOpErrorCollector : public tsl::protobuf::io::ErrorCollector {
public:
void AddError(int line, int column, const std::string& message) override {}
};

bool LoadHloProto(const std::string& contents, xla::HloProto* hlo_proto) {
tsl::protobuf::TextFormat::Parser parser;
NoOpErrorCollector collector;
parser.RecordErrorsTo(&collector);
return hlo_proto->ParseFromString(contents) ||
parser.ParseFromString(contents, hlo_proto) ||
hlo_proto->mutable_hlo_module()->ParseFromString(contents) ||
parser.ParseFromString(contents, hlo_proto->mutable_hlo_module());
}

mlir::OwningOpRef<mlir::ModuleOp> GetModuleFromHLOText(
std::string content, mlir::MLIRContext* context) {
auto hlo_text = xla::ParseAndReturnUnverifiedModule(content);
if (!hlo_text.ok()) {
return nullptr;
}

mlir::OwningOpRef<mlir::ModuleOp> module =
xla::llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context));
auto hlo_module = std::move(hlo_text.value());
auto status = ConvertHloToMlirHlo(*module, hlo_module.get(),
/*import_all_computations=*/true,
/*flatten_computation_args_result*/ true);
if (!status.ok()) {
LOG(INFO) << "Failed to parse input as HLO text" << status;
return nullptr;
}
return module;
}

mlir::OwningOpRef<mlir::ModuleOp> GetModuleFromHLOProto(
std::string content, mlir::MLIRContext* context) {
xla::HloProto hlo_proto;
if (!LoadHloProto(content, &hlo_proto)) {
return nullptr;
}

mlir::OwningOpRef<mlir::ModuleOp> module =
xla::llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context));
auto status =
ConvertHloToMlirHlo(module.get(), hlo_proto.mutable_hlo_module(),
/*import_all_computations=*/true,
/*flatten_computation_args_result=*/true);
if (!status.ok()) {
LOG(INFO) << "Failed to parse input as HLO proto" << status;
return nullptr;
}
return module;
}

} // namespace

static mlir::OwningOpRef<mlir::ModuleOp> HloToMlirTranslate(
llvm::StringRef input, mlir::MLIRContext* context) {
std::string content(input.data(), input.size());
mlir::OwningOpRef<mlir::ModuleOp> module =
GetModuleFromHLOText(content, context);

if (!module) {
module = GetModuleFromHLOProto(content, context);
}

if (!module) {
LOG(ERROR) << "Failed to parse input as HLO text or proto";
return nullptr;
}

if (emit_mhlo) return module;

mlir::PassManager pm(context);
pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass());
if (failed(pm.run(*module))) {
module->emitError("Failed to legalize to StableHLO");
return nullptr;
}

return module;
}

static mlir::LogicalResult MlirToHloTranslate(mlir::ModuleOp mlir_module,
llvm::raw_ostream& output) {
auto hlo_module_or_status = xla::ConvertStablehloToHlo(mlir_module);
if (!hlo_module_or_status.ok()) {
mlir_module->emitError(hlo_module_or_status.status().message());
LOG(ERROR) << "Module conversion failed: " << hlo_module_or_status.status();
return mlir::failure();
}
xla::HloModule* hlo_module = hlo_module_or_status.value().get();
if (emit_proto) {
// Print as HloProto with empty BufferAssignment for legacy compatibility.
output << MakeHloProto(*hlo_module).DebugString();
} else {
// Print as HLO text.
output << hlo_module->ToString(
xla::HloPrintOptions()
.set_include_layout_in_shapes(print_layouts)
.set_syntax_sugar_async_ops(print_sugar)
.set_print_large_constants(print_large_constants));

// Output alias information as comments in the HLO text.
hlo_module->input_output_alias_config().ForEachAlias(
[&](const xla::ShapeIndex& output_index,
const xla::HloInputOutputAliasConfig::Alias& alias) {
output << "// OutputIndex " << output_index.ToString()
<< " aliases with input " << alias.parameter_number << " at "
<< alias.parameter_index.ToString() << "\n";
});
}
return mlir::success();
}

static mlir::TranslateToMLIRRegistration HloToMlirTranslateRegistration(
"hlo-to-mlir", "hlo to mlir translation", HloToMlirTranslate);

static mlir::TranslateFromMLIRRegistration MlirToHloTranslateRegistration(
"mlir-to-hlo", "mlir to hlo translation", MlirToHloTranslate,
RegisterInputMlirDialects);

int main(int argc, char** argv) {
return failed(
mlir::mlirTranslateMain(argc, argv, "MLIR<->HLO translation driver\n"));
}
38 changes: 38 additions & 0 deletions third_party/xla/xla/hlo/translate/tests/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
load("//xla:lit.bzl", "enforce_glob", "lit_test_suite")
load("//xla/tsl:tsl.default.bzl", "filegroup")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)

lit_test_suite(
name = "all_tests",
srcs = enforce_glob(
[
"simple.mlir",
"emit_proto.mlir",
"print_large_constants.mlir",
"print_layouts.mlir",
"simple.hlo",
"emit_mhlo.hlo",
],
include = [
"*.mlir",
"*.hlo",
],
),
cfg = "//xla:lit.cfg.py",
data = [":test_utilities"],
tools = [
"//xla/hlo/tools:hlo-translate",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
)
13 changes: 13 additions & 0 deletions third_party/xla/xla/hlo/translate/tests/emit_mhlo.hlo
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: hlo-translate -hlo-to-mlir -emit-mhlo %s | FileCheck %s

// CHECK-LABEL: module @main
HloModule main, entry_computation_layout={(f32[4]{0}, f32[4]{0})->(f32[])}
ENTRY %main.6 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[]) {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK: mhlo.add
%add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
// CHECK: mhlo.dot
%dot.4 = f32[] dot(f32[4] %add.3, f32[4] %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
ROOT %tuple.5 = (f32[]) tuple(f32[] %dot.4)
}
20 changes: 20 additions & 0 deletions third_party/xla/xla/hlo/translate/tests/emit_proto.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: hlo-translate -mlir-to-hlo -emit-proto %s | FileCheck %s

// CHECK: name: "foobar
// CHECK: entry_computation_name: "main
// CHECK: computations {
// CHECK: name: "main
// CHECK: instructions {
// CHECK: name: "Arg_
// CHECK: opcode: "parameter"
// CHECK: name: "add
// CHECK: opcode: "add"
// CHECK: name: "dot
// CHECK: opcode: "dot"
module @foobar {
func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<4xf32>
%1 = stablehlo.dot %0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
return %1 : tensor<f32>
}
}
18 changes: 18 additions & 0 deletions third_party/xla/xla/hlo/translate/tests/print_large_constants.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: hlo-translate -split-input-file -mlir-to-hlo %s | FileCheck %s --check-prefix CHECK
// RUN: hlo-translate -split-input-file -mlir-to-hlo -print-large-constants %s | FileCheck %s --check-prefix CHECK-PRINT-LARGE

func.func @main(%arg0: tensor<10xi32>) -> tensor<10xi32> {
// CHECK: constant({0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
// CHECK-PRINT-LARGE: constant({0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
%0 = stablehlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>
func.return %0 : tensor<10xi32>
}

// -----

func.func @main(%arg0: tensor<11xi32>) -> tensor<11xi32> {
// CHECK: constant({...})
// CHECK-PRINT-LARGE: constant({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
%0 = stablehlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : tensor<11xi32>
func.return %0 : tensor<11xi32>
}
16 changes: 16 additions & 0 deletions third_party/xla/xla/hlo/translate/tests/print_layouts.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: hlo-translate -mlir-to-hlo -print-layouts %s | FileCheck %s --check-prefix CHECK

// CHECK-LABEL: main
// CHECK: ENTRY
// CHECK: [[ARG:%.*]] = token[] parameter(0)
// CHECK: [[INFEED:%.*]] = ((s32[3,3]{0,1}, pred[]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar"
// CHECK: [[GTE1:%.*]] = (s32[3,3]{0,1}, pred[]) get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) [[INFEED]]), index=0
// CHECK: [[GTE2:%.*]] = s32[3,3]{0,1} get-tuple-element((s32[3,3]{0,1}, pred[]) [[GTE1]]), index=0
// CHECK: [[GTE3:%.*]] = pred[] get-tuple-element((s32[3,3]{0,1}, pred[]) [[GTE1]]), index=1
// CHECK: [[GTE4:%.*]] = token[] get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) [[INFEED]]), index=1
func.func @main(%arg0: !stablehlo.token) -> tuple<tuple<tensor<3x3xi32>, tensor<i1>>, !stablehlo.token> {
%0:3 = "stablehlo.infeed"(%arg0) {infeed_config = "foobar", layout=[[0, 1], [0]]} : (!stablehlo.token) -> (tensor<3x3xi32>, tensor<i1>, !stablehlo.token)
%1 = "stablehlo.tuple"(%0#0, %0#1) : (tensor<3x3xi32>, tensor<i1>) -> tuple<tensor<3x3xi32>, tensor<i1>>
%2 = "stablehlo.tuple"(%1, %0#2) : (tuple<tensor<3x3xi32>, tensor<i1>>, !stablehlo.token) -> tuple<tuple<tensor<3x3xi32>, tensor<i1>>, !stablehlo.token>
func.return %2 : tuple<tuple<tensor<3x3xi32>, tensor<i1>>, !stablehlo.token>
}
13 changes: 13 additions & 0 deletions third_party/xla/xla/hlo/translate/tests/simple.hlo
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: hlo-translate -hlo-to-mlir %s | FileCheck %s

// CHECK-LABEL: module @main
HloModule main, entry_computation_layout={(f32[4]{0}, f32[4]{0})->(f32[])}
ENTRY %main.6 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[]) {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK: stablehlo.add
%add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
// CHECK: stablehlo.dot
%dot.4 = f32[] dot(f32[4] %add.3, f32[4] %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
ROOT %tuple.5 = (f32[]) tuple(f32[] %dot.4)
}
Loading

0 comments on commit f487e9b

Please sign in to comment.