diff --git a/tests/test_grids/test_sgrid.py b/tests/test_grids/test_sgrid.py index 50bbf82..a76657f 100644 --- a/tests/test_grids/test_sgrid.py +++ b/tests/test_grids/test_sgrid.py @@ -1,11 +1,11 @@ import os -import fsspec import numpy as np import xarray as xr import xarray_subset_grid.accessor # noqa: F401 from tests.test_utils import get_test_file_dir +from xarray_subset_grid.grids.sgrid import _get_location_info_from_topology # open dataset as zarr object using fsspec reference file system and xarray @@ -13,55 +13,76 @@ test_dir = get_test_file_dir() sample_sgrid_file = os.path.join(test_dir, 'arakawa_c_test_grid.nc') -def test_polygon_subset(): - ''' - This is a basic integration test for the subsetting of a ROMS sgrid dataset using a polygon. - ''' - fs = fsspec.filesystem( - "reference", - fo="s3://nextgen-dmac-cloud-ingest/nos/wcofs/nos.wcofs.2ds.best.nc.zarr", - remote_protocol="s3", - remote_options={"anon": True}, - target_protocol="s3", - target_options={"anon": True}, - ) - m = fs.get_mapper("") - - ds = xr.open_dataset( - m, engine="zarr", backend_kwargs=dict(consolidated=False), chunks={} - ) - - polygon = np.array( - [ - [-122.38488806417945, 34.98888604471138], - [-122.02425311530737, 33.300351211467074], - [-120.60402628930146, 32.723214427630836], - [-116.63789131284673, 32.54346959375448], - [-116.39346090873218, 33.8541384965596], - [-118.83845767505964, 35.257586401855164], - [-121.34541503969862, 35.50073821008141], - [-122.38488806417945, 34.98888604471138], - ] - ) - ds_temp = ds.xsg.subset_vars(['temp_sur']) - ds_subset = ds_temp.xsg.subset_polygon(polygon) +def test_grid_topology_location_parse(): + ds = xr.open_dataset(sample_sgrid_file, decode_times=False) + node_info = _get_location_info_from_topology(ds['grid'], 'node') + edge1_info = _get_location_info_from_topology(ds['grid'], 'edge1') + edge2_info = _get_location_info_from_topology(ds['grid'], 'edge2') + face_info = _get_location_info_from_topology(ds['grid'], 'face') - #Check that the subset dataset has the correct dimensions given the original padding - assert ds_subset.sizes['eta_rho'] == ds_subset.sizes['eta_psi'] + 1 - assert ds_subset.sizes['eta_u'] == ds_subset.sizes['eta_psi'] + 1 - assert ds_subset.sizes['eta_v'] == ds_subset.sizes['eta_psi'] - assert ds_subset.sizes['xi_rho'] == ds_subset.sizes['xi_psi'] + 1 - assert ds_subset.sizes['xi_u'] == ds_subset.sizes['xi_psi'] - assert ds_subset.sizes['xi_v'] == ds_subset.sizes['xi_psi'] + 1 + assert node_info == {'dims': ['xi_psi', 'eta_psi'], + 'coords': ['lon_psi', 'lat_psi'], + 'padding': {'xi_psi': 'none', 'eta_psi': 'none'}} + assert edge1_info == {'dims': ['xi_u', 'eta_u'], + 'coords': ['lon_u', 'lat_u'], + 'padding': {'eta_u': 'both', 'xi_u': 'none'}} + assert edge2_info == {'dims': ['xi_v', 'eta_v'], + 'coords': ['lon_v', 'lat_v'], + 'padding': {'xi_v': 'both', 'eta_v': 'none'}} + assert face_info == {'dims': ['xi_rho', 'eta_rho'], + 'coords': ['lon_rho', 'lat_rho'], + 'padding': {'xi_rho': 'both', 'eta_rho': 'both'}} + + +# def test_polygon_subset(): +# ''' +# This is a basic integration test for the subsetting of a ROMS sgrid dataset using a polygon. +# ''' +# fs = fsspec.filesystem( +# "reference", +# fo="s3://nextgen-dmac-cloud-ingest/nos/wcofs/nos.wcofs.2ds.best.nc.zarr", +# remote_protocol="s3", +# remote_options={"anon": True}, +# target_protocol="s3", +# target_options={"anon": True}, +# ) +# m = fs.get_mapper("") + +# ds = xr.open_dataset( +# m, engine="zarr", backend_kwargs=dict(consolidated=False), chunks={} +# ) - #Check that the subset rho/psi/u/v positional relationsip makes sense aka psi point is - #'between' it's neighbor rho points - #Note that this needs to be better generalized; it's not trivial to write a test that - #works in all potential cases. - assert (ds_subset['lon_rho'][0,0] < ds_subset['lon_psi'][0,0] - and ds_subset['lon_rho'][0,1] > ds_subset['lon_psi'][0,0]) +# polygon = np.array( +# [ +# [-122.38488806417945, 34.98888604471138], +# [-122.02425311530737, 33.300351211467074], +# [-120.60402628930146, 32.723214427630836], +# [-116.63789131284673, 32.54346959375448], +# [-116.39346090873218, 33.8541384965596], +# [-118.83845767505964, 35.257586401855164], +# [-121.34541503969862, 35.50073821008141], +# [-122.38488806417945, 34.98888604471138], +# ] +# ) +# ds_temp = ds.xsg.subset_vars(['temp_sur']) +# ds_subset = ds_temp.xsg.subset_polygon(polygon) - #ds_subset.temp_sur.isel(ocean_time=0).plot(x="lon_rho", y="lat_rho") +# #Check that the subset dataset has the correct dimensions given the original padding +# assert ds_subset.sizes['eta_rho'] == ds_subset.sizes['eta_psi'] + 1 +# assert ds_subset.sizes['eta_u'] == ds_subset.sizes['eta_psi'] + 1 +# assert ds_subset.sizes['eta_v'] == ds_subset.sizes['eta_psi'] +# assert ds_subset.sizes['xi_rho'] == ds_subset.sizes['xi_psi'] + 1 +# assert ds_subset.sizes['xi_u'] == ds_subset.sizes['xi_psi'] +# assert ds_subset.sizes['xi_v'] == ds_subset.sizes['xi_psi'] + 1 + +# #Check that the subset rho/psi/u/v positional relationsip makes sense aka psi point is +# #'between' it's neighbor rho points +# #Note that this needs to be better generalized; it's not trivial to write a test that +# #works in all potential cases. +# assert (ds_subset['lon_rho'][0,0] < ds_subset['lon_psi'][0,0] +# and ds_subset['lon_rho'][0,1] > ds_subset['lon_psi'][0,0]) + +# #ds_subset.temp_sur.isel(ocean_time=0).plot(x="lon_rho", y="lat_rho") def test_polygon_subset_2(): ds = xr.open_dataset(sample_sgrid_file, decode_times=False) @@ -84,3 +105,5 @@ def test_polygon_subset_2(): assert ds_subset.lon_psi.min() <= 6.5 and ds_subset.lon_psi.max() >= 9.5 assert ds_subset.lat_psi.min() <= 37.5 and ds_subset.lat_psi.max() >= 40.5 + + assert 'u' in ds_subset.variables.keys() diff --git a/xarray_subset_grid/grids/sgrid.py b/xarray_subset_grid/grids/sgrid.py index d0ca38d..3ea9abb 100644 --- a/xarray_subset_grid/grids/sgrid.py +++ b/xarray_subset_grid/grids/sgrid.py @@ -108,8 +108,10 @@ def compute_polygon_subset_selector( dims = _get_sgrid_dim_coord_names(grid_topology) subset_masks: list[tuple[list[str], xr.DataArray]] = [] - node_dims = grid_topology.attrs["node_dimensions"].split() - node_coords = grid_topology.attrs["node_coordinates"].split() + node_info = _get_location_info_from_topology(grid_topology, 'node') + node_dims = node_info['dims'] + node_coords = node_info['coords'] + unique_dims = set(node_dims) node_vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))] @@ -120,8 +122,6 @@ def compute_polygon_subset_selector( node_lon = ds[c] elif 'lat' in ds[c].standard_name.lower(): node_lat = ds[c] - if node_lon is None or node_lat is None: - raise ValueError(f"Could not find lon and lat for dimension {node_dims}") node_mask = compute_2d_subset_mask(lat=node_lat, lon=node_lon, polygon=polygon) msk = np.where(node_mask) @@ -137,22 +137,19 @@ def compute_polygon_subset_selector( index_bounding_box[1][0]:index_bounding_box[1][1]] = True subset_masks.append((node_vars, node_mask)) + for s in ('face', 'edge1', 'edge2'): - dims = grid_topology.attrs.get(f"{s}_dimensions", None) - coords = grid_topology.attrs.get(f"{s}_coordinates", None).split() + info = _get_location_info_from_topology(grid_topology, s) + dims = info['dims'] + coords = info['coords'] unique_dims = set(dims) vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))] lon: xr.DataArray | None = None - lat: xr.DataArray | None = None for c in coords: if 'lon' in ds[c].standard_name.lower(): lon = ds[c] - elif 'lat' in ds[c].standard_name.lower(): - lat = ds[c] - if lon is None or lat is None: - raise ValueError(f"Could not find lon and lat for dimension {dims}") - padding = parse_padding_string(dims) + padding = info['padding'] arranged_padding = [padding[d] for d in lon.dims] arranged_padding = [0 if p == 'none' or p == 'low' else 1 for p in arranged_padding] mask = np.zeros(lon.shape, dtype=bool) @@ -169,6 +166,40 @@ def compute_polygon_subset_selector( subset_masks=subset_masks, ) +def _get_location_info_from_topology(grid_topology: xr.DataArray, location) -> dict[str, str]: + '''Get the dimensions and coordinates for a given location from the grid_topology''' + rdict = {} + dim_str = grid_topology.attrs.get(f"{location}_dimensions", None) + coord_str = grid_topology.attrs.get(f"{location}_coordinates", None) + if dim_str is None or coord_str is None: + raise ValueError(f"Could not find {location} dimensions or coordinates") + # Remove padding for now + dims_only = " ".join([v for v in dim_str.split(" ") if "(" not in v and ")" not in v]) + if ":" in dims_only: + dims_only = [s.replace(":", "") for s in dims_only.split(" ") if ":" in s] + else: + dims_only = dims_only.split(" ") + + padding = dim_str.replace(':', '').split(')') + pdict = {} + if len(padding) == 3: #two padding values + pdict[dims_only[0]] = padding[0].split(' ')[-1] + pdict[dims_only[1]] = padding[1].split(' ')[-1] + elif len(padding) == 2: #one padding value + if padding[-1] == '': #padding is on second dim + pdict[dims_only[1]] = padding[0].split(' ')[-1] + pdict[dims_only[0]] = 'none' + else: + pdict[dims_only[0]] = padding[0].split(' ')[-1] + pdict[dims_only[1]] = 'none' + else: + pdict[dims_only[0]] = 'none' + pdict[dims_only[1]] = 'none' + + rdict['dims'] = dims_only + rdict['coords'] = coord_str.split(" ") + rdict['padding'] = pdict + return rdict def _get_sgrid_dim_coord_names( grid_topology: xr.DataArray, @@ -193,30 +224,3 @@ def _get_sgrid_dim_coord_names( coords.append(v.split(" ")) return list(zip(dims, coords)) - -def parse_padding_string(dim_string): - ''' - Given a grid_topology dimension string, parse the padding for each dimension. - Returns a dict of {dim0name: padding, - dim1name: padding - } - valid values of padding are: 'none', 'low', 'high', 'both' - ''' - parsed_string = dim_string.replace('(padding: ', '').replace(')', '').replace(':', '') - split_parsed_string = parsed_string.split(' ') - if len(split_parsed_string) == 6: - return {split_parsed_string[0]:split_parsed_string[2], - split_parsed_string[3]:split_parsed_string[5]} - elif len(split_parsed_string) == 5: - if split_parsed_string[4] in {'none', 'low', 'high', 'both'}: - #2nd dim has padding, and with len 5 that means first does not - split_parsed_string.insert(2, 'none') - else: - split_parsed_string.insert(5, 'none') - return {split_parsed_string[0]:split_parsed_string[2], - split_parsed_string[3]:split_parsed_string[5]} - elif len(split_parsed_string) == 2: - #node dimensions string could look like this: 'node_dimensions: xi_psi eta_psi' - return {split_parsed_string[0]: 'none', split_parsed_string[1]: 'none'} - else: - raise ValueError(f"Padding parsing failure: {dim_string}")