From 48592934390fe1cd6ab3cfd7539762ab5af1bc1e Mon Sep 17 00:00:00 2001 From: Kristen Thyng Date: Wed, 14 Feb 2024 11:15:47 -0800 Subject: [PATCH] Allow for saving of interpolators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I made changes to structured.py to allow for saving the interpolators to save time on repeated simulations. The necessary attributes to use the save option need to be input via the reader, currently only available in the ROMS reader. This PR also has changes that clean up the name variable in the ROMS reader. Previously there was logic to deal with name for several cases but subsequently it was overwritten to be “roms native”. Now name is used if input and otherwise set to “roms native”. --- opendrift/readers/basereader/structured.py | 59 ++++++++++++++++------ opendrift/readers/reader_ROMS_native.py | 28 +++++----- 2 files changed, 58 insertions(+), 29 deletions(-) diff --git a/opendrift/readers/basereader/structured.py b/opendrift/readers/basereader/structured.py index 5cc9d3776..887f707b2 100644 --- a/opendrift/readers/basereader/structured.py +++ b/opendrift/readers/basereader/structured.py @@ -1,7 +1,9 @@ +import pickle import numpy as np import pyproj from scipy.ndimage import map_coordinates from abc import abstractmethod +from pathlib import Path from opendrift.readers.interpolation.structured import ReaderBlock from .variables import Variables @@ -56,7 +58,7 @@ def __init__(self): self.proj4 = 'None' self.proj = fakeproj.fakeproj() self.projected = False - logger.info('Making interpolator for lon,lat to x,y conversion...') + self.xmin = self.ymin = 0. self.delta_x = self.delta_y = 1. self.xmax = self.lon.shape[1] - 1 @@ -65,22 +67,49 @@ def __init__(self): self.numy = self.ymax self.x = np.arange(0, self.xmax+1) self.y = np.arange(0, self.ymax+1) + + # Making interpolator (lon, lat) -> x + # save to speed up next time + if hasattr(self, "save_interpolator") and self.save_interpolator and hasattr(self, "interpolator_filename"): + interpolator_filename = Path(self.interpolator_filename).with_suffix('.pickle') + else: + interpolator_filename = f'{self.name}_interpolators.pickle' + + if hasattr(self, "save_interpolator") and self.save_interpolator and Path(interpolator_filename).is_file(): + logger.info('Loading interpolator for lon,lat to x,y conversion...') + with open(interpolator_filename, 'rb') as file_handle: + interp_dict = pickle.load(file_handle) + spl_x = interp_dict["spl_x"] + spl_y = interp_dict["spl_y"] + + else: + logger.info('Making interpolator for lon,lat to x,y conversion...') + + block_x, block_y = np.mgrid[self.xmin:self.xmax + 1, + self.ymin:self.ymax + 1] + block_x, block_y = block_x.T, block_y.T + + spl_x = LinearNDInterpolator( + (self.lon.ravel(), self.lat.ravel()), + block_x.ravel(), + fill_value=np.nan) + # Reusing x-interpolator (deepcopy) with data for y + spl_y = copy.deepcopy(spl_x) + spl_y.values[:, 0] = block_y.ravel() + # Call interpolator to avoid threading-problem: + # https://github.com/scipy/scipy/issues/8856 + spl_x((0, 0)), spl_y((0, 0)) + + if hasattr(self, "save_interpolator") and self.save_interpolator: + + interp_dict = {"spl_x": spl_x, "spl_y": spl_y} + with open(interpolator_filename, 'wb') as f: + pickle.dump(interp_dict, f) + + self.spl_x = spl_x + self.spl_y = spl_y - block_x, block_y = np.mgrid[self.xmin:self.xmax + 1, - self.ymin:self.ymax + 1] - block_x, block_y = block_x.T, block_y.T - # Making interpolator (lon, lat) -> x - self.spl_x = LinearNDInterpolator( - (self.lon.ravel(), self.lat.ravel()), - block_x.ravel(), - fill_value=np.nan) - # Reusing x-interpolator (deepcopy) with data for y - self.spl_y = copy.deepcopy(self.spl_x) - self.spl_y.values[:, 0] = block_y.ravel() - # Call interpolator to avoid threading-problem: - # https://github.com/scipy/scipy/issues/8856 - self.spl_x((0, 0)), self.spl_y((0, 0)) else: self.projected = True diff --git a/opendrift/readers/reader_ROMS_native.py b/opendrift/readers/reader_ROMS_native.py index 7766245fd..a4bac49c9 100644 --- a/opendrift/readers/reader_ROMS_native.py +++ b/opendrift/readers/reader_ROMS_native.py @@ -39,8 +39,12 @@ class Reader(BaseReader, StructuredReader): :param name: Name of reader :type name: string, optional - :param proj4: PROJ.4 string describing projection of data. - :type proj4: string, optional + :param save_interpolator: Whether or not to save the interpolator that goes from lon/lat to x/y (calculated in structured.py) + :type save_interpolator: bool + + :param interpolator_path: If save_interpolator is True, user can input this string to control where interpolator is saved. + :type interpolator_path: Path, str, optional + Example: @@ -72,7 +76,8 @@ class Reader(BaseReader, StructuredReader): r = Reader(ds) """ - def __init__(self, filename=None, name=None, gridfile=None, standard_name_mapping={}): + def __init__(self, filename=None, name=None, gridfile=None, standard_name_mapping={}, + save_interpolator=False, interpolator_path=None): if filename is None: raise ValueError('Need filename as argument to constructor') @@ -121,21 +126,14 @@ def __init__(self, filename=None, name=None, gridfile=None, standard_name_mappin -6500, -7000, -7500, -8000]) gls_param = ['gls_cmu0', 'gls_p', 'gls_m', 'gls_n'] + + self.name = name or 'roms native' if isinstance(filename, xr.Dataset): self.Dataset = filename - if name is not None: - self.name = name - else: - import re - self.name = re.sub('[^0-9a-zA-Z]+', '_', filename.attrs['title']) else: filestr = str(filename) - if name is None: - self.name = filestr - else: - self.name = name try: # Open file, check that everything is ok @@ -161,6 +159,10 @@ def drop_non_essential_vars_pop(ds): except Exception as e: raise ValueError(e) + # this is an opporunity to save interpolators to pickle to save sim time + self.save_interpolator = save_interpolator + self.interpolator_path = interpolator_path or f'{self.name}_interpolators' + if gridfile is not None: # Merging gridfile dataset with main dataset gf = xr.open_dataset(gridfile) self.Dataset = xr.merge([self.Dataset, gf]) @@ -259,8 +261,6 @@ def drop_non_essential_vars_pop(ds): else: self.time_step = None - self.name = 'roms native' - self.precalculate_s2z_coefficients = True # Find all variables having standard_name