diff --git a/third_party/xla/xla/hlo/tools/BUILD b/third_party/xla/xla/hlo/tools/BUILD index eb35cbc1a0ab19..e0d0e8c984b953 100644 --- a/third_party/xla/xla/hlo/tools/BUILD +++ b/third_party/xla/xla/hlo/tools/BUILD @@ -153,3 +153,37 @@ xla_cc_binary( "@local_tsl//tsl/platform:status", ], ) + +xla_cc_binary( + name = "hlo-translate", + testonly = True, + srcs = ["hlo_translate.cc"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/translate:stablehlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//xla/mlir_hlo", + "//xla/mlir_hlo:mhlo_passes", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_proto_util", + "//xla/service/cpu:cpu_compiler", + "//xla/service/cpu:cpu_transfer_manager", + "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor/host:host_platform", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:protobuf", + "@stablehlo//:register", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_translate.cc b/third_party/xla/xla/hlo/tools/hlo_translate.cc new file mode 100644 index 00000000000000..06b6966dc50bee --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_translate.cc @@ -0,0 +1,216 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Tools/mlir-translate/MlirTranslateMain.h" +#include "mlir/Tools/mlir-translate/Translation.h" +#include "stablehlo/dialect/Register.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/stablehlo.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_proto_util.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape_util.h" +#include "tsl/platform/protobuf.h" + +namespace { + +// NOLINTNEXTLINE +llvm::cl::opt emit_mhlo("emit-mhlo", + llvm::cl::desc("Translate to MHLO instead of " + "default StableHLO"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +llvm::cl::opt emit_proto("emit-proto", + llvm::cl::desc("Emit HLO proto instead of text"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +llvm::cl::opt print_layouts( + "print-layouts", llvm::cl::desc("Print layouts in the generated HLO text"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +llvm::cl::opt print_large_constants( + "print-large-constants", + llvm::cl::desc("Print large constants in the generated HLO text"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +llvm::cl::opt print_sugar( + "print-sugar", + llvm::cl::desc( + "Print async ops using syntactic sugar in the generated HLO text"), + llvm::cl::init(true)); + +static void RegisterInputMlirDialects(mlir::DialectRegistry& registry) { + mlir::stablehlo::registerAllDialects(registry); + registry.insert(); +} + +// Error collector that simply ignores errors reported. +class NoOpErrorCollector : public tsl::protobuf::io::ErrorCollector { + public: + void AddError(int line, int column, const std::string& message) override {} +}; + +bool LoadHloProto(const std::string& contents, xla::HloProto* hlo_proto) { + tsl::protobuf::TextFormat::Parser parser; + NoOpErrorCollector collector; + parser.RecordErrorsTo(&collector); + return hlo_proto->ParseFromString(contents) || + parser.ParseFromString(contents, hlo_proto) || + hlo_proto->mutable_hlo_module()->ParseFromString(contents) || + parser.ParseFromString(contents, hlo_proto->mutable_hlo_module()); +} + +mlir::OwningOpRef GetModuleFromHLOText( + std::string content, mlir::MLIRContext* context) { + auto hlo_text = xla::ParseAndReturnUnverifiedModule(content); + if (!hlo_text.ok()) { + return nullptr; + } + + mlir::OwningOpRef module = + xla::llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context)); + auto hlo_module = std::move(hlo_text.value()); + auto status = ConvertHloToMlirHlo(*module, hlo_module.get(), + /*import_all_computations=*/true, + /*flatten_computation_args_result*/ true); + if (!status.ok()) { + LOG(INFO) << "Failed to parse input as HLO text" << status; + return nullptr; + } + return module; +} + +mlir::OwningOpRef GetModuleFromHLOProto( + std::string content, mlir::MLIRContext* context) { + xla::HloProto hlo_proto; + if (!LoadHloProto(content, &hlo_proto)) { + return nullptr; + } + + mlir::OwningOpRef module = + xla::llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(context)); + auto status = + ConvertHloToMlirHlo(module.get(), hlo_proto.mutable_hlo_module(), + /*import_all_computations=*/true, + /*flatten_computation_args_result=*/true); + if (!status.ok()) { + LOG(INFO) << "Failed to parse input as HLO proto" << status; + return nullptr; + } + return module; +} + +} // namespace + +static mlir::OwningOpRef HloToMlirTranslate( + llvm::StringRef input, mlir::MLIRContext* context) { + std::string content(input.data(), input.size()); + mlir::OwningOpRef module = + GetModuleFromHLOText(content, context); + + if (!module) { + module = GetModuleFromHLOProto(content, context); + } + + if (!module) { + LOG(ERROR) << "Failed to parse input as HLO text or proto"; + return nullptr; + } + + if (emit_mhlo) return module; + + mlir::PassManager pm(context); + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (failed(pm.run(*module))) { + module->emitError("Failed to legalize to StableHLO"); + return nullptr; + } + + return module; +} + +static mlir::LogicalResult MlirToHloTranslate(mlir::ModuleOp mlir_module, + llvm::raw_ostream& output) { + auto hlo_module_or_status = xla::ConvertStablehloToHlo(mlir_module); + if (!hlo_module_or_status.ok()) { + mlir_module->emitError(hlo_module_or_status.status().message()); + LOG(ERROR) << "Module conversion failed: " << hlo_module_or_status.status(); + return mlir::failure(); + } + xla::HloModule* hlo_module = hlo_module_or_status.value().get(); + if (emit_proto) { + // Print as HloProto with empty BufferAssignment for legacy compatibility. + output << MakeHloProto(*hlo_module).DebugString(); + } else { + // Print as HLO text. + output << hlo_module->ToString( + xla::HloPrintOptions() + .set_include_layout_in_shapes(print_layouts) + .set_syntax_sugar_async_ops(print_sugar) + .set_print_large_constants(print_large_constants)); + + // Output alias information as comments in the HLO text. + hlo_module->input_output_alias_config().ForEachAlias( + [&](const xla::ShapeIndex& output_index, + const xla::HloInputOutputAliasConfig::Alias& alias) { + output << "// OutputIndex " << output_index.ToString() + << " aliases with input " << alias.parameter_number << " at " + << alias.parameter_index.ToString() << "\n"; + }); + } + return mlir::success(); +} + +static mlir::TranslateToMLIRRegistration HloToMlirTranslateRegistration( + "hlo-to-mlir", "hlo to mlir translation", HloToMlirTranslate); + +static mlir::TranslateFromMLIRRegistration MlirToHloTranslateRegistration( + "mlir-to-hlo", "mlir to hlo translation", MlirToHloTranslate, + RegisterInputMlirDialects); + +int main(int argc, char** argv) { + return failed( + mlir::mlirTranslateMain(argc, argv, "MLIR<->HLO translation driver\n")); +} diff --git a/third_party/xla/xla/hlo/translate/tests/BUILD b/third_party/xla/xla/hlo/translate/tests/BUILD new file mode 100644 index 00000000000000..67823abe115fbf --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/BUILD @@ -0,0 +1,38 @@ +load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") +load("//xla/tsl:tsl.default.bzl", "filegroup") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +lit_test_suite( + name = "all_tests", + srcs = enforce_glob( + [ + "simple.mlir", + "emit_proto.mlir", + "print_large_constants.mlir", + "print_layouts.mlir", + "simple.hlo", + "emit_mhlo.hlo", + ], + include = [ + "*.mlir", + "*.hlo", + ], + ), + cfg = "//xla:lit.cfg.py", + data = [":test_utilities"], + tools = [ + "//xla/hlo/tools:hlo-translate", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + ], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, +) diff --git a/third_party/xla/xla/hlo/translate/tests/emit_mhlo.hlo b/third_party/xla/xla/hlo/translate/tests/emit_mhlo.hlo new file mode 100644 index 00000000000000..d85beeb9c209ac --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/emit_mhlo.hlo @@ -0,0 +1,13 @@ +// RUN: hlo-translate -hlo-to-mlir -emit-mhlo %s | FileCheck %s + +// CHECK-LABEL: module @main +HloModule main, entry_computation_layout={(f32[4]{0}, f32[4]{0})->(f32[])} +ENTRY %main.6 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[]) { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + // CHECK: mhlo.add + %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) + // CHECK: mhlo.dot + %dot.4 = f32[] dot(f32[4] %add.3, f32[4] %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} + ROOT %tuple.5 = (f32[]) tuple(f32[] %dot.4) +} diff --git a/third_party/xla/xla/hlo/translate/tests/emit_proto.mlir b/third_party/xla/xla/hlo/translate/tests/emit_proto.mlir new file mode 100644 index 00000000000000..243a4e3978b8fb --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/emit_proto.mlir @@ -0,0 +1,20 @@ +// RUN: hlo-translate -mlir-to-hlo -emit-proto %s | FileCheck %s + +// CHECK: name: "foobar +// CHECK: entry_computation_name: "main +// CHECK: computations { +// CHECK: name: "main +// CHECK: instructions { +// CHECK: name: "Arg_ +// CHECK: opcode: "parameter" +// CHECK: name: "add +// CHECK: opcode: "add" +// CHECK: name: "dot +// CHECK: opcode: "dot" +module @foobar { + func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { + %0 = stablehlo.add %arg0, %arg1 : tensor<4xf32> + %1 = stablehlo.dot %0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor + return %1 : tensor + } +} diff --git a/third_party/xla/xla/hlo/translate/tests/print_large_constants.mlir b/third_party/xla/xla/hlo/translate/tests/print_large_constants.mlir new file mode 100644 index 00000000000000..0df3d6cfaa1f8d --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/print_large_constants.mlir @@ -0,0 +1,18 @@ +// RUN: hlo-translate -split-input-file -mlir-to-hlo %s | FileCheck %s --check-prefix CHECK +// RUN: hlo-translate -split-input-file -mlir-to-hlo -print-large-constants %s | FileCheck %s --check-prefix CHECK-PRINT-LARGE + +func.func @main(%arg0: tensor<10xi32>) -> tensor<10xi32> { + // CHECK: constant({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + // CHECK-PRINT-LARGE: constant({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + %0 = stablehlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32> + func.return %0 : tensor<10xi32> +} + +// ----- + +func.func @main(%arg0: tensor<11xi32>) -> tensor<11xi32> { + // CHECK: constant({...}) + // CHECK-PRINT-LARGE: constant({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + %0 = stablehlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : tensor<11xi32> + func.return %0 : tensor<11xi32> +} diff --git a/third_party/xla/xla/hlo/translate/tests/print_layouts.mlir b/third_party/xla/xla/hlo/translate/tests/print_layouts.mlir new file mode 100644 index 00000000000000..40528c15e92fc1 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/print_layouts.mlir @@ -0,0 +1,16 @@ +// RUN: hlo-translate -mlir-to-hlo -print-layouts %s | FileCheck %s --check-prefix CHECK + +// CHECK-LABEL: main +// CHECK: ENTRY +// CHECK: [[ARG:%.*]] = token[] parameter(0) +// CHECK: [[INFEED:%.*]] = ((s32[3,3]{0,1}, pred[]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar" +// CHECK: [[GTE1:%.*]] = (s32[3,3]{0,1}, pred[]) get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) [[INFEED]]), index=0 +// CHECK: [[GTE2:%.*]] = s32[3,3]{0,1} get-tuple-element((s32[3,3]{0,1}, pred[]) [[GTE1]]), index=0 +// CHECK: [[GTE3:%.*]] = pred[] get-tuple-element((s32[3,3]{0,1}, pred[]) [[GTE1]]), index=1 +// CHECK: [[GTE4:%.*]] = token[] get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) [[INFEED]]), index=1 +func.func @main(%arg0: !stablehlo.token) -> tuple, tensor>, !stablehlo.token> { + %0:3 = "stablehlo.infeed"(%arg0) {infeed_config = "foobar", layout=[[0, 1], [0]]} : (!stablehlo.token) -> (tensor<3x3xi32>, tensor, !stablehlo.token) + %1 = "stablehlo.tuple"(%0#0, %0#1) : (tensor<3x3xi32>, tensor) -> tuple, tensor> + %2 = "stablehlo.tuple"(%1, %0#2) : (tuple, tensor>, !stablehlo.token) -> tuple, tensor>, !stablehlo.token> + func.return %2 : tuple, tensor>, !stablehlo.token> +} diff --git a/third_party/xla/xla/hlo/translate/tests/simple.hlo b/third_party/xla/xla/hlo/translate/tests/simple.hlo new file mode 100644 index 00000000000000..0899d0e459b90f --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/simple.hlo @@ -0,0 +1,13 @@ +// RUN: hlo-translate -hlo-to-mlir %s | FileCheck %s + +// CHECK-LABEL: module @main +HloModule main, entry_computation_layout={(f32[4]{0}, f32[4]{0})->(f32[])} +ENTRY %main.6 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[]) { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + // CHECK: stablehlo.add + %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) + // CHECK: stablehlo.dot + %dot.4 = f32[] dot(f32[4] %add.3, f32[4] %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} + ROOT %tuple.5 = (f32[]) tuple(f32[] %dot.4) +} diff --git a/third_party/xla/xla/hlo/translate/tests/simple.mlir b/third_party/xla/xla/hlo/translate/tests/simple.mlir new file mode 100644 index 00000000000000..f9bf4e698c3004 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/tests/simple.mlir @@ -0,0 +1,28 @@ +// RUN: hlo-translate -mlir-to-hlo -split-input-file %s | FileCheck %s + +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[]) +func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { + // CHECK: %Arg_0.1 = f32[4] parameter(0) + // CHECK: %Arg_1.2 = f32[4] parameter(1) + // CHECK: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) + %0 = stablehlo.add %arg0, %arg1 : tensor<4xf32> + // CHECK: %dot.4 = f32[] dot(f32[4] %add.3, f32[4] %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} + %1 = stablehlo.dot %0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor + // CHECK: ROOT %tuple.5 = (f32[]) tuple(f32[] %dot.4) + func.return %1 : tensor +} + +// ----- +// MHLO to HLO + +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[4]) +func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %Arg_0.1 = f32[4] parameter(0) + // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) + // CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) + %0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2) + %1 = "mhlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // ROOT: %tuple.5 = (f32[4]) tuple(f32[4] %add.4) + func.return %1 : tensor<4xf32> +}