From f9b7407ebd5edf9243a9e2565a71fe2148122e68 Mon Sep 17 00:00:00 2001 From: Alexey Pechnikov Date: Sat, 21 Sep 2024 15:00:57 +0700 Subject: [PATCH] Code refactoring to maintain subswaths offsets --- pygmtsar/pygmtsar/IO.py | 94 +++--------------------------- pygmtsar/pygmtsar/Stack_align.py | 99 ++++++++++++++++++++++++++++++++ pygmtsar/pygmtsar/Stack_topo.py | 43 ++++++++++++-- 3 files changed, 146 insertions(+), 90 deletions(-) diff --git a/pygmtsar/pygmtsar/IO.py b/pygmtsar/pygmtsar/IO.py index 78ee1f8..07811b6 100644 --- a/pygmtsar/pygmtsar/IO.py +++ b/pygmtsar/pygmtsar/IO.py @@ -318,90 +318,11 @@ def open_data(self, dates=None, scale=2.5e-07, debug=False): shape = (slc.y.size, slc.x.size) del slc, prm else: - # calculate the offsets to merge subswaths - prms = [] - ylims = [] - xlims = [] - for subswath in subswaths: - #print (subswath) - prm = self.PRM(subswath=subswath) - prms.append(prm) - ylims.append(prm.get('num_valid_az')) - xlims.append(prm.get('num_rng_bins')) - - assert len(np.unique([prm.get('PRF') for prm in prms])), 'Image PRFs are not consistent' - assert len(np.unique([prm.get('rng_samp_rate') for prm in prms])), 'Image range sampling rates are not consistent' - - bottoms = [0] + [int(np.round(((prm.get('clock_start') - prms[0].get('clock_start')) * 86400 * prms[0].get('PRF')))) for prm in prms[1:]] - # head123: 0, 466, -408 - if debug: - print ('bottoms init', bottoms) - # minh: -408 - minh = min(bottoms) - if debug: - print ('minh', minh) - #head123: 408, 874, 0 - bottoms = np.asarray(bottoms) - minh - if debug: - print ('bottoms', bottoms) - - #ovl12,23: 2690, 2558 - ovls = [prm1.get('num_rng_bins') - \ - int(np.round(((prm2.get('near_range') - prm1.get('near_range')) / (constants.speed_of_light/ prm1.get('rng_samp_rate') / 2)))) \ - for (prm1, prm2) in zip(prms[:-1], prms[1:])] - if debug: - print ('ovls', ovls) - - #Writing the grid files..Size(69158x13075)... - #maxy: 13075 - # for SLC - maxy = max([prm.get('num_valid_az') + bottom for prm, bottom in zip(prms, bottoms)]) - if debug: - print ('maxy', maxy) - maxx = sum([prm.get('num_rng_bins') - ovl - 1 for prm, ovl in zip(prms, [-1] + ovls)]) - if debug: - print ('maxx', maxx) - - #Stitching location n1 = 1045 - #Stitching location n2 = 935 - ns = [np.ceil(-prm.get('rshift') + prm.get('first_sample') + 150.0).astype(int) for prm in prms[1:]] - ns = [10 if n < 10 else n for n in ns] - if debug: - print ('ns', ns) - - # left and right coordinates for every subswath valid area - lefts = [] - rights = [] - - # 1st - xlim = prms[0].get('num_rng_bins') - ovls[0] + ns[0] - lefts.append(0) - rights.append(xlim) - - # 2nd - if len(prms) == 2: - xlim = prms[1].get('num_rng_bins') - 1 - else: - # for 3 subswaths - xlim = prms[1].get('num_rng_bins') - ovls[1] + ns[1] - lefts.append(ns[0]) - rights.append(xlim) - - # 3rd - if len(prms) == 3: - xlim = prms[2].get('num_rng_bins') - 2 - lefts.append(ns[1]) - rights.append(xlim) - - # check and merge SLCs - sumx = sum([right-left for right, left in zip(rights, lefts)]) - if debug: - print ('assert maxx == sum(...)', maxx, sumx) - assert maxx == sumx, 'Incorrect output grid range dimension size' - - offsets = {'bottoms': bottoms, 'lefts': lefts, 'rights': rights, 'bottom': minh, 'extent': [maxy, maxx], 'ylims': ylims, 'xlims': xlims} - if debug: - print ('offsets', offsets) + + #offsets = {'bottoms': bottoms, 'lefts': lefts, 'rights': rights, 'bottom': minh, 'extent': [maxy, maxx], 'ylims': ylims, 'xlims': xlims} + offsets = self.subswaths_offsets(debug=debug) + maxy, maxx = offsets['extent'] + minh = offsets['bottom'] # merge subswaths stack = [] @@ -409,7 +330,8 @@ def open_data(self, dates=None, scale=2.5e-07, debug=False): slcs = [] prms = [] - for subswath, bottom, left, right, ylim, xlim in zip(subswaths, bottoms, lefts, rights, ylims, xlims): + for subswath, bottom, left, right, ylim, xlim in zip(subswaths, + offsets['bottoms'], offsets['lefts'], offsets['rights'], offsets['ylims'], offsets['xlims']): print (date, subswath) prm = self.PRM(date, subswath=int(subswath)) # disable scaling @@ -420,7 +342,7 @@ def open_data(self, dates=None, scale=2.5e-07, debug=False): # check and merge SLCs, use zero fill for np.int16 datatype slc = xr.concat(slcs, dim='x', fill_value=0).assign_coords(x=0.5 + np.arange(maxx)) - + if debug: print ('assert slc.y.size == maxy', slc.y.size, maxy) assert slc.y.size == maxy, 'Incorrect output grid azimuth dimension size' diff --git a/pygmtsar/pygmtsar/Stack_align.py b/pygmtsar/pygmtsar/Stack_align.py index cda4f61..24fe662 100644 --- a/pygmtsar/pygmtsar/Stack_align.py +++ b/pygmtsar/pygmtsar/Stack_align.py @@ -269,6 +269,105 @@ def _align_rep_subswath(self, subswath, date=None, degrees=12.0/3600, debug=Fals #if os.path.exists(filename): os.remove(filename) + def subswaths_offsets(self, debug=False): + import xarray as xr + import numpy as np + from scipy import constants + + subswaths = self.get_subswaths() + if not isinstance(subswaths, (str, int)): + subswaths = ''.join(map(str, subswaths)) + + if len(subswaths) == 1: + return + + # calculate the offsets to merge subswaths + prms = [] + ylims = [] + xlims = [] + for subswath in subswaths: + #print (subswath) + prm = self.PRM(subswath=subswath) + prms.append(prm) + ylims.append(prm.get('num_valid_az')) + xlims.append(prm.get('num_rng_bins')) + + assert len(np.unique([prm.get('PRF') for prm in prms])), 'Image PRFs are not consistent' + assert len(np.unique([prm.get('rng_samp_rate') for prm in prms])), 'Image range sampling rates are not consistent' + + bottoms = [0] + [int(np.round(((prm.get('clock_start') - prms[0].get('clock_start')) * 86400 * prms[0].get('PRF')))) for prm in prms[1:]] + # head123: 0, 466, -408 + if debug: + print ('bottoms init', bottoms) + # minh: -408 + minh = min(bottoms) + if debug: + print ('minh', minh) + #head123: 408, 874, 0 + bottoms = np.asarray(bottoms) - minh + if debug: + print ('bottoms', bottoms) + + #ovl12,23: 2690, 2558 + ovls = [prm1.get('num_rng_bins') - \ + int(np.round(((prm2.get('near_range') - prm1.get('near_range')) / (constants.speed_of_light/ prm1.get('rng_samp_rate') / 2)))) \ + for (prm1, prm2) in zip(prms[:-1], prms[1:])] + if debug: + print ('ovls', ovls) + + #Writing the grid files..Size(69158x13075)... + #maxy: 13075 + # for SLC + maxy = max([prm.get('num_valid_az') + bottom for prm, bottom in zip(prms, bottoms)]) + if debug: + print ('maxy', maxy) + maxx = sum([prm.get('num_rng_bins') - ovl - 1 for prm, ovl in zip(prms, [-1] + ovls)]) + if debug: + print ('maxx', maxx) + + #Stitching location n1 = 1045 + #Stitching location n2 = 935 + ns = [np.ceil(-prm.get('rshift') + prm.get('first_sample') + 150.0).astype(int) for prm in prms[1:]] + ns = [10 if n < 10 else n for n in ns] + if debug: + print ('ns', ns) + + # left and right coordinates for every subswath valid area + lefts = [] + rights = [] + + # 1st + xlim = prms[0].get('num_rng_bins') - ovls[0] + ns[0] + lefts.append(0) + rights.append(xlim) + + # 2nd + if len(prms) == 2: + xlim = prms[1].get('num_rng_bins') - 1 + else: + # for 3 subswaths + xlim = prms[1].get('num_rng_bins') - ovls[1] + ns[1] + lefts.append(ns[0]) + rights.append(xlim) + + # 3rd + if len(prms) == 3: + xlim = prms[2].get('num_rng_bins') - 2 + lefts.append(ns[1]) + rights.append(xlim) + + # check and merge SLCs + sumx = sum([right-left for right, left in zip(rights, lefts)]) + if debug: + print ('assert maxx == sum(...)', maxx, sumx) + assert maxx == sumx, 'Incorrect output grid range dimension size' + + offsets = {'bottoms': bottoms, 'lefts': lefts, 'rights': rights, 'bottom': minh, 'extent': [maxy, maxx], 'ylims': ylims, 'xlims': xlims} + if debug: + print ('offsets', offsets) + + return offsets + def baseline_table(self, n_jobs=-1, debug=False): """ Generates a baseline table for Sentinel-1 data, containing dates and baseline components. diff --git a/pygmtsar/pygmtsar/Stack_topo.py b/pygmtsar/pygmtsar/Stack_topo.py index 3a590d8..23324d0 100644 --- a/pygmtsar/pygmtsar/Stack_topo.py +++ b/pygmtsar/pygmtsar/Stack_topo.py @@ -193,15 +193,50 @@ def block_phase_dask(block_topo, y_chunk, x_chunk, prm1, prm2): # immediately prepare PRM # here is some delay on the function call but the actual processing is faster - def prepare_prms(pair): + offsets = self.subswaths_offsets(debug=debug) + # reference scene first subswath + prm0 = self.PRM() + if offsets is not None: + # multiple subswaths + maxy, maxx = offsets['extent'] + minh = offsets['bottom'] + dt1 = minh / prm0.get('PRF') / 86400 + dt2 = maxy / prm0.get('PRF') / 86400 + else: + # one subswath + #maxy, maxx = prm0.get('num_valid_az', 'num_rng_bins') + dt1 = dt2 = 0 + del prm0 + + # dt = minh / prm.get('PRF') / 86400 + # prm = prm.set(SLC_file=None, + # num_lines=maxy, nrows=maxy, num_valid_az=maxy, + # num_rng_bins=maxx, bytes_per_line=4*maxx, good_bytes=4*maxx, + # SC_clock_start=prm.get('SC_clock_start') + dt, + # clock_start=prm.get('clock_start') + dt, + # SC_clock_stop=prm.get('SC_clock_start') + maxy / prm.get('PRF') / 86400, + # clock_stop=prm.get('clock_start') + maxy / prm.get('PRF') / 86400)\ + # .to_file(prm_filename) + + def prepare_prms(pair, maxy, maxx, dt1, dt2): date1, date2 = pair prm1 = self.PRM(date1) prm2 = self.PRM(date2) - prm2.set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned() - prm1.set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned() + prm2.set( + SC_clock_start=prm2.get('SC_clock_start') + dt1, + clock_start=prm2.get('clock_start') + dt1, + SC_clock_stop=prm2.get('SC_clock_start') + dt2, + clock_stop=prm2.get('clock_start') + dt2)\ + .set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned() + prm1.set( + SC_clock_start=prm1.get('SC_clock_start') + dt1, + clock_start=prm1.get('clock_start') + dt1, + SC_clock_stop=prm1.get('SC_clock_start') + dt2, + clock_stop=prm1.get('clock_start') + dt2)\ + .set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned() return (prm1, prm2) - prms = joblib.Parallel(n_jobs=-1)(joblib.delayed(prepare_prms)(pair) for pair in pairs) + prms = joblib.Parallel(n_jobs=-1)(joblib.delayed(prepare_prms)(pair, maxy, maxx, dt1, dt2) for pair in pairs) # fill NaNs by 0 and expand to 3d topo2d = da.where(da.isnan(topo.data), 0, topo.data)