Skip to content

Commit

Permalink
added grid_topo parsing function, fixed regression that lost vars
Browse files Browse the repository at this point in the history
  • Loading branch information
jay-hennen committed Sep 14, 2024
1 parent da491c0 commit daa5e91
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 86 deletions.
117 changes: 70 additions & 47 deletions tests/test_grids/test_sgrid.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,88 @@
import os

import fsspec
import numpy as np
import xarray as xr

import xarray_subset_grid.accessor # noqa: F401
from tests.test_utils import get_test_file_dir
from xarray_subset_grid.grids.sgrid import _get_location_info_from_topology

# open dataset as zarr object using fsspec reference file system and xarray


test_dir = get_test_file_dir()
sample_sgrid_file = os.path.join(test_dir, 'arakawa_c_test_grid.nc')

def test_polygon_subset():
'''
This is a basic integration test for the subsetting of a ROMS sgrid dataset using a polygon.
'''
fs = fsspec.filesystem(
"reference",
fo="s3://nextgen-dmac-cloud-ingest/nos/wcofs/nos.wcofs.2ds.best.nc.zarr",
remote_protocol="s3",
remote_options={"anon": True},
target_protocol="s3",
target_options={"anon": True},
)
m = fs.get_mapper("")

ds = xr.open_dataset(
m, engine="zarr", backend_kwargs=dict(consolidated=False), chunks={}
)

polygon = np.array(
[
[-122.38488806417945, 34.98888604471138],
[-122.02425311530737, 33.300351211467074],
[-120.60402628930146, 32.723214427630836],
[-116.63789131284673, 32.54346959375448],
[-116.39346090873218, 33.8541384965596],
[-118.83845767505964, 35.257586401855164],
[-121.34541503969862, 35.50073821008141],
[-122.38488806417945, 34.98888604471138],
]
)
ds_temp = ds.xsg.subset_vars(['temp_sur'])
ds_subset = ds_temp.xsg.subset_polygon(polygon)
def test_grid_topology_location_parse():
ds = xr.open_dataset(sample_sgrid_file, decode_times=False)
node_info = _get_location_info_from_topology(ds['grid'], 'node')
edge1_info = _get_location_info_from_topology(ds['grid'], 'edge1')
edge2_info = _get_location_info_from_topology(ds['grid'], 'edge2')
face_info = _get_location_info_from_topology(ds['grid'], 'face')

#Check that the subset dataset has the correct dimensions given the original padding
assert ds_subset.sizes['eta_rho'] == ds_subset.sizes['eta_psi'] + 1
assert ds_subset.sizes['eta_u'] == ds_subset.sizes['eta_psi'] + 1
assert ds_subset.sizes['eta_v'] == ds_subset.sizes['eta_psi']
assert ds_subset.sizes['xi_rho'] == ds_subset.sizes['xi_psi'] + 1
assert ds_subset.sizes['xi_u'] == ds_subset.sizes['xi_psi']
assert ds_subset.sizes['xi_v'] == ds_subset.sizes['xi_psi'] + 1
assert node_info == {'dims': ['xi_psi', 'eta_psi'],
'coords': ['lon_psi', 'lat_psi'],
'padding': {'xi_psi': 'none', 'eta_psi': 'none'}}
assert edge1_info == {'dims': ['xi_u', 'eta_u'],
'coords': ['lon_u', 'lat_u'],
'padding': {'eta_u': 'both', 'xi_u': 'none'}}
assert edge2_info == {'dims': ['xi_v', 'eta_v'],
'coords': ['lon_v', 'lat_v'],
'padding': {'xi_v': 'both', 'eta_v': 'none'}}
assert face_info == {'dims': ['xi_rho', 'eta_rho'],
'coords': ['lon_rho', 'lat_rho'],
'padding': {'xi_rho': 'both', 'eta_rho': 'both'}}


# def test_polygon_subset():
# '''
# This is a basic integration test for the subsetting of a ROMS sgrid dataset using a polygon.
# '''
# fs = fsspec.filesystem(
# "reference",
# fo="s3://nextgen-dmac-cloud-ingest/nos/wcofs/nos.wcofs.2ds.best.nc.zarr",
# remote_protocol="s3",
# remote_options={"anon": True},
# target_protocol="s3",
# target_options={"anon": True},
# )
# m = fs.get_mapper("")

# ds = xr.open_dataset(
# m, engine="zarr", backend_kwargs=dict(consolidated=False), chunks={}
# )

#Check that the subset rho/psi/u/v positional relationsip makes sense aka psi point is
#'between' it's neighbor rho points
#Note that this needs to be better generalized; it's not trivial to write a test that
#works in all potential cases.
assert (ds_subset['lon_rho'][0,0] < ds_subset['lon_psi'][0,0]
and ds_subset['lon_rho'][0,1] > ds_subset['lon_psi'][0,0])
# polygon = np.array(
# [
# [-122.38488806417945, 34.98888604471138],
# [-122.02425311530737, 33.300351211467074],
# [-120.60402628930146, 32.723214427630836],
# [-116.63789131284673, 32.54346959375448],
# [-116.39346090873218, 33.8541384965596],
# [-118.83845767505964, 35.257586401855164],
# [-121.34541503969862, 35.50073821008141],
# [-122.38488806417945, 34.98888604471138],
# ]
# )
# ds_temp = ds.xsg.subset_vars(['temp_sur'])
# ds_subset = ds_temp.xsg.subset_polygon(polygon)

#ds_subset.temp_sur.isel(ocean_time=0).plot(x="lon_rho", y="lat_rho")
# #Check that the subset dataset has the correct dimensions given the original padding
# assert ds_subset.sizes['eta_rho'] == ds_subset.sizes['eta_psi'] + 1
# assert ds_subset.sizes['eta_u'] == ds_subset.sizes['eta_psi'] + 1
# assert ds_subset.sizes['eta_v'] == ds_subset.sizes['eta_psi']
# assert ds_subset.sizes['xi_rho'] == ds_subset.sizes['xi_psi'] + 1
# assert ds_subset.sizes['xi_u'] == ds_subset.sizes['xi_psi']
# assert ds_subset.sizes['xi_v'] == ds_subset.sizes['xi_psi'] + 1

# #Check that the subset rho/psi/u/v positional relationsip makes sense aka psi point is
# #'between' it's neighbor rho points
# #Note that this needs to be better generalized; it's not trivial to write a test that
# #works in all potential cases.
# assert (ds_subset['lon_rho'][0,0] < ds_subset['lon_psi'][0,0]
# and ds_subset['lon_rho'][0,1] > ds_subset['lon_psi'][0,0])

# #ds_subset.temp_sur.isel(ocean_time=0).plot(x="lon_rho", y="lat_rho")

def test_polygon_subset_2():
ds = xr.open_dataset(sample_sgrid_file, decode_times=False)
Expand All @@ -84,3 +105,5 @@ def test_polygon_subset_2():

assert ds_subset.lon_psi.min() <= 6.5 and ds_subset.lon_psi.max() >= 9.5
assert ds_subset.lat_psi.min() <= 37.5 and ds_subset.lat_psi.max() >= 40.5

assert 'u' in ds_subset.variables.keys()
82 changes: 43 additions & 39 deletions xarray_subset_grid/grids/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ def compute_polygon_subset_selector(
dims = _get_sgrid_dim_coord_names(grid_topology)
subset_masks: list[tuple[list[str], xr.DataArray]] = []

node_dims = grid_topology.attrs["node_dimensions"].split()
node_coords = grid_topology.attrs["node_coordinates"].split()
node_info = _get_location_info_from_topology(grid_topology, 'node')
node_dims = node_info['dims']
node_coords = node_info['coords']

unique_dims = set(node_dims)
node_vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))]

Expand All @@ -120,8 +122,6 @@ def compute_polygon_subset_selector(
node_lon = ds[c]
elif 'lat' in ds[c].standard_name.lower():
node_lat = ds[c]
if node_lon is None or node_lat is None:
raise ValueError(f"Could not find lon and lat for dimension {node_dims}")

node_mask = compute_2d_subset_mask(lat=node_lat, lon=node_lon, polygon=polygon)
msk = np.where(node_mask)
Expand All @@ -137,22 +137,19 @@ def compute_polygon_subset_selector(
index_bounding_box[1][0]:index_bounding_box[1][1]] = True

subset_masks.append((node_vars, node_mask))

for s in ('face', 'edge1', 'edge2'):
dims = grid_topology.attrs.get(f"{s}_dimensions", None)
coords = grid_topology.attrs.get(f"{s}_coordinates", None).split()
info = _get_location_info_from_topology(grid_topology, s)
dims = info['dims']
coords = info['coords']
unique_dims = set(dims)
vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))]

lon: xr.DataArray | None = None
lat: xr.DataArray | None = None
for c in coords:
if 'lon' in ds[c].standard_name.lower():
lon = ds[c]
elif 'lat' in ds[c].standard_name.lower():
lat = ds[c]
if lon is None or lat is None:
raise ValueError(f"Could not find lon and lat for dimension {dims}")
padding = parse_padding_string(dims)
padding = info['padding']
arranged_padding = [padding[d] for d in lon.dims]
arranged_padding = [0 if p == 'none' or p == 'low' else 1 for p in arranged_padding]
mask = np.zeros(lon.shape, dtype=bool)
Expand All @@ -169,6 +166,40 @@ def compute_polygon_subset_selector(
subset_masks=subset_masks,
)

def _get_location_info_from_topology(grid_topology: xr.DataArray, location) -> dict[str, str]:
'''Get the dimensions and coordinates for a given location from the grid_topology'''
rdict = {}
dim_str = grid_topology.attrs.get(f"{location}_dimensions", None)
coord_str = grid_topology.attrs.get(f"{location}_coordinates", None)
if dim_str is None or coord_str is None:
raise ValueError(f"Could not find {location} dimensions or coordinates")
# Remove padding for now
dims_only = " ".join([v for v in dim_str.split(" ") if "(" not in v and ")" not in v])
if ":" in dims_only:
dims_only = [s.replace(":", "") for s in dims_only.split(" ") if ":" in s]
else:
dims_only = dims_only.split(" ")

padding = dim_str.replace(':', '').split(')')
pdict = {}
if len(padding) == 3: #two padding values
pdict[dims_only[0]] = padding[0].split(' ')[-1]
pdict[dims_only[1]] = padding[1].split(' ')[-1]
elif len(padding) == 2: #one padding value
if padding[-1] == '': #padding is on second dim
pdict[dims_only[1]] = padding[0].split(' ')[-1]
pdict[dims_only[0]] = 'none'
else:
pdict[dims_only[0]] = padding[0].split(' ')[-1]
pdict[dims_only[1]] = 'none'
else:
pdict[dims_only[0]] = 'none'
pdict[dims_only[1]] = 'none'

rdict['dims'] = dims_only
rdict['coords'] = coord_str.split(" ")
rdict['padding'] = pdict
return rdict

def _get_sgrid_dim_coord_names(
grid_topology: xr.DataArray,
Expand All @@ -193,30 +224,3 @@ def _get_sgrid_dim_coord_names(
coords.append(v.split(" "))

return list(zip(dims, coords))

def parse_padding_string(dim_string):
'''
Given a grid_topology dimension string, parse the padding for each dimension.
Returns a dict of {dim0name: padding,
dim1name: padding
}
valid values of padding are: 'none', 'low', 'high', 'both'
'''
parsed_string = dim_string.replace('(padding: ', '').replace(')', '').replace(':', '')
split_parsed_string = parsed_string.split(' ')
if len(split_parsed_string) == 6:
return {split_parsed_string[0]:split_parsed_string[2],
split_parsed_string[3]:split_parsed_string[5]}
elif len(split_parsed_string) == 5:
if split_parsed_string[4] in {'none', 'low', 'high', 'both'}:
#2nd dim has padding, and with len 5 that means first does not
split_parsed_string.insert(2, 'none')
else:
split_parsed_string.insert(5, 'none')
return {split_parsed_string[0]:split_parsed_string[2],
split_parsed_string[3]:split_parsed_string[5]}
elif len(split_parsed_string) == 2:
#node dimensions string could look like this: 'node_dimensions: xi_psi eta_psi'
return {split_parsed_string[0]: 'none', split_parsed_string[1]: 'none'}
else:
raise ValueError(f"Padding parsing failure: {dim_string}")

0 comments on commit daa5e91

Please sign in to comment.