Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewrote the subset selector for sgrid, added test based on example notebook #61

Merged
merged 9 commits into from
Sep 13, 2024
54 changes: 54 additions & 0 deletions tests/test_grids/test_sgrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import cf_xarray
import fsspec
import numpy as np
import xarray as xr

import xarray_subset_grid.accessor # noqa: F401
from xarray_subset_grid.utils import ray_tracing_numpy
# open dataset as zarr object using fsspec reference file system and xarray
def test_polygon_subset():
'''
This is a basic integration test for the subsetting of a ROMS sgrid dataset using a polygon.
'''
fs = fsspec.filesystem(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need a test (or more) that doesn't rely on that server being alive ....

See PR comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the arakawa_c_test_grid from gridded for future tests. Will add another offline test soon

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

"reference",
fo="s3://nextgen-dmac-cloud-ingest/nos/wcofs/nos.wcofs.2ds.best.nc.zarr",
remote_protocol="s3",
remote_options={"anon": True},
target_protocol="s3",
target_options={"anon": True},
)
m = fs.get_mapper("")

ds = xr.open_dataset(
m, engine="zarr", backend_kwargs=dict(consolidated=False), chunks={}
)

polygon = np.array(
[
[-122.38488806417945, 34.98888604471138],
[-122.02425311530737, 33.300351211467074],
[-120.60402628930146, 32.723214427630836],
[-116.63789131284673, 32.54346959375448],
[-116.39346090873218, 33.8541384965596],
[-118.83845767505964, 35.257586401855164],
[-121.34541503969862, 35.50073821008141],
[-122.38488806417945, 34.98888604471138],
]
)
ds_temp = ds.xsg.subset_vars(['temp_sur'])
ds_subset = ds_temp.xsg.subset_polygon(polygon)

#Check that the subset dataset has the correct dimensions given the original padding
assert ds_subset.sizes['eta_rho'] == ds_subset.sizes['eta_psi'] + 1
assert ds_subset.sizes['eta_u'] == ds_subset.sizes['eta_psi'] + 1
assert ds_subset.sizes['eta_v'] == ds_subset.sizes['eta_psi']
assert ds_subset.sizes['xi_rho'] == ds_subset.sizes['xi_psi'] + 1
assert ds_subset.sizes['xi_u'] == ds_subset.sizes['xi_psi']
assert ds_subset.sizes['xi_v'] == ds_subset.sizes['xi_psi'] + 1

#Check that the subset rho/psi/u/v positional relationsip makes sense aka psi point is 'between' it's neighbor rho points
#Note that this needs to be better generalized; it's not trivial to write a test that works in all potential cases.
assert ds_subset['lon_rho'][0,0] < ds_subset['lon_psi'][0,0] and ds_subset['lon_rho'][0,1] > ds_subset['lon_psi'][0,0]

#ds_subset.temp_sur.isel(ocean_time=0).plot(x="lon_rho", y="lat_rho")
62 changes: 38 additions & 24 deletions xarray_subset_grid/grids/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from xarray_subset_grid.grid import Grid
from xarray_subset_grid.selector import Selector
from xarray_subset_grid.utils import compute_2d_subset_mask
from xarray_subset_grid.utils import compute_2d_subset_mask, parse_padding_string


class SGridSelector(Selector):
Expand Down Expand Up @@ -106,42 +106,56 @@ def compute_polygon_subset_selector(
grid_topology_key = ds.cf.cf_roles["grid_topology"][0]
grid_topology = ds[grid_topology_key]
dims = _get_sgrid_dim_coord_names(grid_topology)

subset_masks: list[tuple[list[str], xr.DataArray]] = []
for dim, coord in dims:
# Get the variables that have the dimensions
unique_dims = set(dim)
vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))]

# If the dataset has already been subset and there are no variables with
# the dimensions, we can skip this dimension set
if len(vars) == 0:
continue

# Get the coordinates for the dimension

node_dims = grid_topology.attrs["node_dimensions"].split()
node_coords = grid_topology.attrs["node_coordinates"].split()

node_lon: xr.DataArray | None = None
node_lat: xr.DataArray | None = None
for c in node_coords:
Copy link
Collaborator

@ChrisBarker-NOAA ChrisBarker-NOAA Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm -- doesn't the spec say it should always be lon, lat order? -- but I suppose this is a bit safer anyway.

if 'lon' in c:
node_lon = ds[c]
elif 'lat' in c:
node_lat = ds[c]
if node_lon is None or node_lat is None:
raise ValueError(f"Could not find lon and lat for dimension {node_dims}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is specifically the node_coordinates -- yes? If so, the error message should be clear on that.


node_mask = compute_2d_subset_mask(lat=node_lat, lon=node_lon, polygon=polygon)
msk = np.where(node_mask)
subset_masks.append(([node_coords[0], node_coords[1]], node_mask))

index_bounding_box = [[msk[0].min(), msk[0].max()], [msk[1].min(), msk[1].max()]]
for s in ('face', 'edge1', 'edge2'):
dims = grid_topology.attrs.get(f"{s}_dimensions", None)
coords = grid_topology.attrs.get(f"{s}_coordinates", None).split()

lon: xr.DataArray | None = None
lat: xr.DataArray | None = None
for c in coord:
if "lon" in ds[c].attrs.get("standard_name", ""):
for c in coords:
if 'lon' in c:
lon = ds[c]
elif "lat" in ds[c].attrs.get("standard_name", ""):
elif 'lat' in c:
lat = ds[c]

if lon is None or lat is None:
raise ValueError(f"Could not find lon and lat for dimension {dim}")

subset_mask = compute_2d_subset_mask(lat=lat, lon=lon, polygon=polygon)

subset_masks.append((vars, subset_mask))

raise ValueError(f"Could not find lon and lat for dimension {dims}")
padding = parse_padding_string(dims)
arranged_padding = [padding[d] for d in lon.dims]
arranged_padding = [0 if p == 'none' or p == 'low' else 1 for p in arranged_padding]
mask = np.zeros(lon.shape, dtype=bool)
mask[index_bounding_box[0][0]:index_bounding_box[0][1] + arranged_padding[0] + 1,
index_bounding_box[1][0]:index_bounding_box[1][1] + arranged_padding[1] + 1] = True
xr_mask = xr.DataArray(mask, dims=lon.dims)
subset_masks.append(([coords[0], coords[1]], xr_mask))

return SGridSelector(
name=name or 'selector',
polygon=polygon,
grid_topology_key=grid_topology_key,
grid_topology=grid_topology,
subset_masks=subset_masks,
)


def _get_sgrid_dim_coord_names(
grid_topology: xr.DataArray,
Expand Down
26 changes: 26 additions & 0 deletions xarray_subset_grid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,29 @@ def compute_2d_subset_mask(
polygon_mask = np.where(polygon_mask > 1, True, False)

return xr.DataArray(polygon_mask, dims=mask_dims)

def parse_padding_string(dim_string):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'd be nice to have a unit test for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also -- is it really this complex ? that surprises me -- though I haven't followed all the logic. If any of the logic is to deal with non-compliant files, then that should go elsewhere. maybe it's not.

'''
Given a grid_topology dimension string, parse the padding for each dimension.
Returns a dict of {dim0name: padding,
dim1name: padding
}
valid values of padding are: 'none', 'low', 'high', 'both'
'''
parsed_string = dim_string.replace('(padding: ', '').replace(')', '').replace(':', '')
split_parsed_string = parsed_string.split(' ')
if len(split_parsed_string) == 6:
return {split_parsed_string[0]:split_parsed_string[2], split_parsed_string[3]:split_parsed_string[5]}
elif len(split_parsed_string) == 5:
if split_parsed_string[4] in {'none', 'low', 'high', 'both'}:
#2nd dim has padding, and with len 5 that means first does not
split_parsed_string.insert(2, 'none')
else:
split_parsed_string.insert(5, 'none')
return {split_parsed_string[0]:split_parsed_string[2], split_parsed_string[3]:split_parsed_string[5]}
elif len(split_parsed_string) == 2:
#node dimensions string could look like this: 'node_dimensions: xi_psi eta_psi'
return {split_parsed_string[0]: 'none', split_parsed_string[1]: 'none'}
else:
raise ValueError(f"Padding parsing failure: {dim_string}")