We need fast linear interpolation in arbitrary dimensions, compatible with
jax.jit
.
Suppose we have a function
Let
We now want to be able to evaluate
We want to do this where
We also want to be able to do this in a vectorized, parallelized manner, so that
we can pass a large number of evaluation points
This is already possible in Numba, using the Econforge interpolation library:
import numpy as np
from interpolation.splines import UCGrid, CGrid, nodes
from interpolation.splines import eval_linear
# function to interpolate
f = lambda x,y: np.sin(np.sqrt(x**2+y**2+0.00001))/np.sqrt(x**2+y**2+0.00001)
# uniform cartesian grid with a small number of points
grid = UCGrid((-1.0, 1.0, 10), (-1.0, 1.0, 10))
# get grid points
gp = nodes(grid) # 100x2 matrix
# compute values on grid points
values = f(gp[:,0], gp[:,1]).reshape((10,10))
# interpolate at one point
point = np.array([0.1,0.45]) # 1d array
val = eval_linear(grid, values, point) # float
# interpolate at many points:
points = np.random.random((10000,2))
eval_linear(grid, values, points) # 10000 vector
The question is, what is the most efficient way to replicate this in JAX?
Using jax.scipy.ndimage.map_coordinates
, @junnanZ has implemented evaluation at a single point for
We could then extend this from one point
to points
via vmap
.
This discussion suggests that the vmap approach will be efficient: jax-ml/jax#6312
Other alternatives include https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/math/batch_interp_regular_nd_grid
Are these our best options? If so, which is best?
Once we have decided, let's
- test to make sure that the output agrees with the output from the Python code above, and
- provide a nice interface, like the one in the Python code above.