Skip to content

Commit

Permalink
Add function topo_phase()
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Pechnikov committed Jan 27, 2024
1 parent 8927ba9 commit 163f040
Showing 1 changed file with 221 additions and 0 deletions.
221 changes: 221 additions & 0 deletions pygmtsar/pygmtsar/Stack_topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,224 @@ def plot_topo(self, topo='auto', caption='Topography on WGS84 ellipsoid, [m]',
plt.ylabel('Azimuth')
plt.title(caption)

def topo_phase(self, pairs, topo='auto', debug=False):
"""
decimator = sbas.decimator(resolution=15, grid=(1,1))
topophase = sbas.topo_phase(pairs, topo=decimator(sbas.get_topo()))
"""
import pandas as pd
import dask
import dask.dataframe
import xarray as xr
import numpy as np
#from tqdm.auto import tqdm
import joblib
import warnings
# suppress Dask warning "RuntimeWarning: invalid value encountered in divide"
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', module='dask')
warnings.filterwarnings('ignore', module='dask.core')

if debug:
print ('DEBUG: topo_phase')

# convert pairs (list, array, dataframe) to 2D numpy array
pairs, dates = self.get_pairs(pairs, dates=True)
pairs = pairs[['ref', 'rep']].astype(str).values

if isinstance(topo, str) and topo == 'auto':
topo = self.get_topo()

# calculate the combined earth curvature and topography correction
def calc_drho(rho, topo, earth_radius, height, b, alpha, Bx):
sina = np.sin(alpha)
cosa = np.cos(alpha)
c = earth_radius + height
# compute the look angle using equation (C26) in Appendix C
# GMTSAR uses long double here
#ret = earth_radius + topo.astype(np.longdouble)
ret = earth_radius + topo
cost = ((rho**2 + c**2 - ret**2) / (2. * rho * c))
#if (cost >= 1.)
# die("calc_drho", "cost >= 0");
sint = np.sqrt(1. - cost**2)
# Compute the offset effect from non-parallel orbit
term1 = rho**2 + b**2 - 2 * rho * b * (sint * cosa - cost * sina) - Bx**2
drho = -rho + np.sqrt(term1)
del term1, sint, cost, ret, c, cosa, sina
return drho

def block_phase(prm1, prm2, ylim, xlim):
# use outer variables date, stack_prm
# disable "distributed.utils_perf - WARNING - full garbage collections ..."
from dask.distributed import utils_perf
utils_perf.disable_gc_diagnosis()
import warnings
# suppress Dask warning "RuntimeWarning: invalid value encountered in divide"
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', module='dask')
warnings.filterwarnings('ignore', module='dask.core')

# for lazy Dask ddataframes
#prm1 = PRM(prm1)
#prm2 = PRM(prm2)
#prm1, prm2 = stack_prm[stack_idx]
#data1, data2 = stack_data[stack_idx]

# use outer variables topo, data1, data2, prm1, prm2
# build topo block
block_topo = topo.isel(y=slice(ylim[0], ylim[1]), x=slice(xlim[0], xlim[1]))\
.compute(n_workers=1)\
.fillna(0)

# set dimensions
xdim = prm1.get('num_rng_bins')
ydim = prm1.get('num_patches') * prm1.get('num_valid_az')

# set heights
htc = prm1.get('SC_height')
ht0 = prm1.get('SC_height_start')
htf = prm1.get('SC_height_end')

# compute the time span and the time spacing
tspan = 86400 * abs(prm2.get('SC_clock_stop') - prm2.get('SC_clock_start'))
assert (tspan >= 0.01) and (prm2.get('PRF') >= 0.01), 'Check sc_clock_start, sc_clock_end, or PRF'

from scipy import constants
# setup the default parameters
# constant from GMTSAR code for consistency
#SOL = 299792456.0
drange = constants.speed_of_light / (2 * prm2.get('rng_samp_rate'))
#drange = SOL / (2 * prm2.get('rng_samp_rate'))
alpha = prm2.get('alpha_start') * np.pi / 180
cnst = -4 * np.pi / prm2.get('radar_wavelength')

# calculate initial baselines
Bh0 = prm2.get('baseline_start') * np.cos(prm2.get('alpha_start') * np.pi / 180)
Bv0 = prm2.get('baseline_start') * np.sin(prm2.get('alpha_start') * np.pi / 180)
Bhf = prm2.get('baseline_end') * np.cos(prm2.get('alpha_end') * np.pi / 180)
Bvf = prm2.get('baseline_end') * np.sin(prm2.get('alpha_end') * np.pi / 180)
Bx0 = prm2.get('B_offset_start')
Bxf = prm2.get('B_offset_end')

# first case is quadratic baseline model, second case is default linear model
if prm2.get('baseline_center') != 0 or prm2.get('alpha_center') != 0 or prm2.get('B_offset_center') != 0:
Bhc = prm2.get('baseline_center') * np.cos(prm2.get('alpha_center') * np.pi / 180)
Bvc = prm2.get('baseline_center') * np.sin(prm2.get('alpha_center') * np.pi / 180)
Bxc = prm2.get('B_offset_center')

dBh = (-3 * Bh0 + 4 * Bhc - Bhf) / tspan
dBv = (-3 * Bv0 + 4 * Bvc - Bvf) / tspan
ddBh = (2 * Bh0 - 4 * Bhc + 2 * Bhf) / (tspan * tspan)
ddBv = (2 * Bv0 - 4 * Bvc + 2 * Bvf) / (tspan * tspan)

dBx = (-3 * Bx0 + 4 * Bxc - Bxf) / tspan
ddBx = (2 * Bx0 - 4 * Bxc + 2 * Bxf) / (tspan * tspan)
else:
dBh = (Bhf - Bh0) / tspan
dBv = (Bvf - Bv0) / tspan
dBx = (Bxf - Bx0) / tspan
ddBh = ddBv = ddBx = 0

# calculate height increment
dht = (-3 * ht0 + 4 * htc - htf) / tspan
ddht = (2 * ht0 - 4 * htc + 2 * htf) / (tspan * tspan)

# multiply xr.ones_like(topo) for correct broadcasting
near_range = xr.ones_like(block_topo)*(prm1.get('near_range') + \
block_topo.x * (1 + prm1.get('stretch_r')) * drange) + \
xr.ones_like(block_topo)*(block_topo.y * prm1.get('a_stretch_r') * drange)

# calculate the change in baseline and height along the frame if topoflag is on
time = block_topo.y * tspan / (ydim - 1)
Bh = Bh0 + dBh * time + ddBh * time**2
Bv = Bv0 + dBv * time + ddBv * time**2
Bx = Bx0 + dBx * time + ddBx * time**2
B = np.sqrt(Bh * Bh + Bv * Bv)
alpha = np.arctan2(Bv, Bh)
height = ht0 + dht * time + ddht * time**2

# calculate the combined earth curvature and topography correction
drho = calc_drho(near_range, block_topo, prm1.get('earth_radius'), height, B, alpha, Bx)

# make topographic and model phase corrections
# GMTSAR uses float32 complex operations with precision loss
#phase_shift = np.exp(1j * (cnst * drho).astype(np.float32))
phase_shift = np.exp(1j * (cnst * drho))
del block_topo, near_range, drho, height, B, alpha, Bx, Bv, Bh, time

return phase_shift.astype(np.complex64)

# # prepare lazy PRM
# # this is the "true way" but processing is ~40% slower due to additional Dask tasks
# def block_prms(date1, date2):
# prm1 = self.PRM(date1)
# prm2 = self.PRM(date2)
# prm2.set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned()
# prm1.set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned()
# return (prm1.df, prm2.df)
# # Define metadata explicitly to match PRM dataframe
# prm_meta = pd.DataFrame(columns=['name', 'value']).astype({'name': 'str', 'value': 'object'}).set_index('name')

# immediately prepare PRM
# here is some delay on the function call but the actual processing is faster
def prepare_prms(pair):
date1, date2 = pair
prm1 = self.PRM(date1)
prm2 = self.PRM(date2)
prm2.set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned()
prm1.set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned()
return {(date1, date2): (prm1, prm2)}

#with self.tqdm_joblib(tqdm(desc=f'Pre-Processing PRM', total=len(pairs))) as progress_bar:
prms = joblib.Parallel(n_jobs=-1)(joblib.delayed(prepare_prms)(pair) for pair in pairs)
# convert the list of dicts to a single dict
prms = {k: v for d in prms for k, v in d.items()}

# define blocks
chunks = topo.chunks
ychunks, xchunks = chunks[0], chunks[1]
ychunks = np.concatenate([[0], np.cumsum(ychunks)])
xchunks = np.concatenate([[0], np.cumsum(xchunks)])
ylims = [(y1, y2) for y1, y2 in zip(ychunks, ychunks[1:])]
xlims = [(x1, x2) for x1, x2 in zip(xchunks, xchunks[1:])]
#print ('ylims', ylims)
#print ('xlims', xlims)

stack = []
for stack_idx, pair in enumerate(pairs):
date1, date2 = pair

# Create a Dask DataFrame with provided metadata for each Dask block
#prms = dask.delayed(block_prms)(date1, date2)
#prm1 = dask.dataframe.from_delayed(dask.delayed(prms[0]), meta=prm_meta)
#prm2 = dask.dataframe.from_delayed(dask.delayed(prms[1]), meta=prm_meta)
prm1, prm2 = prms[(date1, date2)]

blocks_total = []
for ylim in ylims:
blocks = []
for xlim in xlims:
delayed = dask.delayed(block_phase)(prm1, prm2, ylim, xlim)
block = dask.array.from_delayed(delayed,
shape=((ylim[1]-ylim[0]), (xlim[1]-xlim[0])),
dtype=np.complex64)
blocks.append(block)
del block, delayed
blocks_total.append(blocks)
del blocks
intf = xr.DataArray(dask.array.block(blocks_total), coords={'y': topo.y, 'x': topo.x})
del blocks_total

# add to stack
stack.append(intf)
# cleanup
del intf, prm1, prm2
del prms

coord_pair = [' '.join(pair) for pair in pairs]
coord_ref = xr.DataArray(pd.to_datetime(pairs[:,0]), coords={'pair': coord_pair})
coord_rep = xr.DataArray(pd.to_datetime(pairs[:,1]), coords={'pair': coord_pair})

return xr.concat(stack, dim='pair').assign_coords(ref=coord_ref, rep=coord_rep, pair=coord_pair).where(np.isfinite(topo)).rename('phase')

0 comments on commit 163f040

Please sign in to comment.