Skip to content

Commit

Permalink
Realize utils.interp2d_like() using OpenCV interpolation methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Pechnikov committed Sep 15, 2024
1 parent b45b73b commit 37859e5
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 27 deletions.
160 changes: 133 additions & 27 deletions pygmtsar/pygmtsar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,39 +37,90 @@ class utils():
# this custom function handles the task more effectively.
@staticmethod
def interp2d_like(data, grid, method='cubic', **kwargs):
"""
Efficiently interpolate a 2D array using OpenCV interpolation methods.
Args:
data (xarray.DataArray): The input data array.
grid (xarray.DataArray): The grid to interpolate onto.
method (str): Interpolation method ('nearest', 'linear', 'cubic' or 'lanczos').
**kwargs: Additional arguments for interpolation.
Returns:
xarray.DataArray: The interpolated data.
"""
import cv2
import numpy as np
import xarray as xr
import dask.array as da
import os
import warnings
# suppress Dask warning "RuntimeWarning: invalid value encountered in divide"
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', module='dask')
warnings.filterwarnings('ignore', module='dask.core')

# detect dimensions and coordinates for 2D or 3D grid
dims = grid.dims[-2:]
dim1, dim2 = dims
coords = {dim1: grid[dim1], dim2: grid[dim2]}
#print (f'dims: {dims}, coords: {coords}')
#print ('coords', coords)

# Define interpolation method
if method == 'nearest':
interpolation = cv2.INTER_NEAREST
elif method == 'linear':
interpolation = cv2.INTER_LINEAR
elif method == 'cubic':
interpolation = cv2.INTER_CUBIC
elif method == 'lanczos':
interpolation = cv2.INTER_LANCZOS4
else:
raise ValueError(f"Unsupported interpolation {method}. Should be 'nearest', 'linear', 'cubic' or 'lanczos'")

# use outer variable data
def interpolate_chunk(out_chunk1, out_chunk2, dim1, dim2, method, **kwargs):
d1, d2 = float(data[dim1].diff(dim1)[0]), float(data[dim2].diff(dim2)[0])
#print ('d1, d2', d1, d2)
# TBD: can be added to the function parameters
borderMode = cv2.BORDER_REFLECT

# define interpolation function using outer variable data
def interpolate_chunk(out_chunk1, out_chunk2, dim1, dim2, interpolation, borderMode, **kwargs):
d1 = float(data[dim1].diff(dim1)[0])
d2 = float(data[dim2].diff(dim2)[0])

# select the chunk from data with some padding
chunk = data.sel({
dim1: slice(out_chunk1[0]-2*d1, out_chunk1[-1]+2*d1),
dim2: slice(out_chunk2[0]-2*d2, out_chunk2[-1]+2*d2)
}).compute(n_workers=1)
#print ('chunk', chunk)
out = chunk.interp({dim1: out_chunk1, dim2: out_chunk2}, method=method, **kwargs)
del chunk
return out

chunk_sizes = grid.chunks[-2:] if hasattr(grid, 'chunks') else (self.chunksize, self.chunksize)
# coordinates are numpy arrays
dim1: slice(out_chunk1[0] - 3 * d1, out_chunk1[-1] + 3 * d1),
dim2: slice(out_chunk2[0] - 3 * d2, out_chunk2[-1] + 3 * d2)
}).compute(n_workers=1)

# Create grid for interpolation
dst_grid_x, dst_grid_y = np.meshgrid(out_chunk2, out_chunk1)

# map destination grid coordinates to source pixel indices
src_x_coords = np.interp(
dst_grid_x.ravel(),
chunk[dim2].values,
np.arange(len(chunk[dim2]))
)
src_y_coords = np.interp(
dst_grid_y.ravel(),
chunk[dim1].values,
np.arange(len(chunk[dim1]))
)

# reshape the coordinates for remap
src_x_coords = src_x_coords.reshape(dst_grid_x.shape).astype(np.float32)
src_y_coords = src_y_coords.reshape(dst_grid_y.shape).astype(np.float32)

# interpolate using OpenCV
dst_grid = cv2.remap(
chunk.values.astype(np.float32),
src_x_coords,
src_y_coords,
interpolation=interpolation,
borderMode=borderMode
)
return dst_grid

# define chunk sizes
chunk_sizes = grid.chunks[-2:] if hasattr(grid, 'chunks') else (data.sizes[dim1], data.sizes[dim2])

# create dask array for parallel processing
grid_y = da.from_array(grid[dim1].values, chunks=chunk_sizes[0])
grid_x = da.from_array(grid[dim2].values, chunks=chunk_sizes[1])


# Perform interpolation
dask_out = da.blockwise(
interpolate_chunk,
'yx',
Expand All @@ -78,14 +129,69 @@ def interpolate_chunk(out_chunk1, out_chunk2, dim1, dim2, method, **kwargs):
dtype=data.dtype,
dim1=dim1,
dim2=dim2,
method=method,
interpolation=interpolation,
borderMode=borderMode,
**kwargs
)

da_out = xr.DataArray(dask_out, coords=coords, dims=dims).rename(data.name)
del dask_out
# append all the input coordinates

# Append all the input coordinates
return da_out.assign_coords({k: v for k, v in data.coords.items() if k not in coords})

# # Xarray's interpolation can be inefficient for large grids;
# # this custom function handles the task more effectively.
# @staticmethod
# def interp2d_like(data, grid, method='cubic', **kwargs):
# import xarray as xr
# import dask.array as da
# import os
# import warnings
# # suppress Dask warning "RuntimeWarning: invalid value encountered in divide"
# warnings.filterwarnings('ignore')
# warnings.filterwarnings('ignore', module='dask')
# warnings.filterwarnings('ignore', module='dask.core')
#
# # detect dimensions and coordinates for 2D or 3D grid
# dims = grid.dims[-2:]
# dim1, dim2 = dims
# coords = {dim1: grid[dim1], dim2: grid[dim2]}
# #print (f'dims: {dims}, coords: {coords}')
#
# # use outer variable data
# def interpolate_chunk(out_chunk1, out_chunk2, dim1, dim2, method, **kwargs):
# d1, d2 = float(data[dim1].diff(dim1)[0]), float(data[dim2].diff(dim2)[0])
# #print ('d1, d2', d1, d2)
# chunk = data.sel({
# dim1: slice(out_chunk1[0]-2*d1, out_chunk1[-1]+2*d1),
# dim2: slice(out_chunk2[0]-2*d2, out_chunk2[-1]+2*d2)
# }).compute(n_workers=1)
# #print ('chunk', chunk)
# out = chunk.interp({dim1: out_chunk1, dim2: out_chunk2}, method=method, **kwargs)
# del chunk
# return out
#
# chunk_sizes = grid.chunks[-2:] if hasattr(grid, 'chunks') else (self.chunksize, self.chunksize)
# # coordinates are numpy arrays
# grid_y = da.from_array(grid[dim1].values, chunks=chunk_sizes[0])
# grid_x = da.from_array(grid[dim2].values, chunks=chunk_sizes[1])
#
# dask_out = da.blockwise(
# interpolate_chunk,
# 'yx',
# grid_y, 'y',
# grid_x, 'x',
# dtype=data.dtype,
# dim1=dim1,
# dim2=dim2,
# method=method,
# **kwargs
# )
# da_out = xr.DataArray(dask_out, coords=coords, dims=dims).rename(data.name)
# del dask_out
# # append all the input coordinates
# return da_out.assign_coords({k: v for k, v in data.coords.items() if k not in coords})

@staticmethod
def nanconvolve2d_gaussian(data,
weight=None,
Expand Down
1 change: 1 addition & 0 deletions pygmtsar/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_version():
'geopandas',
'distributed>=2024.1.0',
'dask[complete]>=2024.4.1',
'opencv-python',
'joblib',
'tqdm',
'scipy',
Expand Down

0 comments on commit 37859e5

Please sign in to comment.