Skip to content

Commit

Permalink
lite: flatbuffer_import: convert stablehlo attribute to tensor type i…
Browse files Browse the repository at this point in the history
…nstead of vector type

PiperOrigin-RevId: 559606029
  • Loading branch information
zichuan-wei authored and tensorflower-gardener committed Aug 24, 2023
1 parent 21f553c commit e46b690
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions tensorflow/compiler/mlir/lite/flatbuffer_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,23 +263,18 @@ static mlir::Attribute BuildI64ArrayAttr(std::vector<int32_t> value,
return builder.getI64ArrayAttr(typecast);
}

static mlir::Attribute BuildDenseElementAttr(std::vector<int64_t> value,
mlir::Builder builder) {
return builder.getI64VectorAttr(value);
}

static mlir::Attribute BuildDenseElementAttr(std::vector<bool> value,
static mlir::Attribute BuildRankedTensorAttr(std::vector<int64_t> shape,
std::vector<bool> value,
mlir::Builder builder) {
// The implementation of getBoolVectorAttr is flawed, so we bypass it here
std::vector<int8_t> extendVec;
extendVec.reserve(value.size());
for (size_t i = 0; i < value.size(); ++i) {
extendVec[i] = static_cast<int8_t>(value[i]);
}
return mlir::DenseIntElementsAttr::get(
mlir::VectorType::get(static_cast<int64_t>(extendVec.size()),
builder.getI1Type()),
extendVec);
mlir::RankedTensorType ty =
tensorflow::GetTypeFromTFTensorShape(shape, builder.getIntegerType(1));
return mlir::DenseIntElementsAttr::get(ty, extendVec);
}

static mlir::Attribute BuildRankedTensorAttr(std::vector<int64_t> shape,
Expand Down Expand Up @@ -396,16 +391,22 @@ void mlir::BuiltinOptions2ToAttributes(
if (const auto* op = op_union.AsStablehloBroadcastInDimOptions()) {
attributes.emplace_back(builder.getNamedAttr(
"broadcast_dimensions",
BuildDenseElementAttr(op->broadcast_dimensions, builder)));
BuildRankedTensorAttr(
{static_cast<int64_t>(op->broadcast_dimensions.size())},
op->broadcast_dimensions, builder)));
return;
}
if (const auto* op = op_union.AsStablehloSliceOptions()) {
std::vector<int64_t> shape = {
static_cast<int64_t>(op->start_indices.size())};
attributes.emplace_back(builder.getNamedAttr(
"start_indices", BuildDenseElementAttr(op->start_indices, builder)));
"start_indices",
BuildRankedTensorAttr(shape, op->start_indices, builder)));
attributes.emplace_back(builder.getNamedAttr(
"limit_indices", BuildDenseElementAttr(op->limit_indices, builder)));
"limit_indices",
BuildRankedTensorAttr(shape, op->limit_indices, builder)));
attributes.emplace_back(builder.getNamedAttr(
"strides", BuildDenseElementAttr(op->strides, builder)));
"strides", BuildRankedTensorAttr(shape, op->strides, builder)));
return;
}
if (const auto* op = op_union.AsStablehloConvolutionOptions()) {
Expand Down Expand Up @@ -440,7 +441,9 @@ void mlir::BuiltinOptions2ToAttributes(
if (!(op->window_reversal.empty()))
attributes.emplace_back(builder.getNamedAttr(
"window_reversal",
BuildDenseElementAttr(op->window_reversal, builder)));
BuildRankedTensorAttr(
{static_cast<int64_t>(op->window_reversal.size())},
op->window_reversal, builder)));
attributes.emplace_back(builder.getNamedAttr(
"dimension_numbers",
mlir::stablehlo::ConvDimensionNumbersAttr::get(
Expand Down Expand Up @@ -490,7 +493,9 @@ void mlir::BuiltinOptions2ToAttributes(
}
if (const auto* op = op_union.AsStablehloReduceOptions()) {
attributes.emplace_back(builder.getNamedAttr(
"dimensions", BuildDenseElementAttr(op->dimensions, builder)));
"dimensions",
BuildRankedTensorAttr({static_cast<int64_t>(op->dimensions.size())},
op->dimensions, builder)));
return;
}
}
Expand Down

0 comments on commit e46b690

Please sign in to comment.