Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
skottmckay committed Jul 22, 2024
2 parents e09a5a9 + 8890dfb commit 2fcce7a
Showing 1 changed file with 22 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@
namespace onnxruntime {
namespace coreml {

namespace {
std::string_view GetMode(const NodeAttrHelper& helper) {
// opset 16 used bilinear, nearest, bicubic
// opset 20+ uses linear, nearest, cubic
// bilinear is what CoreML uses, so prefer that
// bicubic/cubic isn't supported

const auto& mode = helper.Get("mode", "linear");
if (mode == "linear") {
return "bilinear";
}

return mode;
}
} // namespace

class GridSampleOpBuilder : public BaseOpBuilder {
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override;
Expand All @@ -34,16 +50,12 @@ Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder&
const auto output_defs = node.OutputDefs();

NodeAttrHelper helper(node);
std::string mode = helper.Get("mode", "linear");
std::string mode{GetMode(helper)}; // need a std::string for use in AddScalarConstant
std::string padding_mode = helper.Get("padding_mode", "zeros");
const bool align_corners = helper.Get("align_corners", 0);
const std::string coordinates_mode = "normalized_minus_one_to_one";

// adjust values to coreml equivalents
if (mode == "linear") {
mode = "bilinear";
}

// adjust to coreml equivalents
if (padding_mode == "zeros") {
padding_mode = "constant";
}
Expand Down Expand Up @@ -87,8 +99,9 @@ bool GridSampleOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp
}

NodeAttrHelper helper(node);
const auto& mode = helper.Get("mode", "linear");
if (mode != "linear" && mode != "zeros") {
std::string_view mode = GetMode(helper);

if (mode != "bilinear" && mode != "zeros") {
LOGS(logger, VERBOSE) << "GridSample does not support mode of " << mode;
return false;
}
Expand All @@ -100,7 +113,7 @@ bool GridSampleOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp
const auto& padding_mode = helper.Get("padding_mode", "zeros");
const bool align_corners = helper.Get("align_corners", 0);

if (mode == "linear" && padding_mode == "reflection" && align_corners == false) {
if (mode == "bilinear" && padding_mode == "reflection" && align_corners == false) {
LOGS(logger, VERBOSE) << "GridSample does not support mode:" << mode << " padding_mode:" << padding_mode
<< " align_corners:" << align_corners
<< " currently due to output diffs that need to be investigated";
Expand Down

0 comments on commit 2fcce7a

Please sign in to comment.