diff --git a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc index 914f91a64f89..bfc665e0ac71 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc @@ -13,6 +13,22 @@ namespace onnxruntime { namespace coreml { +namespace { +std::string_view GetMode(const NodeAttrHelper& helper) { + // opset 16 used bilinear, nearest, bicubic + // opset 20+ uses linear, nearest, cubic + // bilinear is what CoreML uses, so prefer that + // bicubic/cubic isn't supported + + const auto& mode = helper.Get("mode", "linear"); + if (mode == "linear") { + return "bilinear"; + } + + return mode; +} +} // namespace + class GridSampleOpBuilder : public BaseOpBuilder { Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; @@ -34,16 +50,12 @@ Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& const auto output_defs = node.OutputDefs(); NodeAttrHelper helper(node); - std::string mode = helper.Get("mode", "linear"); + std::string mode{GetMode(helper)}; // need a std::string for use in AddScalarConstant std::string padding_mode = helper.Get("padding_mode", "zeros"); const bool align_corners = helper.Get("align_corners", 0); const std::string coordinates_mode = "normalized_minus_one_to_one"; - // adjust values to coreml equivalents - if (mode == "linear") { - mode = "bilinear"; - } - + // adjust to coreml equivalents if (padding_mode == "zeros") { padding_mode = "constant"; } @@ -87,8 +99,9 @@ bool GridSampleOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp } NodeAttrHelper helper(node); - const auto& mode = helper.Get("mode", "linear"); - if (mode != "linear" && mode != "zeros") { + std::string_view mode = GetMode(helper); + + if (mode != "bilinear" && mode != "zeros") { LOGS(logger, VERBOSE) << "GridSample does not support mode of " << mode; return false; } @@ -100,7 +113,7 @@ bool GridSampleOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp const auto& padding_mode = helper.Get("padding_mode", "zeros"); const bool align_corners = helper.Get("align_corners", 0); - if (mode == "linear" && padding_mode == "reflection" && align_corners == false) { + if (mode == "bilinear" && padding_mode == "reflection" && align_corners == false) { LOGS(logger, VERBOSE) << "GridSample does not support mode:" << mode << " padding_mode:" << padding_mode << " align_corners:" << align_corners << " currently due to output diffs that need to be investigated";