-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rewrote the subset selector for sgrid, added test based on example notebook #61
Changes from 1 commit
b52b12e
36df925
1e6d620
06b3843
02fc924
e8692f8
d17517a
338a1f7
83ad4af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import cf_xarray | ||
import fsspec | ||
import numpy as np | ||
import xarray as xr | ||
|
||
import xarray_subset_grid.accessor # noqa: F401 | ||
from xarray_subset_grid.utils import ray_tracing_numpy | ||
# open dataset as zarr object using fsspec reference file system and xarray | ||
def test_polygon_subset(): | ||
''' | ||
This is a basic integration test for the subsetting of a ROMS sgrid dataset using a polygon. | ||
''' | ||
fs = fsspec.filesystem( | ||
"reference", | ||
fo="s3://nextgen-dmac-cloud-ingest/nos/wcofs/nos.wcofs.2ds.best.nc.zarr", | ||
remote_protocol="s3", | ||
remote_options={"anon": True}, | ||
target_protocol="s3", | ||
target_options={"anon": True}, | ||
) | ||
m = fs.get_mapper("") | ||
|
||
ds = xr.open_dataset( | ||
m, engine="zarr", backend_kwargs=dict(consolidated=False), chunks={} | ||
) | ||
|
||
polygon = np.array( | ||
[ | ||
[-122.38488806417945, 34.98888604471138], | ||
[-122.02425311530737, 33.300351211467074], | ||
[-120.60402628930146, 32.723214427630836], | ||
[-116.63789131284673, 32.54346959375448], | ||
[-116.39346090873218, 33.8541384965596], | ||
[-118.83845767505964, 35.257586401855164], | ||
[-121.34541503969862, 35.50073821008141], | ||
[-122.38488806417945, 34.98888604471138], | ||
] | ||
) | ||
ds_temp = ds.xsg.subset_vars(['temp_sur']) | ||
ds_subset = ds_temp.xsg.subset_polygon(polygon) | ||
|
||
#Check that the subset dataset has the correct dimensions given the original padding | ||
assert ds_subset.sizes['eta_rho'] == ds_subset.sizes['eta_psi'] + 1 | ||
assert ds_subset.sizes['eta_u'] == ds_subset.sizes['eta_psi'] + 1 | ||
assert ds_subset.sizes['eta_v'] == ds_subset.sizes['eta_psi'] | ||
assert ds_subset.sizes['xi_rho'] == ds_subset.sizes['xi_psi'] + 1 | ||
assert ds_subset.sizes['xi_u'] == ds_subset.sizes['xi_psi'] | ||
assert ds_subset.sizes['xi_v'] == ds_subset.sizes['xi_psi'] + 1 | ||
|
||
#Check that the subset rho/psi/u/v positional relationsip makes sense aka psi point is 'between' it's neighbor rho points | ||
#Note that this needs to be better generalized; it's not trivial to write a test that works in all potential cases. | ||
assert ds_subset['lon_rho'][0,0] < ds_subset['lon_psi'][0,0] and ds_subset['lon_rho'][0,1] > ds_subset['lon_psi'][0,0] | ||
|
||
#ds_subset.temp_sur.isel(ocean_time=0).plot(x="lon_rho", y="lat_rho") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
|
||
from xarray_subset_grid.grid import Grid | ||
from xarray_subset_grid.selector import Selector | ||
from xarray_subset_grid.utils import compute_2d_subset_mask | ||
from xarray_subset_grid.utils import compute_2d_subset_mask, parse_padding_string | ||
|
||
|
||
class SGridSelector(Selector): | ||
|
@@ -106,42 +106,56 @@ def compute_polygon_subset_selector( | |
grid_topology_key = ds.cf.cf_roles["grid_topology"][0] | ||
grid_topology = ds[grid_topology_key] | ||
dims = _get_sgrid_dim_coord_names(grid_topology) | ||
|
||
subset_masks: list[tuple[list[str], xr.DataArray]] = [] | ||
for dim, coord in dims: | ||
# Get the variables that have the dimensions | ||
unique_dims = set(dim) | ||
vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))] | ||
|
||
# If the dataset has already been subset and there are no variables with | ||
# the dimensions, we can skip this dimension set | ||
if len(vars) == 0: | ||
continue | ||
|
||
# Get the coordinates for the dimension | ||
|
||
node_dims = grid_topology.attrs["node_dimensions"].split() | ||
node_coords = grid_topology.attrs["node_coordinates"].split() | ||
|
||
node_lon: xr.DataArray | None = None | ||
node_lat: xr.DataArray | None = None | ||
for c in node_coords: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm -- doesn't the spec say it should always be lon, lat order? -- but I suppose this is a bit safer anyway. |
||
if 'lon' in c: | ||
node_lon = ds[c] | ||
elif 'lat' in c: | ||
node_lat = ds[c] | ||
if node_lon is None or node_lat is None: | ||
raise ValueError(f"Could not find lon and lat for dimension {node_dims}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is specifically the node_coordinates -- yes? If so, the error message should be clear on that. |
||
|
||
node_mask = compute_2d_subset_mask(lat=node_lat, lon=node_lon, polygon=polygon) | ||
msk = np.where(node_mask) | ||
subset_masks.append(([node_coords[0], node_coords[1]], node_mask)) | ||
|
||
index_bounding_box = [[msk[0].min(), msk[0].max()], [msk[1].min(), msk[1].max()]] | ||
for s in ('face', 'edge1', 'edge2'): | ||
dims = grid_topology.attrs.get(f"{s}_dimensions", None) | ||
coords = grid_topology.attrs.get(f"{s}_coordinates", None).split() | ||
|
||
lon: xr.DataArray | None = None | ||
lat: xr.DataArray | None = None | ||
for c in coord: | ||
if "lon" in ds[c].attrs.get("standard_name", ""): | ||
for c in coords: | ||
if 'lon' in c: | ||
lon = ds[c] | ||
elif "lat" in ds[c].attrs.get("standard_name", ""): | ||
elif 'lat' in c: | ||
lat = ds[c] | ||
|
||
if lon is None or lat is None: | ||
raise ValueError(f"Could not find lon and lat for dimension {dim}") | ||
|
||
subset_mask = compute_2d_subset_mask(lat=lat, lon=lon, polygon=polygon) | ||
|
||
subset_masks.append((vars, subset_mask)) | ||
|
||
raise ValueError(f"Could not find lon and lat for dimension {dims}") | ||
padding = parse_padding_string(dims) | ||
arranged_padding = [padding[d] for d in lon.dims] | ||
arranged_padding = [0 if p == 'none' or p == 'low' else 1 for p in arranged_padding] | ||
mask = np.zeros(lon.shape, dtype=bool) | ||
mask[index_bounding_box[0][0]:index_bounding_box[0][1] + arranged_padding[0] + 1, | ||
index_bounding_box[1][0]:index_bounding_box[1][1] + arranged_padding[1] + 1] = True | ||
xr_mask = xr.DataArray(mask, dims=lon.dims) | ||
subset_masks.append(([coords[0], coords[1]], xr_mask)) | ||
|
||
return SGridSelector( | ||
name=name or 'selector', | ||
polygon=polygon, | ||
grid_topology_key=grid_topology_key, | ||
grid_topology=grid_topology, | ||
subset_masks=subset_masks, | ||
) | ||
|
||
|
||
def _get_sgrid_dim_coord_names( | ||
grid_topology: xr.DataArray, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,3 +151,29 @@ def compute_2d_subset_mask( | |
polygon_mask = np.where(polygon_mask > 1, True, False) | ||
|
||
return xr.DataArray(polygon_mask, dims=mask_dims) | ||
|
||
def parse_padding_string(dim_string): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it'd be nice to have a unit test for this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also -- is it really this complex ? that surprises me -- though I haven't followed all the logic. If any of the logic is to deal with non-compliant files, then that should go elsewhere. maybe it's not. |
||
''' | ||
Given a grid_topology dimension string, parse the padding for each dimension. | ||
Returns a dict of {dim0name: padding, | ||
dim1name: padding | ||
} | ||
valid values of padding are: 'none', 'low', 'high', 'both' | ||
''' | ||
parsed_string = dim_string.replace('(padding: ', '').replace(')', '').replace(':', '') | ||
split_parsed_string = parsed_string.split(' ') | ||
if len(split_parsed_string) == 6: | ||
return {split_parsed_string[0]:split_parsed_string[2], split_parsed_string[3]:split_parsed_string[5]} | ||
elif len(split_parsed_string) == 5: | ||
if split_parsed_string[4] in {'none', 'low', 'high', 'both'}: | ||
#2nd dim has padding, and with len 5 that means first does not | ||
split_parsed_string.insert(2, 'none') | ||
else: | ||
split_parsed_string.insert(5, 'none') | ||
return {split_parsed_string[0]:split_parsed_string[2], split_parsed_string[3]:split_parsed_string[5]} | ||
elif len(split_parsed_string) == 2: | ||
#node dimensions string could look like this: 'node_dimensions: xi_psi eta_psi' | ||
return {split_parsed_string[0]: 'none', split_parsed_string[1]: 'none'} | ||
else: | ||
raise ValueError(f"Padding parsing failure: {dim_string}") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need a test (or more) that doesn't rely on that server being alive ....
See PR comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the arakawa_c_test_grid from gridded for future tests. Will add another offline test soon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!