Skip to content

Commit

Permalink
Merge pull request #1228 from kthyng/save_interpolators
Browse files Browse the repository at this point in the history
Allow for saving of interpolators
  • Loading branch information
knutfrode authored Feb 14, 2024
2 parents 4ac4e00 + f43dccd commit 567998f
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 27 deletions.
59 changes: 44 additions & 15 deletions opendrift/readers/basereader/structured.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pickle
import numpy as np
import pyproj
from scipy.ndimage import map_coordinates
from abc import abstractmethod
from pathlib import Path

from opendrift.readers.interpolation.structured import ReaderBlock
from .variables import Variables
Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(self):
self.proj4 = 'None'
self.proj = fakeproj.fakeproj()
self.projected = False
logger.info('Making interpolator for lon,lat to x,y conversion...')

self.xmin = self.ymin = 0.
self.delta_x = self.delta_y = 1.
self.xmax = self.lon.shape[1] - 1
Expand All @@ -65,22 +67,49 @@ def __init__(self):
self.numy = self.ymax
self.x = np.arange(0, self.xmax+1)
self.y = np.arange(0, self.ymax+1)

# Making interpolator (lon, lat) -> x
# save to speed up next time
if hasattr(self, "save_interpolator") and self.save_interpolator and hasattr(self, "interpolator_filename"):
interpolator_filename = Path(self.interpolator_filename).with_suffix('.pickle')
else:
interpolator_filename = f'{self.name}_interpolators.pickle'

if hasattr(self, "save_interpolator") and self.save_interpolator and Path(interpolator_filename).is_file():
logger.info('Loading interpolator for lon,lat to x,y conversion...')
with open(interpolator_filename, 'rb') as file_handle:
interp_dict = pickle.load(file_handle)
spl_x = interp_dict["spl_x"]
spl_y = interp_dict["spl_y"]

else:
logger.info('Making interpolator for lon,lat to x,y conversion...')

block_x, block_y = np.mgrid[self.xmin:self.xmax + 1,
self.ymin:self.ymax + 1]
block_x, block_y = block_x.T, block_y.T

spl_x = LinearNDInterpolator(
(self.lon.ravel(), self.lat.ravel()),
block_x.ravel(),
fill_value=np.nan)
# Reusing x-interpolator (deepcopy) with data for y
spl_y = copy.deepcopy(spl_x)
spl_y.values[:, 0] = block_y.ravel()
# Call interpolator to avoid threading-problem:
# https://github.com/scipy/scipy/issues/8856
spl_x((0, 0)), spl_y((0, 0))

if hasattr(self, "save_interpolator") and self.save_interpolator:

interp_dict = {"spl_x": spl_x, "spl_y": spl_y}
with open(interpolator_filename, 'wb') as f:
pickle.dump(interp_dict, f)

self.spl_x = spl_x
self.spl_y = spl_y

block_x, block_y = np.mgrid[self.xmin:self.xmax + 1,
self.ymin:self.ymax + 1]
block_x, block_y = block_x.T, block_y.T

# Making interpolator (lon, lat) -> x
self.spl_x = LinearNDInterpolator(
(self.lon.ravel(), self.lat.ravel()),
block_x.ravel(),
fill_value=np.nan)
# Reusing x-interpolator (deepcopy) with data for y
self.spl_y = copy.deepcopy(self.spl_x)
self.spl_y.values[:, 0] = block_y.ravel()
# Call interpolator to avoid threading-problem:
# https://github.com/scipy/scipy/issues/8856
self.spl_x((0, 0)), self.spl_y((0, 0))
else:
self.projected = True

Expand Down
26 changes: 14 additions & 12 deletions opendrift/readers/reader_ROMS_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class Reader(BaseReader, StructuredReader):
:param name: Name of reader
:type name: string, optional
:param save_interpolator: Whether or not to save the interpolator that goes from lon/lat to x/y (calculated in structured.py)
:type save_interpolator: bool
:param interpolator_path: If save_interpolator is True, user can input this string to control where interpolator is saved.
:type interpolator_path: Path, str, optional
Example:
.. code::
Expand Down Expand Up @@ -69,7 +75,8 @@ class Reader(BaseReader, StructuredReader):
r = Reader(ds)
"""

def __init__(self, filename=None, name=None, gridfile=None, standard_name_mapping={}):
def __init__(self, filename=None, name=None, gridfile=None, standard_name_mapping={},
save_interpolator=False, interpolator_path=None):

if filename is None:
raise ValueError('Need filename as argument to constructor')
Expand Down Expand Up @@ -118,21 +125,14 @@ def __init__(self, filename=None, name=None, gridfile=None, standard_name_mappin
-6500, -7000, -7500, -8000])

gls_param = ['gls_cmu0', 'gls_p', 'gls_m', 'gls_n']

self.name = name or 'roms native'

if isinstance(filename, xr.Dataset):
self.Dataset = filename
if name is not None:
self.name = name
else:
import re
self.name = re.sub('[^0-9a-zA-Z]+', '_', filename.attrs['title'])
else:

filestr = str(filename)
if name is None:
self.name = filestr
else:
self.name = name

try:
# Open file, check that everything is ok
Expand All @@ -158,6 +158,10 @@ def drop_non_essential_vars_pop(ds):
except Exception as e:
raise ValueError(e)

# this is an opporunity to save interpolators to pickle to save sim time
self.save_interpolator = save_interpolator
self.interpolator_path = interpolator_path or f'{self.name}_interpolators'

if gridfile is not None: # Merging gridfile dataset with main dataset
gf = xr.open_dataset(gridfile)
self.Dataset = xr.merge([self.Dataset, gf])
Expand Down Expand Up @@ -256,8 +260,6 @@ def drop_non_essential_vars_pop(ds):
else:
self.time_step = None

self.name = 'roms native'

self.precalculate_s2z_coefficients = True

# Find all variables having standard_name
Expand Down

0 comments on commit 567998f

Please sign in to comment.