Skip to content

Commit

Permalink
lite:stablehlo: add serialization for stablehlo add, multiply, divide…
Browse files Browse the repository at this point in the history
… and maximum op

PiperOrigin-RevId: 555305030
  • Loading branch information
zichuan-wei authored and tensorflower-gardener committed Aug 9, 2023
1 parent b6f9258 commit 9b5e827
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 18 deletions.
46 changes: 40 additions & 6 deletions tensorflow/compiler/mlir/lite/flatbuffer_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,13 @@ class Translator {
return data_size > flatbuffer_size_max - builder_.GetSize();
}

// helper function for build stablehlo functions
std::optional<BufferOffset<tflite::Operator>>
BuildStablehloOperatorwithoutOptions(Operation* inst,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results,
tflite::BuiltinOperator op_code);

ModuleOp module_;

tensorflow::OpOrArgNameMapper& name_mapper_;
Expand Down Expand Up @@ -1363,6 +1370,19 @@ uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
return it.first->second;
}

std::optional<BufferOffset<tflite::Operator>>
Translator::BuildStablehloOperatorwithoutOptions(
Operation* inst, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results,
const tflite::BuiltinOperator op_code) {
std::string op_name = inst->getName().getStringRef().str();
uint32_t opcode_index = GetOpcodeIndex(op_name, op_code);

return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE, 0);
}

std::optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
Operation* inst, std::vector<int32_t> operands,
const std::vector<int32_t>& results,
Expand Down Expand Up @@ -1434,13 +1454,27 @@ std::optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
// builtin ops
if (dialect == stablehlo_dialect_) {
if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::LogisticOp>(inst)) {
std::string op_name = inst->getName().getStringRef().str();
uint32_t opcode_index =
GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_LOGISTIC);
return BuildStablehloOperatorwithoutOptions(
inst, operands, results, tflite::BuiltinOperator_STABLEHLO_LOGISTIC);
}

if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::AddOp>(inst)) {
return BuildStablehloOperatorwithoutOptions(
inst, operands, results, tflite::BuiltinOperator_STABLEHLO_ADD);
}

return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE, 0);
if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::MulOp>(inst)) {
return BuildStablehloOperatorwithoutOptions(
inst, operands, results, tflite::BuiltinOperator_STABLEHLO_MULTIPLY);
}

if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::DivOp>(inst)) {
return BuildStablehloOperatorwithoutOptions(
inst, operands, results, tflite::BuiltinOperator_STABLEHLO_DIVIDE);
}
if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::MaxOp>(inst)) {
return BuildStablehloOperatorwithoutOptions(
inst, operands, results, tflite::BuiltinOperator_STABLEHLO_MAXIMUM);
}
}

Expand Down
48 changes: 42 additions & 6 deletions tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir
Original file line number Diff line number Diff line change
@@ -1,16 +1,52 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// test stablehlo roundtrip

module {
func.func @main(%arg0: tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> {
%0 = stablehlo.logistic %arg0 : tensor<1x1x1x96xf32>
func.return %0 : tensor<1x1x1x96xf32>
}

// CHECK:func.func @main(%arg0: tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "stablehlo.logistic"}} {
// CHECK: %0 = stablehlo.logistic %arg0 : tensor<1x1x1x96xf32>
// CHECK: return %0 : tensor<1x1x1x96xf32>
// CHECK:}

func.func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<1xf32>
func.return %0 : tensor<1xf32>
}

// CHECK:func.func private @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32>
// CHECK: return %0 : tensor<1xf32>
// CHECK:}

func.func @multiply(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = stablehlo.multiply %arg0, %arg1 : tensor<1xf32>
func.return %0 : tensor<1xf32>
}

// CHECK:func.func private @multiply(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: %0 = stablehlo.multiply %arg0, %arg1 : tensor<1xf32>
// CHECK: return %0 : tensor<1xf32>
// CHECK:}

func.func @divide(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = stablehlo.divide %arg0, %arg1 : tensor<1xf32>
func.return %0 : tensor<1xf32>
}

// CHECK:func.func private @divide(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: %0 = stablehlo.divide %arg0, %arg1 : tensor<1xf32>
// CHECK: return %0 : tensor<1xf32>
// CHECK:}

func.func @maximum(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = stablehlo.maximum %arg0, %arg1 : tensor<1xf32>
func.return %0 : tensor<1xf32>
}

// CHECK:module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
// CHECK: func.func @main(%arg0: tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "stablehlo.logistic"}} {
// CHECK: %0 = stablehlo.logistic %arg0 : tensor<1x1x1x96xf32>
// CHECK: return %0 : tensor<1x1x1x96xf32>
// CHECK: }
// CHECK:func.func private @maximum(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: %0 = stablehlo.maximum %arg0, %arg1 : tensor<1xf32>
// CHECK: return %0 : tensor<1xf32>
// CHECK:}
4 changes: 4 additions & 0 deletions tensorflow/lite/builtin_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ typedef enum {
kTfLiteBuiltinBitwiseXor = 160,
kTfLiteBuiltinRightShift = 161,
kTfLiteBuiltinStablehloLogistic = 162,
kTfLiteBuiltinStablehloAdd = 163,
kTfLiteBuiltinStablehloDivide = 164,
kTfLiteBuiltinStablehloMultiply = 165,
kTfLiteBuiltinStablehloMaximum = 166,
} TfLiteBuiltinOperator;

#ifdef __cplusplus
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/lite/core/api/flatbuffer_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_BITCAST:
case BuiltinOperator_WHERE:
case BuiltinOperator_STABLEHLO_LOGISTIC:
case BuiltinOperator_STABLEHLO_ADD:
case BuiltinOperator_STABLEHLO_DIVIDE:
case BuiltinOperator_STABLEHLO_MULTIPLY:
case BuiltinOperator_STABLEHLO_MAXIMUM:
return kTfLiteOk;
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
return kTfLiteError;
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/lite/core/kernels/builtin_op_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ TfLiteRegistration* Register_RIGHT_SHIFT();
TfLiteRegistration*
Register_STABLEHLO_LOGISTIC(); // WARNING: not implemented, using this op will
// crash the runtime
TfLiteRegistration*
Register_STABLEHLO_ADD(); // WARNING: not implemented, using this op will crash
// the runtime
TfLiteRegistration*
Register_STABLEHLO_DIVIDE(); // WARNING: not implemented, using this op will
// crash the runtime
TfLiteRegistration*
Register_STABLEHLO_MULTIPLY(); // WARNING: not implemented, using this op will
// crash the runtime
TfLiteRegistration*
Register_STABLEHLO_MAXIMUM(); // WARNING: not implemented, using this op will
// crash the runtime
} // namespace builtin
} // namespace ops
} // namespace tflite
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/lite/kernels/builtin_ops_list.inc
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,7 @@ TFLITE_OP(Register_BITCAST)
TFLITE_OP(Register_BITWISE_XOR)
TFLITE_OP(Register_RIGHT_SHIFT)
TFLITE_OP(Register_STABLEHLO_LOGISTIC)
TFLITE_OP(Register_STABLEHLO_ADD)
TFLITE_OP(Register_STABLEHLO_DIVIDE)
TFLITE_OP(Register_STABLEHLO_MULTIPLY)
TFLITE_OP(Register_STABLEHLO_MAXIMUM)
4 changes: 4 additions & 0 deletions tensorflow/lite/schema/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ enum BuiltinOperator : int32 {
// All Operators start with STABLEHLO_ prefixes are subject to change
// Many of the ops below can not be executed by TFlite runtime
STABLEHLO_LOGISTIC = 162, // WARNING: Do not have runtime support
STABLEHLO_ADD = 163, // WARNING: No runtime support yet
STABLEHLO_DIVIDE = 164, // WARNING: No runtime support yet
STABLEHLO_MULTIPLY = 165, // WARNING: No runtime support yet
STABLEHLO_MAXIMUM = 166, // WARNING: No runtime support yet
}
// LINT.ThenChange(nnapi_linter/linter.proto)

Expand Down
22 changes: 17 additions & 5 deletions tensorflow/lite/schema/schema_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -1089,11 +1089,15 @@ enum BuiltinOperator : int32_t {
BuiltinOperator_BITWISE_XOR = 160,
BuiltinOperator_RIGHT_SHIFT = 161,
BuiltinOperator_STABLEHLO_LOGISTIC = 162,
BuiltinOperator_STABLEHLO_ADD = 163,
BuiltinOperator_STABLEHLO_DIVIDE = 164,
BuiltinOperator_STABLEHLO_MULTIPLY = 165,
BuiltinOperator_STABLEHLO_MAXIMUM = 166,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_STABLEHLO_LOGISTIC
BuiltinOperator_MAX = BuiltinOperator_STABLEHLO_MAXIMUM
};

inline const BuiltinOperator (&EnumValuesBuiltinOperator())[163] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[167] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
Expand Down Expand Up @@ -1257,13 +1261,17 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[163] {
BuiltinOperator_BITCAST,
BuiltinOperator_BITWISE_XOR,
BuiltinOperator_RIGHT_SHIFT,
BuiltinOperator_STABLEHLO_LOGISTIC
BuiltinOperator_STABLEHLO_LOGISTIC,
BuiltinOperator_STABLEHLO_ADD,
BuiltinOperator_STABLEHLO_DIVIDE,
BuiltinOperator_STABLEHLO_MULTIPLY,
BuiltinOperator_STABLEHLO_MAXIMUM
};
return values;
}

inline const char * const *EnumNamesBuiltinOperator() {
static const char * const names[164] = {
static const char * const names[168] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
Expand Down Expand Up @@ -1427,13 +1435,17 @@ inline const char * const *EnumNamesBuiltinOperator() {
"BITWISE_XOR",
"RIGHT_SHIFT",
"STABLEHLO_LOGISTIC",
"STABLEHLO_ADD",
"STABLEHLO_DIVIDE",
"STABLEHLO_MULTIPLY",
"STABLEHLO_MAXIMUM",
nullptr
};
return names;
}

inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_STABLEHLO_LOGISTIC)) return "";
if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_STABLEHLO_MAXIMUM)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,15 @@ class OpOptionData {
op_to_option_["BITCAST"] = "";
op_to_option_["BITWISE_XOR"] = "";
op_to_option_["RIGHT_SHIFT"] = "";
// HACK(b/293937201): currently we're hitting the Flatbuffer Java API limit
// HACK(b/294399204): currently we're hitting the Flatbuffer Java API limit
// for union structs
// for all new ops thta uses none option, manually map it here, instead of
// adding a new option
op_to_option_["STABLEHLO_LOGISTIC"] = "";
op_to_option_["STABLEHLO_ADD"] = "";
op_to_option_["STABLEHLO_DIVIDE"] = "";
op_to_option_["STABLEHLO_MULTIPLY"] = "";
op_to_option_["STABLEHLO_MAXIMUM"] = "";

// TODO(aselle): These are undesirable hacks. Consider changing C structs
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
Expand Down

0 comments on commit 9b5e827

Please sign in to comment.