Skip to content

Commit

Permalink
Add GridSample test with data for failing unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
skottmckay committed Jul 20, 2024
1 parent 5d2b2ca commit 94f1c48
Showing 1 changed file with 55 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning

Run lintrunner -a to apply this patch.
import coremltools as ct
from coremltools.converters.mil import Builder as mb

Check warning

Code scanning / lintrunner

RUFF/N813 Warning

Camelcase Builder imported as lowercase mb.
See https://docs.astral.sh/ruff/rules/camelcase-imported-as-lowercase

target = ct.target.iOS15

x_shape = (2, 2, 3, 2)
grid_shape = (2, 3, 2, 2)

@mb.program(input_specs=[mb.TensorSpec(shape=x_shape),
mb.TensorSpec(shape=grid_shape)],
opset_version=target)
def prog(x, grid):
sampling = mb.const(name="sampling_mode", val="bilinear")
padding_mode = mb.const(name="pmode", val="reflection")
pad = mb.const(name="pval", val=np.float32(0))
coord_mode = mb.const(name="coord_mode", val="normalized_minus_one_to_one")
align_corners = mb.const(name="align_corners", val=False)
z = mb.resample(x=x, coordinates=grid, sampling_mode=sampling,
padding_mode=padding_mode, padding_value=pad, coordinates_mode=coord_mode,
align_corners=align_corners)

return z

# print(prog)

# Convert to ML program
m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32)

# spec = m.get_spec()
# print(spec)

m.save("GridSample.mlpackage")
# construct MLModel with compute_units=ComputeUnit.CPU and run predict
m_cpu = ct.models.MLModel('GridSample.mlpackage', compute_units=ct.ComputeUnit.CPU_ONLY)
m_all = ct.models.MLModel('GridSample.mlpackage', compute_units=ct.ComputeUnit.ALL)

# GridSampleTest.test_grid_sample_20_4D_bilinear_reflection_no_align_corners
# ORT produces different output for this test. ORT output is generated by pytorch
x = np.array([-0.173652, -1.513725, -0.704586, -1.952375, -0.699404, -0.806298,
1.640852, -0.138969, -0.695411, -1.352111, 0.568797, -0.564294,
-0.056468, 0.641604, -0.438370, 0.450167, -1.091401, 1.669729,
-0.908544, 0.244467, 0.172109, 1.156741, -0.617128, 1.155460]
).astype(np.float32).reshape(x_shape)

grid = np.array([
0.252250, -0.151452, 0.824706, -0.588292, -0.591147, -0.155082,
-0.732938, 0.457493, -0.439559, 0.492330, 0.696447, 0.700722,
-0.220298, 0.654884, -0.635434, -1.195619, -0.114204, -0.870080,
-0.929674, 0.305035, 1.025429, -0.472240, -0.067881, -0.869393]
).astype(np.float32).reshape(grid_shape)


print(m_cpu.predict({'x': x, 'grid': grid}))
print(m_all.predict({'x': x, 'grid': grid}))

0 comments on commit 94f1c48

Please sign in to comment.