Skip to content

Commit

Permalink
fix linter errors
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahAlidoost committed Mar 22, 2024
1 parent 2002854 commit 9113b02
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
37 changes: 30 additions & 7 deletions stmtools/stm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from shapely.geometry import Point
from shapely.strtree import STRtree

from stmtools import utils
from stmtools.metadata import DataVarTypes, STMMetaData
from stmtools.utils import _has_property

Expand All @@ -39,6 +38,7 @@ def add_metadata(self, metadata):
-------
xarray.Dataset
STM with assigned attributes.
"""
self._obj = self._obj.assign_attrs(metadata)
return self._obj
Expand Down Expand Up @@ -70,6 +70,7 @@ def regulate_dims(self, space_label=None, time_label=None):
-------
xarray.Dataset
Regulated STM.
"""
if (
(space_label is None)
Expand Down Expand Up @@ -129,6 +130,7 @@ def subset(self, method: str, **kwargs):
-------
xarray.Dataset
A subset of the original STM.
"""
# Check if both "space" and "time" dimension exists
for dim in ["space", "time"]:
Expand Down Expand Up @@ -204,6 +206,7 @@ def enrich_from_polygon(self, polygon, fields, xlabel="lon", ylabel="lat"):
-------
xarray.Dataset
Enriched STM.
"""
_ = _validate_coords(self._obj, xlabel, ylabel)

Expand Down Expand Up @@ -267,6 +270,7 @@ def _in_polygon(self, polygon, xlabel="lon", ylabel="lat"):
-------
Dask.array
A boolean Dask array. True where a space entry is inside the (multi-)polygon.
"""
# Check if coords exists
_ = _validate_coords(self._obj, xlabel, ylabel)
Expand Down Expand Up @@ -312,6 +316,7 @@ def register_metadata(self, dict_meta: STMMetaData):
-------
xarray.Dataset
STM with registered metadata.
"""
ds_updated = self._obj.assign_attrs(dict_meta)

Expand All @@ -331,6 +336,7 @@ def register_datatype(self, keys: str | Iterable, datatype: DataVarTypes):
-------
xarray.Dataset
STM with registered metadata.
"""
ds_updated = self._obj

Expand Down Expand Up @@ -364,6 +370,7 @@ def get_order(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0):
Scaling multiplier to the x coordinates before truncating them to integer values.
yscale : float
Scaling multiplier to the y coordinates before truncating them to integer values.
"""
meta_arr = np.array((), dtype=np.int64)
order = da.apply_gufunc(
Expand Down Expand Up @@ -396,6 +403,7 @@ def reorder(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0):
Scaling multiplier to the x coordinates before truncating them to integer values.
yscale : float
Scaling multiplier to the y coordinates before truncating them to integer values.
"""
self._obj = self.get_order(xlabel, ylabel, xscale, yscale)
self._obj = self._obj.sortby(self._obj.order)
Expand All @@ -422,10 +430,12 @@ def enrich_from_dataset(self,
method : str, optional
Method of interpolation, by default "nearest", see
https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like
Returns
-------
xarray.Dataset
Enriched STM.
"""
# Check if fields is a Iterable or a str
if isinstance(fields, str):
Expand Down Expand Up @@ -455,7 +465,7 @@ def enrich_from_dataset(self,
else:
raise ValueError(
"The input dataset is not a point or raster dataset."
"The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions." # give help on renaming
"The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions."
"Consider renaming using "
"https://docs.xarray.dev/en/latest/generated/xarray.Dataset.rename.html#xarray-dataset-rename"
)
Expand Down Expand Up @@ -494,6 +504,7 @@ def num_points(self):
-------
int
Number of space entry.
"""
return self._obj.dims["space"]

Expand All @@ -505,6 +516,7 @@ def num_epochs(self):
-------
int
Number of epochs.
"""
return self._obj.dims["time"]

Expand Down Expand Up @@ -558,6 +570,7 @@ def _ml_str_query(xx, yy, polygon, type_polygon):
An array with two columns. The first column is the positional index into the list of
polygons being used to query the tree. The second column is the positional index into
the list of space entries for which the tree was constructed.
"""
# Crop the polygon to the bounding box of the block
xmin, ymin, xmax, ymax = [
Expand Down Expand Up @@ -623,6 +636,7 @@ def _validate_coords(ds, xlabel, ylabel):
------
ValueError
If xlabel or ylabel neither exists in coordinates, raise ValueError
"""
for clabel in [xlabel, ylabel]:
if clabel not in ds.coords.keys():
Expand Down Expand Up @@ -655,6 +669,7 @@ def _compute_morton_code(xx, yy):
-------
array_like
An array with Morton codes per coordinate pair.
"""
code = [pm.interleave(int(xi), int(yi)) for xi, yi in zip(xx, yy, strict=True)]
return code
Expand All @@ -670,8 +685,8 @@ def _enrich_from_raster_block(ds, dataraster, fields, method):
Parameters
----------
ds : xarray.Dataset
dataset : xarray.Dataset | xarray.DataArray
SpaceTimeMatrix to enrich
dataraster : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Field name(s) in the dataset for enrichment
Expand All @@ -681,6 +696,7 @@ def _enrich_from_raster_block(ds, dataraster, fields, method):
Returns
-------
xarray.Dataset
"""
# interpolate the raster dataset to the coordinates of ds
interpolated = dataraster.interp(ds.coords, method=method)
Expand All @@ -699,7 +715,7 @@ def _enrich_from_points_block(ds, datapoints, fields):
Parameters
----------
ds : xarray.Dataset
SpaceTimeMatrix to enrich
datapoints : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Expand All @@ -708,11 +724,16 @@ def _enrich_from_points_block(ds, datapoints, fields):
Returns
-------
xarray.Dataset
"""
# unstak the dimensions
for dim in datapoints.dims:
if dim not in datapoints.coords:
indexer = {dim: [coord for coord in datapoints.coords if dim in datapoints[coord].dims]}
indexer = {
dim: [
coord for coord in datapoints.coords if dim in datapoints[coord].dims
]
}
datapoints = datapoints.set_index(indexer)
datapoints = datapoints.unstack(dim)

Expand All @@ -722,6 +743,8 @@ def _enrich_from_points_block(ds, datapoints, fields):

# Assign these values to the corresponding points in ds
for field in fields:
ds[field] = xr.DataArray(selections[field].data.transpose(), dims=ds.dims, coords=ds.coords)
ds[field] = xr.DataArray(
selections[field].data.transpose(), dims=ds.dims, coords=ds.coords
)

return ds
10 changes: 8 additions & 2 deletions stmtools/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import xarray as xr
from collections.abc import Iterable

import xarray as xr


def _has_property(ds, keys: str | Iterable):
if isinstance(keys, str):
Expand All @@ -27,6 +28,7 @@ def crop(ds, other, buffer):
-------
xarray.Dataset
Cropped dataset.
"""
if isinstance(ds, xr.DataArray):
ds = ds.to_dataset()
Expand All @@ -48,7 +50,11 @@ def crop(ds, other, buffer):
indexer = {}
for dim in other.dims:
if dim not in other.coords.keys():
indexer = {dim: [coord for coord in other.coords.keys() if dim in other.coords[coord].dims]}
indexer = {
dim: [
coord for coord in other.coords.keys() if dim in other.coords[coord].dims
]
}
other = other.set_index(indexer)
other = other.unstack(indexer)

Expand Down

0 comments on commit 9113b02

Please sign in to comment.