Skip to content

Commit

Permalink
Play with this method of doig things
Browse files Browse the repository at this point in the history
  • Loading branch information
mpiannucci committed Jul 30, 2024
1 parent abb551a commit 339c77b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 38 deletions.
28 changes: 15 additions & 13 deletions xarray_subset_grid/grids/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
class SGrid(Grid):
"""Grid implementation for SGRID datasets"""

_grid_topology_key: str
_grid_topology: xr.DataArray
_dims: list[tuple[list[str], list[str]]]

@staticmethod
def recognize(ds: xr.Dataset) -> bool:
"""Recognize if the dataset matches the given grid"""
Expand All @@ -20,6 +24,12 @@ def recognize(ds: xr.Dataset) -> bool:
# we assume it's a SGRID
return len(_grid_topology_keys) > 0 and _grid_topology_keys[0] in ds

def __init__(self, ds: xr.Dataset):
self._grid_topology_key = ds.cf.cf_roles["grid_topology"][0]
self._grid_topology = ds[self._grid_topology_key]
self._dims = _get_sgrid_dim_coord_names(self._grid_topology)


@property
def name(self) -> str:
"""Name of the grid type"""
Expand All @@ -31,10 +41,8 @@ def grid_vars(self, ds: xr.Dataset) -> set[str]:
These variables are used to define the grid and thus should be kept
when subsetting the dataset
"""
grid_topology_key = ds.cf.cf_roles["grid_topology"][0]
grid_topology = ds[grid_topology_key]
grid_coords = [grid_topology_key]
for _dims, coords in _get_sgrid_dim_coord_names(grid_topology):
grid_coords = [self._grid_topology_key]
for _dims, coords in self._dims:
grid_coords.extend(coords)
return set(grid_coords)

Expand All @@ -45,10 +53,8 @@ def data_vars(self, ds: xr.Dataset) -> set[str]:
data analysis. These can be discarded when subsetting the dataset
when they are not needed.
"""
grid_topology_key = ds.cf.cf_roles["grid_topology"][0]
grid_topology = ds[grid_topology_key]
dims = []
for dims, _coords in _get_sgrid_dim_coord_names(grid_topology):
for dims, _coords in self._dims:
dims.extend(dims)
dims = set(dims)

Expand All @@ -62,13 +68,9 @@ def subset_polygon(
:param polygon: The polygon to subset to
:return: The subsetted dataset
"""
grid_topology_key = ds.cf.cf_roles["grid_topology"][0]
grid_topology = ds[grid_topology_key]
dims = _get_sgrid_dim_coord_names(grid_topology)

ds_out = []

for dim, coord in dims:
for dim, coord in self._dims:
# Get the variables that have the dimensions
unique_dims = set(dim)
vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))]
Expand Down Expand Up @@ -107,7 +109,7 @@ def subset_polygon(
# Merge the subsetted datasets
ds_out = xr.merge(ds_out)

ds_out = ds_out.assign({grid_topology_key: grid_topology})
ds_out = ds_out.assign({self._grid_topology_key: self._grid_topology})

return ds_out

Expand Down
51 changes: 26 additions & 25 deletions xarray_subset_grid/grids/ugrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class UGrid(Grid):
# TODO: Abstract away common subsetting methods to functions that can be cached for reuse
"""

_mesh_topology_key: str
_mesh_topology: xr.DataArray

@staticmethod
def recognize(ds: xr.Dataset) -> bool:
"""Recognize if the dataset matches the given grid"""
Expand All @@ -56,6 +59,10 @@ def recognize(ds: xr.Dataset) -> bool:

return mesh.attrs.get("face_node_connectivity") is not None

def __init__(self, ds: xr.Dataset):
self._mesh_topology_key = ds.cf.cf_roles["mesh_topology"][0]
self._mesh_topology = ds[self._mesh_topology_key]

@property
def name(self) -> str:
"""Name of the grid type"""
Expand All @@ -68,15 +75,14 @@ def grid_vars(self, ds: xr.Dataset) -> set[str]:
These variables are used to define the grid and thus should be kept
when subsetting the dataset
"""
mesh = ds.cf["mesh_topology"]
vars = {mesh.name}
vars = {self._mesh_topology.mesh.name}
for var_name in ALL_MESH_VARS:
if var_name in mesh.attrs:
if var_name in self._mesh_topology.attrs:
if "coordinates" in var_name:
_node_coordinates = mesh.node_coordinates.split(" ")
vars.update(mesh.attrs[var_name].split(" "))
_node_coordinates = self._mesh_topology.node_coordinates.split(" ")
vars.update(self._mesh_topology.attrs[var_name].split(" "))
else:
vars.add(mesh.attrs[var_name])
vars.add(self._mesh_topology.attrs[var_name])
return vars

def data_vars(self, ds: xr.Dataset) -> set[str]:
Expand All @@ -89,15 +95,14 @@ def data_vars(self, ds: xr.Dataset) -> set[str]:
Then all grid_vars are excluded as well.
"""
mesh = ds.cf["mesh_topology"]
dims = []

# Use the coordinates as the source of truth, the face and node
# dimensions are the same as the coordinates and any data variables
# that do not contain either face or node dimensions can be ignored
face_coord = mesh.face_coordinates.split(" ")[0]
face_coord = self._mesh_topology.face_coordinates.split(" ")[0]
dims.extend(ds[face_coord].dims)
node_coord = mesh.node_coordinates.split(" ")[0]
node_coord = self._mesh_topology.node_coordinates.split(" ")[0]
dims.extend(ds[node_coord].dims)

dims = set(dims)
Expand All @@ -118,30 +123,26 @@ def subset_polygon(
# For this grid type, we find all nodes that are connected to elements that are inside
# the polygon. To do this, we first find all nodes that are inside the polygon and then
# find all elements that are connected to those nodes.
try:
mesh = ds.cf["mesh_topology"]
except KeyError as err:
raise ValueError("Dataset has no mesh topology variable") from err
has_face_face_connectivity = "face_face_connectivity" in mesh.attrs
x_var, y_var = mesh.node_coordinates.split(" ")
has_face_face_connectivity = "face_face_connectivity" in self._mesh_topology.attrs
x_var, y_var = self._mesh_topology.node_coordinates.split(" ")
x, y = ds[x_var], ds[y_var]
node_dimension = x.dims[0]

face_dimension = mesh.attrs.get("face_dimension", None)
face_dimension = self._mesh_topology.attrs.get("face_dimension", None)
if not face_dimension:
raise ValueError("face_dimension is required to subset UGRID datasets")
face_node_indices_dimension = next(
d for d in ds[mesh.face_node_connectivity].dims if d != face_dimension
d for d in ds[self._mesh_topology.face_node_connectivity].dims if d != face_dimension
)

# NOTE: When the first dimension is face_dimension, the face_node_connectivity
# is indexed by element first, then vertex. When the first dimension
if ds[mesh.face_node_connectivity].dims[0] == face_dimension:
if ds[self._mesh_topology.face_node_connectivity].dims[0] == face_dimension:
transpose_face_node_connectivity = False
face_node_connectivity = ds[mesh.face_node_connectivity]
face_node_connectivity = ds[self._mesh_topology.face_node_connectivity]
else:
transpose_face_node_connectivity = True
face_node_connectivity = ds[mesh.face_node_connectivity].T
face_node_connectivity = ds[self._mesh_topology.face_node_connectivity].T
face_node_start_index = face_node_connectivity.attrs.get("start_index", None)
if not face_node_start_index:
warnings.warn("No start_index found in face_node_connectivity, assuming 0")
Expand Down Expand Up @@ -189,12 +190,12 @@ def subset_polygon(
face_node_new = face_node_new.T

if has_face_face_connectivity:
if ds[mesh.face_node_connectivity].dims[0] == face_dimension:
if ds[self._mesh_topology.face_node_connectivity].dims[0] == face_dimension:
transpose_face_face_connectivity = False
face_face_connectivity = ds[mesh.face_face_connectivity]
face_face_connectivity = ds[self._mesh_topology.face_face_connectivity]
else:
transpose_face_face_connectivity = True
face_face_connectivity = ds[mesh.face_face_connectivity].T
face_face_connectivity = ds[self._mesh_topology.face_face_connectivity].T
face_face_new = np.searchsorted(
selected_elements, face_face_connectivity[selected_elements]
)
Expand All @@ -205,9 +206,9 @@ def subset_polygon(
# Subset using xarrays select indexing, and overwrite the face_node_connectivity
# and face_face_connectivity (if available) with the new indices
ds_subset = ds.sel({node_dimension: selected_nodes, face_dimension: selected_elements})
ds_subset[mesh.face_node_connectivity][:] = face_node_new
ds_subset[self._mesh_topology.face_node_connectivity][:] = face_node_new
if has_face_face_connectivity:
ds_subset[mesh.face_face_connectivity][:] = face_face_new
ds_subset[self._mesh_topology.face_face_connectivity][:] = face_face_new
return ds_subset


Expand Down

0 comments on commit 339c77b

Please sign in to comment.