From 339c77bd3e4016447b6c3d37c931357b3a2437a3 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Tue, 30 Jul 2024 09:49:31 -0400 Subject: [PATCH] Play with this method of doig things --- xarray_subset_grid/grids/sgrid.py | 28 +++++++++-------- xarray_subset_grid/grids/ugrid.py | 51 ++++++++++++++++--------------- 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/xarray_subset_grid/grids/sgrid.py b/xarray_subset_grid/grids/sgrid.py index bec9f3b..5b71a5c 100644 --- a/xarray_subset_grid/grids/sgrid.py +++ b/xarray_subset_grid/grids/sgrid.py @@ -8,6 +8,10 @@ class SGrid(Grid): """Grid implementation for SGRID datasets""" + _grid_topology_key: str + _grid_topology: xr.DataArray + _dims: list[tuple[list[str], list[str]]] + @staticmethod def recognize(ds: xr.Dataset) -> bool: """Recognize if the dataset matches the given grid""" @@ -20,6 +24,12 @@ def recognize(ds: xr.Dataset) -> bool: # we assume it's a SGRID return len(_grid_topology_keys) > 0 and _grid_topology_keys[0] in ds + def __init__(self, ds: xr.Dataset): + self._grid_topology_key = ds.cf.cf_roles["grid_topology"][0] + self._grid_topology = ds[self._grid_topology_key] + self._dims = _get_sgrid_dim_coord_names(self._grid_topology) + + @property def name(self) -> str: """Name of the grid type""" @@ -31,10 +41,8 @@ def grid_vars(self, ds: xr.Dataset) -> set[str]: These variables are used to define the grid and thus should be kept when subsetting the dataset """ - grid_topology_key = ds.cf.cf_roles["grid_topology"][0] - grid_topology = ds[grid_topology_key] - grid_coords = [grid_topology_key] - for _dims, coords in _get_sgrid_dim_coord_names(grid_topology): + grid_coords = [self._grid_topology_key] + for _dims, coords in self._dims: grid_coords.extend(coords) return set(grid_coords) @@ -45,10 +53,8 @@ def data_vars(self, ds: xr.Dataset) -> set[str]: data analysis. These can be discarded when subsetting the dataset when they are not needed. """ - grid_topology_key = ds.cf.cf_roles["grid_topology"][0] - grid_topology = ds[grid_topology_key] dims = [] - for dims, _coords in _get_sgrid_dim_coord_names(grid_topology): + for dims, _coords in self._dims: dims.extend(dims) dims = set(dims) @@ -62,13 +68,9 @@ def subset_polygon( :param polygon: The polygon to subset to :return: The subsetted dataset """ - grid_topology_key = ds.cf.cf_roles["grid_topology"][0] - grid_topology = ds[grid_topology_key] - dims = _get_sgrid_dim_coord_names(grid_topology) - ds_out = [] - for dim, coord in dims: + for dim, coord in self._dims: # Get the variables that have the dimensions unique_dims = set(dim) vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))] @@ -107,7 +109,7 @@ def subset_polygon( # Merge the subsetted datasets ds_out = xr.merge(ds_out) - ds_out = ds_out.assign({grid_topology_key: grid_topology}) + ds_out = ds_out.assign({self._grid_topology_key: self._grid_topology}) return ds_out diff --git a/xarray_subset_grid/grids/ugrid.py b/xarray_subset_grid/grids/ugrid.py index e8b6ff4..5284fe3 100644 --- a/xarray_subset_grid/grids/ugrid.py +++ b/xarray_subset_grid/grids/ugrid.py @@ -45,6 +45,9 @@ class UGrid(Grid): # TODO: Abstract away common subsetting methods to functions that can be cached for reuse """ + _mesh_topology_key: str + _mesh_topology: xr.DataArray + @staticmethod def recognize(ds: xr.Dataset) -> bool: """Recognize if the dataset matches the given grid""" @@ -56,6 +59,10 @@ def recognize(ds: xr.Dataset) -> bool: return mesh.attrs.get("face_node_connectivity") is not None + def __init__(self, ds: xr.Dataset): + self._mesh_topology_key = ds.cf.cf_roles["mesh_topology"][0] + self._mesh_topology = ds[self._mesh_topology_key] + @property def name(self) -> str: """Name of the grid type""" @@ -68,15 +75,14 @@ def grid_vars(self, ds: xr.Dataset) -> set[str]: These variables are used to define the grid and thus should be kept when subsetting the dataset """ - mesh = ds.cf["mesh_topology"] - vars = {mesh.name} + vars = {self._mesh_topology.mesh.name} for var_name in ALL_MESH_VARS: - if var_name in mesh.attrs: + if var_name in self._mesh_topology.attrs: if "coordinates" in var_name: - _node_coordinates = mesh.node_coordinates.split(" ") - vars.update(mesh.attrs[var_name].split(" ")) + _node_coordinates = self._mesh_topology.node_coordinates.split(" ") + vars.update(self._mesh_topology.attrs[var_name].split(" ")) else: - vars.add(mesh.attrs[var_name]) + vars.add(self._mesh_topology.attrs[var_name]) return vars def data_vars(self, ds: xr.Dataset) -> set[str]: @@ -89,15 +95,14 @@ def data_vars(self, ds: xr.Dataset) -> set[str]: Then all grid_vars are excluded as well. """ - mesh = ds.cf["mesh_topology"] dims = [] # Use the coordinates as the source of truth, the face and node # dimensions are the same as the coordinates and any data variables # that do not contain either face or node dimensions can be ignored - face_coord = mesh.face_coordinates.split(" ")[0] + face_coord = self._mesh_topology.face_coordinates.split(" ")[0] dims.extend(ds[face_coord].dims) - node_coord = mesh.node_coordinates.split(" ")[0] + node_coord = self._mesh_topology.node_coordinates.split(" ")[0] dims.extend(ds[node_coord].dims) dims = set(dims) @@ -118,30 +123,26 @@ def subset_polygon( # For this grid type, we find all nodes that are connected to elements that are inside # the polygon. To do this, we first find all nodes that are inside the polygon and then # find all elements that are connected to those nodes. - try: - mesh = ds.cf["mesh_topology"] - except KeyError as err: - raise ValueError("Dataset has no mesh topology variable") from err - has_face_face_connectivity = "face_face_connectivity" in mesh.attrs - x_var, y_var = mesh.node_coordinates.split(" ") + has_face_face_connectivity = "face_face_connectivity" in self._mesh_topology.attrs + x_var, y_var = self._mesh_topology.node_coordinates.split(" ") x, y = ds[x_var], ds[y_var] node_dimension = x.dims[0] - face_dimension = mesh.attrs.get("face_dimension", None) + face_dimension = self._mesh_topology.attrs.get("face_dimension", None) if not face_dimension: raise ValueError("face_dimension is required to subset UGRID datasets") face_node_indices_dimension = next( - d for d in ds[mesh.face_node_connectivity].dims if d != face_dimension + d for d in ds[self._mesh_topology.face_node_connectivity].dims if d != face_dimension ) # NOTE: When the first dimension is face_dimension, the face_node_connectivity # is indexed by element first, then vertex. When the first dimension - if ds[mesh.face_node_connectivity].dims[0] == face_dimension: + if ds[self._mesh_topology.face_node_connectivity].dims[0] == face_dimension: transpose_face_node_connectivity = False - face_node_connectivity = ds[mesh.face_node_connectivity] + face_node_connectivity = ds[self._mesh_topology.face_node_connectivity] else: transpose_face_node_connectivity = True - face_node_connectivity = ds[mesh.face_node_connectivity].T + face_node_connectivity = ds[self._mesh_topology.face_node_connectivity].T face_node_start_index = face_node_connectivity.attrs.get("start_index", None) if not face_node_start_index: warnings.warn("No start_index found in face_node_connectivity, assuming 0") @@ -189,12 +190,12 @@ def subset_polygon( face_node_new = face_node_new.T if has_face_face_connectivity: - if ds[mesh.face_node_connectivity].dims[0] == face_dimension: + if ds[self._mesh_topology.face_node_connectivity].dims[0] == face_dimension: transpose_face_face_connectivity = False - face_face_connectivity = ds[mesh.face_face_connectivity] + face_face_connectivity = ds[self._mesh_topology.face_face_connectivity] else: transpose_face_face_connectivity = True - face_face_connectivity = ds[mesh.face_face_connectivity].T + face_face_connectivity = ds[self._mesh_topology.face_face_connectivity].T face_face_new = np.searchsorted( selected_elements, face_face_connectivity[selected_elements] ) @@ -205,9 +206,9 @@ def subset_polygon( # Subset using xarrays select indexing, and overwrite the face_node_connectivity # and face_face_connectivity (if available) with the new indices ds_subset = ds.sel({node_dimension: selected_nodes, face_dimension: selected_elements}) - ds_subset[mesh.face_node_connectivity][:] = face_node_new + ds_subset[self._mesh_topology.face_node_connectivity][:] = face_node_new if has_face_face_connectivity: - ds_subset[mesh.face_face_connectivity][:] = face_face_new + ds_subset[self._mesh_topology.face_face_connectivity][:] = face_face_new return ds_subset