From 056e62759f5bc43e75e1a127783751a052eb782f Mon Sep 17 00:00:00 2001 From: Ruxandra Valcu Date: Fri, 15 Dec 2023 11:38:48 +0200 Subject: [PATCH] Memory optimizations for signal & transient noise (#141) * Memory optimizations for signal & transient noise * Memory optimizations for signal & transient noise --- echopype/clean/signal_attenuation.py | 36 +++++++++++++------ echopype/clean/transient_noise.py | 6 ++-- .../tests/clean/test_signal_attenuation.py | 9 ++--- echopype/tests/conftest.py | 23 ++++++++++++ echopype/utils/mask_transformation_xr.py | 22 ++++++++++++ 5 files changed, 77 insertions(+), 19 deletions(-) diff --git a/echopype/clean/signal_attenuation.py b/echopype/clean/signal_attenuation.py index 91659855b..7556c8761 100644 --- a/echopype/clean/signal_attenuation.py +++ b/echopype/clean/signal_attenuation.py @@ -3,9 +3,24 @@ import numpy as np import xarray as xr -from ..utils.mask_transformation_xr import lin as _lin, line_to_square, log as _log - -DEFAULT_RYAN_PARAMS = {"r0": 180, "r1": 280, "n": 30, "thr": -6, "start": 0} +from ..utils.mask_transformation_xr import ( + lin as _lin, + line_to_square, + log as _log, + rolling_median_block, +) + +# import dask.array as da + + +DEFAULT_RYAN_PARAMS = { + "r0": 180, + "r1": 280, + "n": 30, + "thr": -6, + "start": 0, + "dask_chunking": {"ping_time": 100, "range_sample": 100}, +} DEFAULT_ARIZA_PARAMS = {"offset": 20, "thr": (-40, -35), "m": 20, "n": 50} @@ -45,15 +60,17 @@ def _ryan(source_Sv: xr.DataArray, desired_channel: str, parameters=DEFAULT_RYAN Returns: xr.DataArray: boolean array with AS mask, with ping_time and range_sample dims """ - parameter_names = ("r0", "r1", "n", "thr", "start") + parameter_names = ("r0", "r1", "n", "thr", "start", "dask_chunking") if not all(name in parameters.keys() for name in parameter_names): raise ValueError( - "Missing parameters - should be r0, r1, n, thr, start, are" + str(parameters.keys()) + "Missing parameters - should be r0, r1, n, thr, start, dask_chunking are" + + str(parameters.keys()) ) r0 = parameters["r0"] r1 = parameters["r1"] n = parameters["n"] thr = parameters["thr"] + dask_chunking = parameters["dask_chunking"] # start = parameters["start"] channel_Sv = source_Sv.sel(channel=desired_channel) @@ -85,13 +102,10 @@ def _ryan(source_Sv: xr.DataArray, desired_channel: str, parameters=DEFAULT_RYAN layer_mask = (Sv["range_sample"] >= up) & (Sv["range_sample"] <= lw) layer_Sv = Sv.where(layer_mask) - # Creating shifted arrays for block comparison - shifted_arrays = [layer_Sv.shift(ping_time=i) for i in range(-n, n + 1)] - block = xr.concat(shifted_arrays, dim="shifted_ping_time") + layer_Sv_chunked = layer_Sv.chunk(dask_chunking) - # Computing the median of the block and the pings - ping_median = layer_Sv.median(dim="range_sample", skipna=True) - block_median = block.median(dim=["range_sample", "shifted_ping_time"], skipna=True) + block_median = rolling_median_block(layer_Sv_chunked.data, window_half_size=n, axis=0) + ping_median = layer_Sv_chunked.median(dim="range_sample", skipna=True) # Creating the mask based on the threshold mask_condition = (ping_median - block_median) > thr diff --git a/echopype/clean/transient_noise.py b/echopype/clean/transient_noise.py index e2c56d7ca..874cb18b6 100644 --- a/echopype/clean/transient_noise.py +++ b/echopype/clean/transient_noise.py @@ -27,6 +27,7 @@ lin as _lin, line_to_square, log as _log, + rolling_median_block, ) RYAN_DEFAULT_PARAMS = { @@ -231,10 +232,7 @@ def _fielding( ping_median = Sv_range.median(dim="range_sample", skipna=True) ping_75q = Sv_range.reduce(np.nanpercentile, q=75, dim="range_sample") - - shifted_arrays = [Sv_range.shift(ping_time=i) for i in range(-n, n + 1)] - block = xr.concat(shifted_arrays, dim="shifted_ping_time") - block_median = block.median(dim=["range_sample", "shifted_ping_time"], skipna=True) + block_median = rolling_median_block(Sv_range.data, window_half_size=n, axis=0) # identify columns in which noise can be found noise_col = (ping_75q < maxts) & ((ping_median - block_median) < thr[0]) diff --git a/echopype/tests/clean/test_signal_attenuation.py b/echopype/tests/clean/test_signal_attenuation.py index 2cd4c9d2c..35d2b7a7b 100644 --- a/echopype/tests/clean/test_signal_attenuation.py +++ b/echopype/tests/clean/test_signal_attenuation.py @@ -2,12 +2,10 @@ import numpy as np import echopype.clean - -DEFAULT_RYAN_PARAMS = {"r0": 180, "r1": 280, "n": 30, "thr": -6, "start": 0} +from echopype.clean.signal_attenuation import DEFAULT_RYAN_PARAMS # commented ariza out since the current interpretation relies on a # preexisting seabed mask, which is not available in this PR -# DEFAULT_ARIZA_PARAMS = {"offset": 20, "thr": (-40, -35), "m": 20, "n": 50} @pytest.mark.parametrize( @@ -18,7 +16,10 @@ ], ) def test_get_signal_attenuation_mask( - sv_dataset_jr161, method, parameters, expected_true_false_counts + sv_dataset_jr161, + method, + parameters, + expected_true_false_counts, ): # source_Sv = get_sv_dataset(test_data_path) desired_channel = "GPT 38 kHz 009072033fa5 1 ES38" diff --git a/echopype/tests/conftest.py b/echopype/tests/conftest.py index 8de6faafa..9341eef50 100644 --- a/echopype/tests/conftest.py +++ b/echopype/tests/conftest.py @@ -84,3 +84,26 @@ def _get_sv_dataset(file_path): ed = ep.open_raw(file_path, sonar_model="ek60") Sv = ep.calibrate.compute_Sv(ed).compute() return Sv + + +@pytest.fixture(scope="session") +def sv_ek80(): + base_url = "noaa-wcsd-pds.s3.amazonaws.com/" + path = "data/raw/Sally_Ride/SR1611/EK80/" + file_name = "D20161109-T163350.raw" + + local_path = os.path.join(TEST_DATA_FOLDER, file_name) + if os.path.isfile(local_path): + ed = ep.open_raw( + local_path, + sonar_model="EK80", + ) + else: + raw_file_address = base_url + path + file_name + rf = raw_file_address # Path(raw_file_address) + ed = ep.open_raw( + f"https://{rf}", + sonar_model="EK80", + ) + Sv = ep.calibrate.compute_Sv(ed, waveform_mode="CW", encode_mode="complex").compute() + return Sv diff --git a/echopype/utils/mask_transformation_xr.py b/echopype/utils/mask_transformation_xr.py index 774e56262..7d9baea70 100644 --- a/echopype/utils/mask_transformation_xr.py +++ b/echopype/utils/mask_transformation_xr.py @@ -134,3 +134,25 @@ def dask_nanmean(array, axis=None): if not isinstance(array, da.Array): raise TypeError("Expected a Dask array, got {}.".format(type(array))) return da.nanmean(array, axis=axis) + + +def block_nanmedian(block, i, n, axis): + """ + Since dask nanmedian doesn't work except when applied on a specific axis, + this is a kludge to enable us to calculate block medians without assigning + arrays with an extra dimension + """ + start = max(0, i - n) + end = min(block.shape[axis], i + n + 1) + indices = da.arange(start, end, dtype=int) + use_block = da.take(block, indices, axis) + res = np.nanmedian(use_block.compute()) + return res + + +def rolling_median_block(block, window_half_size, axis): + """ + Applies a median block calculation as a rolling function across an axis + """ + res = [block_nanmedian(block, i, window_half_size, axis) for i in range(block.shape[axis])] + return np.array(res)