Skip to content

Commit

Permalink
replace KDTree with sel method of xarray, fix a bug in fields
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahAlidoost committed Mar 18, 2024
1 parent e133f5f commit 90ef61d
Showing 1 changed file with 26 additions and 43 deletions.
69 changes: 26 additions & 43 deletions stmtools/stm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import xarray as xr
from shapely.geometry import Point
from shapely.strtree import STRtree
from scipy.spatial import KDTree

from stmtools.metadata import DataVarTypes, STMMetaData
from stmtools.utils import _has_property
Expand Down Expand Up @@ -468,34 +467,34 @@ def enrich_from_dataset(self,

# TODO: check if both ds and dataset has same coordinate system

for i, field in enumerate(fields):
for field in fields:

# check if dataset has the fields
if field not in dataset.data_vars.keys():
raise ValueError(f'Field "{field}" not found in the the input dataset')

# check STM has the filed already
if field in ds.data_vars.keys():
logger.warning(
f'"{field}" was found in the data variables of the STM. '
f'"We will proceed with the data variable from the input dataset as "{field}_other".'
)
fields[i] = f"{field}_other"
raise ValueError(f'Field "{field}" already exists in the STM.')

# if dataset is a dask collection, compute it first

if approch == "raster":
return xr.map_blocks(
_enrich_from_raster_block,
ds,
args=(fields, method),
kwargs={"dataset": dataset}, #TODD: block still not working, refactor
)
return _enrich_from_raster_block(ds, dataset, fields, method)
# return xr.map_blocks(
# _enrich_from_raster_block,
# ds,
# args=(fields, method),
# kwargs={"dataset": dataset}, #TODD: block still not working, refactor
# )
elif approch == "point":
return xr.map_blocks(
_enrich_from_points_block,
ds,
args=(fields),
kwargs={"dataset": dataset},
)
return _enrich_from_points_block(ds, dataset, fields)
# return xr.map_blocks(
# _enrich_from_points_block,
# ds,
# args=(fields),
# kwargs={"dataset": dataset},
# )

@property
def num_points(self):
Expand Down Expand Up @@ -706,9 +705,7 @@ def _enrich_from_raster_block(ds, dataraster, fields, method):
def _enrich_from_points_block(ds, datapoints, fields):
"""Enrich the ds (SpaceTimeMatrix) from one or more fields of a point dataset.
scipy is required. It uses cKDTree to find the nearest points in space and
time using Euclidean distance.
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.cKDTree.html#scipy-spatial-ckdtree
https://docs.xarray.dev/en/latest/generated/xarray.DataArray.sel.html#xarray.DataArray.sel
Parameters
----------
Expand All @@ -725,30 +722,16 @@ def _enrich_from_points_block(ds, datapoints, fields):
"""
_ds = ds.copy(deep=True)

# create tuple of spatial coordinates
spatial_coords = list(_ds.coords.keys())[:-1] # assuming the last coordinate is time
ds_coords = np.column_stack([_ds[coord].values.flatten() for coord in spatial_coords])

spatial_coords = list(datapoints.coords.keys())[:-1] # assuming the last coordinate is time
dataset_points_coords = np.column_stack([datapoints[coord].values.flatten() for coord in spatial_coords])

# Create a cKDTree object for the spatial coordinates of datapoints
# Find the indices of the nearest points in space in datapoints for each point in _ds
# it uses Euclidean distance
tree = KDTree(dataset_points_coords)
_, indices_space = tree.query(ds_coords)

# Create a cKDTree object for the temporal coordinates of datapoints
# Find the indices of the nearest points in time in datapoints for each point in _ds
datapoints_times = datapoints.time.values.reshape(-1, 1)
ds_times = _ds.time.values.reshape(-1, 1)
tree = KDTree(datapoints_times)
_, indices_time = tree.query(ds_times)
# add spatial coordinates to dims
datapoints_coords = list(datapoints.coords.keys())
datapoints = datapoints.set_index(space=datapoints_coords[:-1]) # assuming the last coordinate is time
datapoints = datapoints.unstack("space") # after this, the order of coordinates changes, so we use transpose later

selections = datapoints.isel(time=indices_time, space=indices_space)
indexers = {coord: _ds[coord] for coord in datapoints_coords}
selections = datapoints.sel(indexers, method="nearest")

# Assign these values to the corresponding points in _ds
for field in fields:
_ds[field] = xr.DataArray(selections[field].data, dims=ds.dims, coords=ds.coords)
_ds[field] = xr.DataArray(selections[field].data.transpose(), dims=ds.dims, coords=ds.coords)

return _ds

0 comments on commit 90ef61d

Please sign in to comment.