Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CoreML: Add GridSample ML Program support #21431

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/providers/coreml/builders/impl/base_op_builder.h"
#include "core/providers/coreml/builders/impl/builder_utils.h"
#include "core/providers/coreml/builders/model_builder.h"
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "core/providers/coreml/shape_utils.h"
#include "core/providers/shared/utils/utils.h"

namespace onnxruntime {
namespace coreml {

namespace {
std::string_view GetMode(const NodeAttrHelper& helper) {
// opset 16 used bilinear, nearest, bicubic
// opset 20+ uses linear, nearest, cubic
// bilinear is what CoreML uses, so prefer that
// bicubic/cubic isn't supported

const auto& mode = helper.Get("mode", "linear");
if (mode == "linear") {
return "bilinear";
}

return mode;
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
}
} // namespace

class GridSampleOpBuilder : public BaseOpBuilder {
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override;

bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;

bool SupportsMLProgram() const override { return true; }
};

Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder,
[[maybe_unused]] const Node& node,
[[maybe_unused]] const logging::Logger& logger) const {
#if defined(COREML_ENABLE_MLPROGRAM)
using namespace CoreML::Specification::MILSpec; // NOLINT
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.resample

const auto input_defs = node.InputDefs();
const auto output_defs = node.OutputDefs();

NodeAttrHelper helper(node);
std::string mode{GetMode(helper)}; // need a std::string for use in AddScalarConstant
std::string padding_mode = helper.Get("padding_mode", "zeros");
const bool align_corners = helper.Get("align_corners", 0);
const std::string coordinates_mode = "normalized_minus_one_to_one";

// adjust to coreml equivalents
if (padding_mode == "zeros") {
padding_mode = "constant";
}

auto op = model_builder.CreateOperation(node, "resample");
AddOperationInput(*op, "x", input_defs[0]->Name());
AddOperationInput(*op, "coordinates", input_defs[1]->Name());
AddOperationInput(*op, "sampling_mode", model_builder.AddScalarConstant(op->type(), "sampling_mode", mode));
AddOperationInput(*op, "padding_mode", model_builder.AddScalarConstant(op->type(), "padding_mode", padding_mode));
AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", 0.0f));
AddOperationInput(*op, "coordinates_mode",
model_builder.AddScalarConstant(op->type(), "coordinates_mode", coordinates_mode));
AddOperationInput(*op, "align_corners", model_builder.AddScalarConstant(op->type(), "align_corners", align_corners));

AddOperationOutput(*op, *output_defs[0]);

model_builder.AddOperation(std::move(op));
#endif
return Status::OK();
}

bool GridSampleOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
if (!input_params.create_mlprogram) {
LOGS(logger, VERBOSE) << "GridSample is not supported.";
return false;
}

const auto& input_defs = node.InputDefs();

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger)) {
LOGS(logger, VERBOSE) << "GridSample: failed to get input shape";
return false;
}

const auto input_rank = input_shape.size();
if (input_rank != 4) {
LOGS(logger, VERBOSE) << "GridSample only supports 4D input. Got:" << input_rank << "D";
return false;
}

NodeAttrHelper helper(node);
std::string_view mode = GetMode(helper);

if (mode != "bilinear" && mode != "zeros") {
LOGS(logger, VERBOSE) << "GridSample does not support mode of " << mode;
return false;
}

// there is one combination of settings where the unit test fails.
// The ORT unit test values are generated by pytorch so not clear if it's an issue with CoreML.
// CoreML output is consistent for CPU and non-CPU at least.
// Disabling until there's a use-case that requires this combination.
const auto& padding_mode = helper.Get("padding_mode", "zeros");
const bool align_corners = helper.Get("align_corners", 0);

if (mode == "bilinear" && padding_mode == "reflection" && align_corners == false) {
LOGS(logger, VERBOSE) << "GridSample does not support mode:" << mode << " padding_mode:" << padding_mode
<< " align_corners:" << align_corners
<< " currently due to output diffs that need to be investigated";
return false;
}

return true;
}

void CreateGridSampleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<GridSampleOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace coreml
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateSplitOpBuilder("Split", op_registrations);
}

CreateGridSampleOpBuilder("GridSample", op_registrations);

return op_registrations;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrati
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGridSampleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down
Loading
Loading