Skip to content

Commit

Permalink
Refactor multilooking function. Remove dependency to dask-image libra…
Browse files Browse the repository at this point in the history
…ry used in the function before.
  • Loading branch information
Alexey Pechnikov committed Aug 26, 2024
1 parent 7303439 commit a47826d
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 83 deletions.
251 changes: 169 additions & 82 deletions pygmtsar/pygmtsar/Stack_multilooking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Licensed under the BSD 3-Clause License (see LICENSE for details)
# ----------------------------------------------------------------------------
from .Stack_phasediff import Stack_phasediff
from .utils import utils

class Stack_multilooking(Stack_phasediff):

Expand Down Expand Up @@ -92,65 +93,167 @@ def decimator(da):

# return callback function and set common chunk size
return lambda da: decimator(da)

# # coarsen = None disables downscaling and uses wavelength to filter
# # coarsen=1 disables downscaling and use coarsen/cutoff filter
# def multilooking(self, data, weight=None, wavelength=None, coarsen=None, debug=False):
# import xarray as xr
# import numpy as np
# import dask
# #from dask_image.ndfilters import gaussian_filter as dask_gaussian_filter
#
# # GMTSAR constant 5.3 defines half-gain at filter_wavelength
# # https://github.com/gmtsar/gmtsar/issues/411
# cutoff = 5.3
#
# # expand simplified definition
# if coarsen is not None and not isinstance(coarsen, (list, tuple, np.ndarray)):
# coarsen = (coarsen, coarsen)
#
# # allow this case to save the original grid resolution
# if wavelength is None and coarsen is None:
# return data
#
# # antialiasing (multi-looking) filter
# if wavelength is None:
# sigmas = [coarsen[0]/cutoff, coarsen[1]/cutoff]
# if debug:
# print (f'DEBUG: multilooking sigmas ({sigmas[0]:.2f}, {sigmas[1]:.2f}), for specified coarsen {coarsen}')
# else:
# dy, dx = self.get_spacing(data)
# #print ('DEBUG dy, dx', dy, dx)
# #sigmas = int(np.round(wavelength/dy/coarsen[0])), int(np.round(wavelength/dx))
# sigmas = [wavelength/cutoff/dy, wavelength/cutoff/dx]
# if debug:
# print (f'DEBUG: multilooking sigmas ({sigmas[0]:.2f}, {sigmas[1]:.2f}), for specified wavelength {wavelength:.1f}')
#
# # weighted and not weighted convolution on float and complex float data
# def apply_filter(data, weight, sigmas, truncate=2):
# if np.issubdtype(data.dtype, np.complexfloating):
# #print ('complexfloating')
# parts = []
# for part in [data.real, data.imag]:
# data_complex = ((1j + part) * (weight if weight is not None else 1)).fillna(0)
# conv_complex = dask_gaussian_filter(data_complex.data, sigmas, mode='reflect', truncate=truncate)
# #conv = conv_complex.real/conv_complex.imag
# # to prevent "RuntimeWarning: invalid value encountered in divide" even when warning filter is defined
# conv = dask.array.where(conv_complex.imag == 0, np.nan, conv_complex.real/(conv_complex.imag + 1e-17))
# del data_complex, conv_complex
# parts.append(conv)
# del conv
# conv = parts[0] + 1j*parts[1]
# del parts
# else:
# #print ('floating')
# # replace nan + 1j to to 0.+0.j
# data_complex = ((1j + data) * (weight if weight is not None else 1)).fillna(0)
# conv_complex = dask_gaussian_filter(data_complex.data, sigmas, mode='reflect', truncate=truncate)
# #conv = conv_complex.real/conv_complex.imag
# # to prevent "RuntimeWarning: invalid value encountered in divide" even when warning filter is defined
# conv = dask.array.where(conv_complex.imag == 0, np.nan, conv_complex.real/(conv_complex.imag + 1e-17))
# del data_complex, conv_complex
# return conv
#
# if isinstance(data, xr.Dataset):
# dims = data[list(data.data_vars)[0]].dims
# else:
# dims = data.dims
#
# if len(dims) == 2:
# stackvar = None
# else:
# stackvar = dims[0]
# #print ('stackvar', stackvar)
#
# if weight is not None:
# # for InSAR processing expect 2D weights
# assert isinstance(weight, xr.DataArray) and len(weight.dims)==2, \
# 'ERROR: multilooking weight should be 2D DataArray'
#
# if weight is not None and len(data.dims) == len(weight.dims):
# #print ('2D check shape weighted')
# # single 2D grid processing
# if isinstance(data, xr.Dataset):
# for varname in data.data_vars:
# assert data[varname].shape == weight.shape, \
# f'ERROR: multilooking data[{varname}] and weight variables have different shape'
# else:
# assert data.shape == weight.shape, 'ERROR: multilooking data and weight variables have different shape'
# elif weight is not None and len(data.dims) == len(weight.dims) + 1:
# #print ('3D check shape weighted')
# # stack of 2D grids processing
# if isinstance(data, xr.Dataset):
# for varname in data.data_vars:
# assert data[varname].shape[1:] == weight.shape, \
# f'ERROR: multilooking data[{varname}] slice and weight variables have different shape'
# else:
# assert data.shape[1:] == weight.shape, 'ERROR: multilooking data slice and weight variables have different shape'
#
# stack =[]
# for ind in range(len(data[stackvar]) if stackvar is not None else 1):
# if isinstance(data, xr.Dataset):
# #print (f'Dataset ind:{ind}')
# data_convs = []
# for key in data.data_vars:
# conv = utils.nanconvolve2d_gaussian(data[key][ind] if stackvar is not None else data[key],
# weight,
# sigmas)
# data_conv = xr.DataArray(conv, dims=data[key].dims[1:] if stackvar is not None else data[key].dims, name=data[key].name)
# del conv
# data_convs.append(data_conv)
# del data_conv
# stack.append(xr.merge(data_convs))
# del data_convs
# else:
# #print (f'DataArray ind:{ind}')
# conv = utils.nanconvolve2d_gaussian(data[ind] if stackvar is not None else data,
# weight,
# sigmas)
# data_conv = xr.DataArray(conv, dims=data.dims[1:] if stackvar is not None else data.dims, name=data.name)
# del conv
# stack.append(data_conv)
# del data_conv
#
# if stackvar is not None:
# #print ('3D')
# ds = xr.concat(stack, dim=stackvar).assign_coords(data.coords)
# else:
# #print ('2D')
# ds = stack[0].assign_coords(data.coords)
# del stack
#
# # it works faster when we prevent small output chunks
# chunksizes = {'y': self.chunksize, 'x': self.chunksize}
# if coarsen is not None:
# # coarse grid typically to square cells
# return ds.coarsen({'y': coarsen[0], 'x': coarsen[1]}, boundary='trim').mean().chunk(chunksizes)
# return ds.chunk(chunksizes)

# coarsen = None disables downscaling and uses wavelength to filter
# coarsen=1 disables downscaling and use coarsen/cutoff filter
def multilooking(self, data, weight=None, wavelength=None, coarsen=None, debug=False):
import xarray as xr
import numpy as np
import dask
from dask_image.ndfilters import gaussian_filter as dask_gaussian_filter

# GMTSAR constant 5.3 defines half-gain at filter_wavelength
# https://github.com/gmtsar/gmtsar/issues/411
cutoff = 5.3

# expand simplified definition
if coarsen is not None and not isinstance(coarsen, (list, tuple, np.ndarray)):
coarsen = (coarsen, coarsen)

# allow this case to save the original grid resolution

# Expand simplified definition of coarsen
coarsen = (coarsen, coarsen) if coarsen is not None and not isinstance(coarsen, (list, tuple, np.ndarray)) else coarsen

# no-op, processing is needed
if wavelength is None and coarsen is None:
return data

# antialiasing (multi-looking) filter
if wavelength is None:
sigmas = [coarsen[0]/cutoff, coarsen[1]/cutoff]

# calculate sigmas based on wavelength or coarsen
if wavelength is not None:
dy, dx = self.get_spacing(data)
sigmas = [wavelength / cutoff / dy, wavelength / cutoff / dx]
if debug:
print (f'DEBUG: multilooking sigmas ({sigmas[0]:.2f}, {sigmas[1]:.2f}), for specified coarsen {coarsen}')
print(f'DEBUG: multilooking sigmas ({sigmas[0]:.2f}, {sigmas[1]:.2f}), wavelength {wavelength:.1f}')
else:
dy, dx = self.get_spacing(data)
#print ('DEBUG dy, dx', dy, dx)
#sigmas = int(np.round(wavelength/dy/coarsen[0])), int(np.round(wavelength/dx))
sigmas = [wavelength/cutoff/dy, wavelength/cutoff/dx]
sigmas = [coarsen[0] / cutoff, coarsen[1] / cutoff]
if debug:
print (f'DEBUG: multilooking sigmas ({sigmas[0]:.2f}, {sigmas[1]:.2f}), for specified wavelength {wavelength:.1f}')

# weighted and not weighted convolution on float and complex float data
def apply_filter(data, weight, sigmas, truncate=2):
if np.issubdtype(data.dtype, np.complexfloating):
#print ('complexfloating')
parts = []
for part in [data.real, data.imag]:
data_complex = ((1j + part) * (weight if weight is not None else 1)).fillna(0)
conv_complex = dask_gaussian_filter(data_complex.data, sigmas, mode='reflect', truncate=truncate)
#conv = conv_complex.real/conv_complex.imag
# to prevent "RuntimeWarning: invalid value encountered in divide" even when warning filter is defined
conv = dask.array.where(conv_complex.imag == 0, np.nan, conv_complex.real/(conv_complex.imag + 1e-17))
del data_complex, conv_complex
parts.append(conv)
del conv
conv = parts[0] + 1j*parts[1]
del parts
else:
#print ('floating')
# replace nan + 1j to to 0.+0.j
data_complex = ((1j + data) * (weight if weight is not None else 1)).fillna(0)
conv_complex = dask_gaussian_filter(data_complex.data, sigmas, mode='reflect', truncate=truncate)
#conv = conv_complex.real/conv_complex.imag
# to prevent "RuntimeWarning: invalid value encountered in divide" even when warning filter is defined
conv = dask.array.where(conv_complex.imag == 0, np.nan, conv_complex.real/(conv_complex.imag + 1e-17))
del data_complex, conv_complex
return conv
print(f'DEBUG: multilooking sigmas ({sigmas[0]:.2f}, {sigmas[1]:.2f}), coarsen {coarsen}')

if isinstance(data, xr.Dataset):
dims = data[list(data.data_vars)[0]].dims
Expand Down Expand Up @@ -187,43 +290,27 @@ def apply_filter(data, weight, sigmas, truncate=2):
else:
assert data.shape[1:] == weight.shape, 'ERROR: multilooking data slice and weight variables have different shape'

stack =[]
for ind in range(len(data[stackvar]) if stackvar is not None else 1):
if isinstance(data, xr.Dataset):
#print (f'Dataset ind:{ind}')
data_convs = []
for key in data.data_vars:
conv = apply_filter(data[key][ind] if stackvar is not None else data[key],
weight,
sigmas)
data_conv = xr.DataArray(conv, dims=data[key].dims[1:] if stackvar is not None else data[key].dims, name=data[key].name)
del conv
data_convs.append(data_conv)
del data_conv
stack.append(xr.merge(data_convs))
del data_convs
# process a slice of dataarray
def process_slice(slice_data):
conv = utils.nanconvolve2d_gaussian(slice_data, weight, sigmas)
return xr.DataArray(conv, dims=slice_data.dims, name=slice_data.name)

# process stack of dataarray slices
def process_slice_var(dataarray):
if stackvar:
stack = [process_slice(dataarray[ind]) for ind in range(len(dataarray[stackvar]))]
return xr.concat(stack, dim=stackvar).assign_coords(dataarray.coords)
else:
#print (f'DataArray ind:{ind}')
conv = apply_filter(data[ind] if stackvar is not None else data,
weight,
sigmas)
data_conv = xr.DataArray(conv, dims=data.dims[1:] if stackvar is not None else data.dims, name=data.name)
del conv
stack.append(data_conv)
del data_conv

if stackvar is not None:
#print ('3D')
ds = xr.concat(stack, dim=stackvar).assign_coords(data.coords)
else:
#print ('2D')
ds = stack[0].assign_coords(data.coords)
del stack
return process_slice(dataarray).assign_coords(dataarray.coords)

# it works faster when we prevent small output chunks
if isinstance(data, xr.Dataset):
ds = xr.Dataset({varname: process_slice_var(data[varname]) for varname in data.data_vars})
else:
ds = process_slice_var(data)

# Set chunk size
chunksizes = {'y': self.chunksize, 'x': self.chunksize}
if coarsen is not None:
# coarse grid typically to square cells
if coarsen:
return ds.coarsen({'y': coarsen[0], 'x': coarsen[1]}, boundary='trim').mean().chunk(chunksizes)

return ds.chunk(chunksizes)

71 changes: 71 additions & 0 deletions pygmtsar/pygmtsar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,77 @@ class utils():
# .predict(np.column_stack([topo_values])).reshape(phase.shape)
# return xr.DataArray(phase_topo, coords=phase.coords)

@staticmethod
def nanconvolve2d_gaussian(data,
weight=None,
sigma=None,
mode='reflect',
truncate=4.0):
import numpy as np
import xarray as xr

if sigma is None:
return data

if not isinstance(sigma, (list, tuple, np.ndarray)):
sigma = (sigma, sigma)
depth = [np.ceil(_sigma * truncate).astype(int) for _sigma in sigma]
#print ('sigma', sigma, 'depth', depth)

# weighted Gaussian filtering for real floats with NaNs
def nanconvolve2d_gaussian_floating_dask_chunk(data, weight=None, **kwargs):
import numpy as np
from scipy.ndimage import gaussian_filter
assert not np.issubdtype(data.dtype, np.complexfloating)
assert np.issubdtype(data.dtype, np.floating)
if weight is not None:
assert not np.issubdtype(weight.dtype, np.complexfloating)
assert np.issubdtype(weight.dtype, np.floating)
# replace nan + 1j to to 0.+0.j
data_complex = (1j + data) * (weight if weight is not None else 1)
conv_complex = gaussian_filter(np.nan_to_num(data_complex, 0), **kwargs)
#conv = conv_complex.real/conv_complex.imag
# to prevent "RuntimeWarning: invalid value encountered in divide" even when warning filter is defined
conv = np.where(conv_complex.imag == 0, np.nan, conv_complex.real/(conv_complex.imag + 1e-17))
del data_complex, conv_complex
return conv

def nanconvolve2d_gaussian_dask_chunk(data, weight=None, **kwargs):
import numpy as np
if np.issubdtype(data.dtype, np.complexfloating):
#print ('complexfloating')
real = nanconvolve2d_gaussian_floating_dask_chunk(data.real, weight, **kwargs)
imag = nanconvolve2d_gaussian_floating_dask_chunk(data.imag, weight, **kwargs)
conv = real + 1j*imag
del real, imag
else:
#print ('floating')
conv = nanconvolve2d_gaussian_floating_dask_chunk(data.real, weight, **kwargs)
return conv

# weighted Gaussian filtering for real or complex floats
def nanconvolve2d_gaussian_dask(data, weight, **kwargs):
import dask.array as da
# ensure both dask arrays have the same chunk structure
# use map_overlap with the custom function to handle both arrays
return da.map_overlap(
nanconvolve2d_gaussian_dask_chunk,
*(da.broadcast_arrays(data, weight) if weight is not None else [data]),
depth=depth,
boundary='none',
dtype=data.dtype,
meta=data._meta,
**kwargs
)

return xr.DataArray(nanconvolve2d_gaussian_dask(data.data,
weight.data if weight is not None else None,
sigma=sigma,
mode=mode,
truncate=truncate),
coords=data.coords,
name=data.name)

@staticmethod
def histogram(data, bins, range):
"""
Expand Down
1 change: 0 additions & 1 deletion pygmtsar/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def get_version():
'geopandas',
'distributed>=2024.1.0',
'dask[complete]>=2024.4.1',
'dask-image',
'joblib',
'tqdm',
'scipy',
Expand Down

0 comments on commit a47826d

Please sign in to comment.