diff --git a/stmtools/stm.py b/stmtools/stm.py index 531936a..19cacae 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -11,7 +11,6 @@ import xarray as xr from shapely.geometry import Point from shapely.strtree import STRtree -from scipy.spatial import KDTree from stmtools.metadata import DataVarTypes, STMMetaData from stmtools.utils import _has_property @@ -468,7 +467,7 @@ def enrich_from_dataset(self, # TODO: check if both ds and dataset has same coordinate system - for i, field in enumerate(fields): + for field in fields: # check if dataset has the fields if field not in dataset.data_vars.keys(): @@ -476,26 +475,26 @@ def enrich_from_dataset(self, # check STM has the filed already if field in ds.data_vars.keys(): - logger.warning( - f'"{field}" was found in the data variables of the STM. ' - f'"We will proceed with the data variable from the input dataset as "{field}_other".' - ) - fields[i] = f"{field}_other" + raise ValueError(f'Field "{field}" already exists in the STM.') + + # if dataset is a dask collection, compute it first if approch == "raster": - return xr.map_blocks( - _enrich_from_raster_block, - ds, - args=(fields, method), - kwargs={"dataset": dataset}, #TODD: block still not working, refactor - ) + return _enrich_from_raster_block(ds, dataset, fields, method) + # return xr.map_blocks( + # _enrich_from_raster_block, + # ds, + # args=(fields, method), + # kwargs={"dataset": dataset}, #TODD: block still not working, refactor + # ) elif approch == "point": - return xr.map_blocks( - _enrich_from_points_block, - ds, - args=(fields), - kwargs={"dataset": dataset}, - ) + return _enrich_from_points_block(ds, dataset, fields) + # return xr.map_blocks( + # _enrich_from_points_block, + # ds, + # args=(fields), + # kwargs={"dataset": dataset}, + # ) @property def num_points(self): @@ -706,9 +705,7 @@ def _enrich_from_raster_block(ds, dataraster, fields, method): def _enrich_from_points_block(ds, datapoints, fields): """Enrich the ds (SpaceTimeMatrix) from one or more fields of a point dataset. - scipy is required. It uses cKDTree to find the nearest points in space and - time using Euclidean distance. - https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.cKDTree.html#scipy-spatial-ckdtree + https://docs.xarray.dev/en/latest/generated/xarray.DataArray.sel.html#xarray.DataArray.sel Parameters ---------- @@ -725,30 +722,16 @@ def _enrich_from_points_block(ds, datapoints, fields): """ _ds = ds.copy(deep=True) - # create tuple of spatial coordinates - spatial_coords = list(_ds.coords.keys())[:-1] # assuming the last coordinate is time - ds_coords = np.column_stack([_ds[coord].values.flatten() for coord in spatial_coords]) - - spatial_coords = list(datapoints.coords.keys())[:-1] # assuming the last coordinate is time - dataset_points_coords = np.column_stack([datapoints[coord].values.flatten() for coord in spatial_coords]) - - # Create a cKDTree object for the spatial coordinates of datapoints - # Find the indices of the nearest points in space in datapoints for each point in _ds - # it uses Euclidean distance - tree = KDTree(dataset_points_coords) - _, indices_space = tree.query(ds_coords) - - # Create a cKDTree object for the temporal coordinates of datapoints - # Find the indices of the nearest points in time in datapoints for each point in _ds - datapoints_times = datapoints.time.values.reshape(-1, 1) - ds_times = _ds.time.values.reshape(-1, 1) - tree = KDTree(datapoints_times) - _, indices_time = tree.query(ds_times) + # add spatial coordinates to dims + datapoints_coords = list(datapoints.coords.keys()) + datapoints = datapoints.set_index(space=datapoints_coords[:-1]) # assuming the last coordinate is time + datapoints = datapoints.unstack("space") # after this, the order of coordinates changes, so we use transpose later - selections = datapoints.isel(time=indices_time, space=indices_space) + indexers = {coord: _ds[coord] for coord in datapoints_coords} + selections = datapoints.sel(indexers, method="nearest") # Assign these values to the corresponding points in _ds for field in fields: - _ds[field] = xr.DataArray(selections[field].data, dims=ds.dims, coords=ds.coords) + _ds[field] = xr.DataArray(selections[field].data.transpose(), dims=ds.dims, coords=ds.coords) return _ds