diff --git a/ocean/transects/python/README.md b/ocean/transects/python/README.md new file mode 100644 index 00000000..0016751f --- /dev/null +++ b/ocean/transects/python/README.md @@ -0,0 +1,57 @@ +# Python Transect Tools + +## compute_transects.py + +Computes transport through sections. + +Example call: +``` + ./compute_transects.py + -k transect_masks.nc + -m MPAS_mesh.nc + -t 'RUN_PATH/analysis_members/timeSeriesStatsMonthly.*.nc' + -n 'all' +``` +To create the `transect_masks.nc` file, load e3sm-unified and: +``` + MpasMaskCreator.x MPAS_mesh.nc transect_masks.nc -f transect_definitions.geojson +``` +where the `transect_definitions.geojson` file includes a sequence of lat/lon +points for each transect. + +On LANL IC, example file is at +``` +/usr/projects/climate/mpeterse/analysis_input_files/geojson_files/SingleRegionAtlanticWTransportTransects.geojson +``` + +## create_transect_masks.py + +Requires a conda environment with: +* `python` +* `geometric_features` +* `matplotlib` +* `mpas_tools` +* `netcdf4` +* `numpy` +* `scipy` +* `shapely` +* `xarray` + +The tools creates cell and edge masks, distance along the transect of cells +and edges in the mask, and the edge sign on edges. It also includes +information (distance, cell and edge indices, interpolation weights, etc.) +along the transect itself to aid plotting. + +The required inputs are an MPAS mesh file and a geojson file or the name of an +ocean transect from `geometric_features`. The required output is a filename +with the masks and other information about the transect. + +## cut_closed_transect.py + +If a transect is a closed loop, the path of edges and edge signs don't work +correctly (the shortest path between the beginning and end of the transect is +trivial and involves a single edge). To avoid this, we provide a tool for +cutting a square (in lat/lon space) out of the transect to sever the loop. +The user provides a latitude and longitude (used to locate the closest point) +on the transect and the size of the square to cut out. + diff --git a/ocean/transects/python/create_transect_masks.py b/ocean/transects/python/create_transect_masks.py new file mode 100755 index 00000000..e652ad80 --- /dev/null +++ b/ocean/transects/python/create_transect_masks.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python + +import argparse + +import numpy as np +import xarray as xr +from geometric_features import read_feature_collection +from geometric_features import GeometricFeatures +from mpas_tools.cime.constants import constants +from mpas_tools.logging import LoggingContext +from mpas_tools.io import write_netcdf +from mpas_tools.mesh.mask import compute_mpas_transect_masks +from mpas_tools.parallel import create_pool + +from transect.vert import compute_transect + + +def combine_transect_datasets(ds_mesh, fc_transect, out_filename, pool, + logger): + """ + Combine transects masks on cells and edges with a dataset for plotting + that describes how the transect slices through cell and edge geometry. + Add fields on edges and cells that define the (mean) distance along the + transect for each cell or edge in the transect + """ + + earth_radius = constants['SHR_CONST_REARTH'] + + ds_mask = compute_mpas_transect_masks(dsMesh=ds_mesh, fcMask=fc_transect, + earthRadius=earth_radius, + maskTypes=('cell', 'edge',), + logger=logger, + pool=pool, addEdgeSign=True) + + feature = fc_transect.features[0] + geom_type = feature['geometry']['type'] + if geom_type == 'LineString': + coordinates = [feature['geometry']['coordinates']] + elif geom_type == 'MultiLineString': + coordinates = feature['geometry']['coordinates'] + else: + raise ValueError( + f'Unexpected geometry type for the transect {geom_type}') + + lon = [] + lat = [] + for coords in coordinates: + lon_local, lat_local = zip(*coords) + lon.extend(lon_local) + lat.extend(lat_local) + lon = xr.DataArray(data=lon, dims='nTransectPoints') + lat = xr.DataArray(data=lat, dims='nTransectPoints') + + layer_thickness = ds_mesh.layerThickness + bottom_depth = ds_mesh.bottomDepth + min_level_cell = ds_mesh.minLevelCell + max_level_cell = ds_mesh.maxLevelCell + + ds_transect = compute_transect(lon, lat, ds_mesh, layer_thickness, + bottom_depth, min_level_cell, + max_level_cell, spherical=True) + + ds = ds_mask + for var in ds_transect.data_vars: + ds[var] = ds_transect[var] + + add_distance_field(ds, logger) + + write_netcdf(ds, out_filename) + + +def add_distance_field(ds, logger): + """ + Add fields on edges and cells that define the (mean) distance along the + transect for each cell or edge in the transect + """ + + dist_cell = np.zeros(ds.sizes['nCells']) + count_cell = np.zeros(ds.sizes['nCells'], dtype=int) + dist_edge = np.zeros(ds.sizes['nEdges']) + count_edge = np.zeros(ds.sizes['nEdges'], dtype=int) + + logger.info('Adding transect distance fields on cells and edges...') + + for segment in range(ds.sizes['nSegments']): + icell = ds.horizCellIndices.isel(nSegments=segment).values + iedge = ds.horizEdgeIndices.isel(nSegments=segment).values + # the distance for the midpoint of the segment is the mean + # of the distances of the end points + dist = 0.5 * (ds.dNode.isel(nHorizNodes=segment) + + ds.dNode.isel(nHorizNodes=segment + 1)) + dist_cell[icell] += dist + count_cell[icell] += 1 + dist_edge[iedge] += dist + count_edge[iedge] += 1 + + mask = count_cell > 0 + dist_cell[mask] /= count_cell[mask] + dist_cell[np.logical_not(mask)] = np.nan + + mask = count_edge > 0 + dist_edge[mask] /= count_edge[mask] + dist_edge[np.logical_not(mask)] = np.nan + + ds['transectDistanceCell'] = ('nCells', dist_cell) + ds['transectDistanceEdge'] = ('nEdges', dist_edge) + logger.info('Done.') + + +def main(): + + parser = argparse.ArgumentParser(description=''' + creates transect edge and cell masks along with edge sign and distance + along the transect''') + parser.add_argument('-m', dest='mesh_filename', + help='MPAS-Ocean horizontal and vertical filename', + required=True) + parser.add_argument('-g', dest='geojson_filename', + help='Geojson filename with transect', required=False) + parser.add_argument('-f', dest='feature_name', + help='Name of an ocean transect from ' + 'geometric_features', + required=False) + parser.add_argument('-o', dest='out_filename', + help='Edge transect filename', required=True) + parser.add_argument( + "--process_count", required=False, dest="process_count", type=int, + help="The number of processes to use to compute masks. The " + "default is to use all available cores") + parser.add_argument( + "--multiprocessing_method", dest="multiprocessing_method", + default='forkserver', + help="The multiprocessing method use for python mask creation " + "('fork', 'spawn' or 'forkserver')") + args = parser.parse_args() + + if args.geojson_filename is None and args.feature_name is None: + raise ValueError('Must supply either a geojson file or a transect ' + 'name') + + if args.geojson_filename is not None: + fc_transect = read_feature_collection(args.geojson_filename) + else: + gf = GeometricFeatures() + fc_transect = gf.read(componentName='ocean', objectType='transect', + featureNames=[args.feature_name]) + + ds_mesh = xr.open_dataset(args.mesh_filename) + if 'Time' in ds_mesh.dims: + ds_mesh = ds_mesh.isel(Time=0) + + pool = create_pool(process_count=args.process_count, + method=args.multiprocessing_method) + + with LoggingContext('create_transect_masks') as logger: + + combine_transect_datasets(ds_mesh, fc_transect, args.out_filename, + pool, logger) + + +if __name__ == '__main__': + main() diff --git a/ocean/transects/python/cut_closed_transect.py b/ocean/transects/python/cut_closed_transect.py new file mode 100755 index 00000000..b6b7326e --- /dev/null +++ b/ocean/transects/python/cut_closed_transect.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python + +import argparse + +import numpy as np +from geometric_features import ( + GeometricFeatures, + read_feature_collection +) +from shapely.geometry import ( + mapping, + Polygon, + shape +) + + +def cut_transect(fc_transect, lat, lon, size, out_filename): + """ + Cut a square out of the given closed-loop transect to break the loop. + """ + + # find the closest point in the transect to the specificed lat/lon + + feature = fc_transect.features[0] + coordinates = feature['geometry']['coordinates'] + feature_type = feature['geometry']['type'] + if feature_type == 'LineString': + coordinates = [coordinates] + elif feature_type != 'MultiLineString': + raise ValueError( + f'Unexpected geometry type for transect {feature_type}') + + min_dist = None + center_lon = None + center_lan = None + for coords in coordinates: + lon_local, lat_local = zip(*coords) + dist = np.sqrt((np.array(lon_local) - lon)**2 + + (np.array(lat_local) - lat)**2) + index_min = np.argmin(dist) + if min_dist is None or dist[index_min] < min_dist: + center_lon = lon_local[index_min] + center_lan = lat_local[index_min] + min_dist = dist[index_min] + + square = Polygon([(center_lon - 0.5 * size, center_lan - 0.5 * size), + (center_lon - 0.5 * size, center_lan + 0.5 * size), + (center_lon + 0.5 * size, center_lan + 0.5 * size), + (center_lon + 0.5 * size, center_lan - 0.5 * size), + (center_lon - 0.5 * size, center_lan - 0.5 * size)]) + + feature = fc_transect.features[0] + transect_shape = shape(feature['geometry']) + transect_shape = transect_shape.difference(square) + + # now sort the coordinates so the start and end of the transect are at the + # dividing point + + feature['geometry'] = mapping(transect_shape) + + feature_type = feature['geometry']['type'] + if feature_type == 'MultiLineString': + coordinates = feature['geometry']['coordinates'] + + # reorder the LineStrings so the first one starts right after the cut + + closest = None + min_dist = None + for index, coords in enumerate(coordinates): + lon_first, lat_first = coords[0] + dist = np.sqrt((lon_first - lon)**2 + (lat_first - lat)**2) + if min_dist is None or dist < min_dist: + closest = index + min_dist = dist + new_coords = list(coordinates[closest:]) + new_coords.extend(list(coordinates[:closest])) + feature['geometry']['coordinates'] = tuple(new_coords) + + fc_transect.to_geojson(out_filename) + + +def main(): + + parser = argparse.ArgumentParser(description=''' + cut the given transect loop as close as possible to the given + latitude and longitude''') + parser.add_argument('-g', dest='geojson_filename', + help='Geojson filename with transect', required=False) + parser.add_argument('-f', dest='feature_name', + help='Name of an ocean transect from ' + 'geometric_features', + required=False) + parser.add_argument('--lat', dest='lat', type=float, + help='The approx. latitude at which to cut the loop', + required=True) + parser.add_argument('--lon', dest='lon', type=float, + help='The approx. longitude at which to cut the loop', + required=True) + parser.add_argument('--size', dest='size', type=float, + help='The size in degrees of the square used to cut ' + 'the loop', + required=True) + parser.add_argument('-o', dest='out_filename', + help='The geojson file with the cut transect to write ' + 'out', + required=True) + args = parser.parse_args() + + if args.geojson_filename is None and args.feature_name is None: + raise ValueError('Must supply either a geojson file or a transect ' + 'name') + + if args.geojson_filename is not None: + fc_transect = read_feature_collection(args.geojson_filename) + else: + gf = GeometricFeatures() + fc_transect = gf.read(componentName='ocean', objectType='transect', + featureNames=[args.feature_name]) + + cut_transect(fc_transect, args.lat, args.lon, args.size, args.out_filename) + + +if __name__ == '__main__': + main() diff --git a/ocean/transects/python/transect/__init__.py b/ocean/transects/python/transect/__init__.py new file mode 100644 index 00000000..cea375db --- /dev/null +++ b/ocean/transects/python/transect/__init__.py @@ -0,0 +1,3 @@ +""" +Copied from Polaris +""" diff --git a/ocean/transects/python/transect/horiz.py b/ocean/transects/python/transect/horiz.py new file mode 100644 index 00000000..bd561a92 --- /dev/null +++ b/ocean/transects/python/transect/horiz.py @@ -0,0 +1,916 @@ +""" +Copied from Polaris +""" +import numpy as np +import xarray as xr +from mpas_tools.transects import ( + cartesian_to_lon_lat, + lon_lat_to_cartesian, + subdivide_great_circle, + subdivide_planar, +) +from mpas_tools.vector import Vector +from scipy.spatial import cKDTree +from shapely.geometry import LineString, Point + + +def mesh_to_triangles(ds_mesh): + """ + Construct a dataset in which each MPAS cell is divided into the triangles + connecting pairs of adjacent vertices to cell centers. + + Parameters + ---------- + ds_mesh : xarray.Dataset + An MPAS mesh + + Returns + ------- + ds_tris : xarray.Dataset + A dataset that defines triangles connecting pairs of adjacent vertices + to cell centers as well as the cell index that each triangle is in and + cell indices and weights for interpolating data defined at cell centers + to triangle nodes. ``ds_tris`` includes variables ``triCellIndices``, + the cell that each triangle is part of; ``triEdgeIndices``, the edge + that each triangle is adjacent to; ``nodeCellIndices`` and + ``nodeCellWeights``, the indices and weights used to interpolate from + MPAS cell centers to triangle nodes; Cartesian coordinates ``xNode``, + ``yNode``, and ``zNode``; and ``lonNode``` and ``latNode`` in radians. + ``lonNode`` is guaranteed to be within 180 degrees of the cell center + corresponding to ``triCellIndices``. Nodes always have a + counterclockwise winding. + + """ + n_vertices_on_cell = ds_mesh.nEdgesOnCell.values + vertices_on_cell = ds_mesh.verticesOnCell.values - 1 + cells_on_vertex = ds_mesh.cellsOnVertex.values - 1 + edges_on_cell = ds_mesh.edgesOnCell.values - 1 + + on_a_sphere = ds_mesh.attrs['on_a_sphere'].strip() == 'YES' + is_periodic = False + x_period = None + y_period = None + if not on_a_sphere: + is_periodic = ds_mesh.attrs['is_periodic'].strip() == 'YES' + if is_periodic: + x_period = ds_mesh.attrs['x_period'] + y_period = ds_mesh.attrs['y_period'] + + kite_areas_on_vertex = ds_mesh.kiteAreasOnVertex.values + + n_triangles = np.sum(n_vertices_on_cell) + + max_edges = ds_mesh.sizes['maxEdges'] + n_cells = ds_mesh.sizes['nCells'] + if ds_mesh.sizes['vertexDegree'] != 3: + raise ValueError('mesh_to_triangles only supports meshes with ' + 'vertexDegree = 3') + + # find the third vertex for each triangle + next_vertex = -1 * np.ones(vertices_on_cell.shape, int) + next_edge = -1 * np.ones(edges_on_cell.shape, int) + for i_vertex in range(max_edges): + valid = i_vertex < n_vertices_on_cell + invalid = np.logical_not(valid) + vertices_on_cell[invalid, i_vertex] = -1 + nv = n_vertices_on_cell[valid] + cell_indices = np.arange(0, n_cells)[valid] + i_next = np.where(i_vertex < nv - 1, i_vertex + 1, 0) + next_vertex[:, i_vertex][valid] = ( + vertices_on_cell[cell_indices, i_next]) + next_edge[:, i_vertex][valid] = ( + edges_on_cell[cell_indices, i_next]) + + valid = vertices_on_cell >= 0 + vertices_on_cell = vertices_on_cell[valid] + next_vertex = next_vertex[valid] + + # find the cell index for each triangle + tri_cell_indices, _ = np.meshgrid(np.arange(0, n_cells), + np.arange(0, max_edges), + indexing='ij') + tri_cell_indices = tri_cell_indices[valid] + + # find the edge index for each triangle. Since the nth edge lies between + # the nth and n+1st vertex, this is next_edge, rather than edges_on_cell + tri_edge_indices = next_edge[valid] + + # find list of cells and weights for each triangle node + node_cell_indices = -1 * np.ones((n_triangles, 3, 3), dtype=int) + node_cell_weights = np.zeros((n_triangles, 3, 3)) + + # the first node is at the cell center, so the value is just the one from + # that cell + node_cell_indices[:, 0, 0] = tri_cell_indices + node_cell_weights[:, 0, 0] = 1. + + # the other 2 nodes are associated with vertices + node_cell_indices[:, 1, :] = cells_on_vertex[vertices_on_cell, :] + node_cell_weights[:, 1, :] = kite_areas_on_vertex[vertices_on_cell, :] + node_cell_indices[:, 2, :] = cells_on_vertex[next_vertex, :] + node_cell_weights[:, 2, :] = kite_areas_on_vertex[next_vertex, :] + + weight_sum = np.sum(node_cell_weights, axis=2) + for i_node in range(3): + node_cell_weights[:, :, i_node] = ( + node_cell_weights[:, :, i_node] / weight_sum) + + ds_tris = xr.Dataset() + ds_tris['triCellIndices'] = ('nTriangles', tri_cell_indices) + ds_tris['triEdgeIndices'] = ('nTriangles', tri_edge_indices) + ds_tris['nodeCellIndices'] = (('nTriangles', 'nNodes', 'nInterp'), + node_cell_indices) + ds_tris['nodeCellWeights'] = (('nTriangles', 'nNodes', 'nInterp'), + node_cell_weights) + + # get Cartesian and lon/lat coordinates of each node + for prefix in ['x', 'y', 'z', 'lat', 'lon']: + out_var = f'{prefix}Node' + cell_var = f'{prefix}Cell' + vertex_var = f'{prefix}Vertex' + coord = np.zeros((n_triangles, 3)) + coord[:, 0] = ds_mesh[cell_var].values[tri_cell_indices] + coord[:, 1] = ds_mesh[vertex_var].values[vertices_on_cell] + coord[:, 2] = ds_mesh[vertex_var].values[next_vertex] + ds_tris[out_var] = (('nTriangles', 'nNodes'), coord) + + # nothing obvious we can do about triangles containing the poles + + if on_a_sphere: + ds_tris = _fix_periodic_tris(ds_tris, periodic_var='lonNode', + period=2 * np.pi) + elif is_periodic: + ds_tris = _fix_periodic_tris(ds_tris, periodic_var='xNode', + period=x_period) + ds_tris = _fix_periodic_tris(ds_tris, periodic_var='yNode', + period=y_period) + + return ds_tris + + +def make_triangle_tree(ds_tris): + """ + Make a KD-Tree for finding triangle edges that are near enough to transect + segments that they might intersect + + Parameters + ---------- + ds_tris : xarray.Dataset + A dataset that defines triangles, the results of calling + :py:func:`polaris.ocean.viz.transect.horiz.mesh_to_triangles()` + + Returns + ------- + tree : scipy.spatial.cKDTree + A tree of edge centers from triangles making up an MPAS mesh + """ + + n_triangles = ds_tris.sizes['nTriangles'] + n_nodes = ds_tris.sizes['nNodes'] + node_coords = np.zeros((n_triangles * n_nodes, 3)) + node_coords[:, 0] = ds_tris.xNode.values.ravel() + node_coords[:, 1] = ds_tris.yNode.values.ravel() + node_coords[:, 2] = ds_tris.zNode.values.ravel() + + next_tri, next_node = np.meshgrid( + np.arange(n_triangles), np.mod(np.arange(n_nodes) + 1, 3), + indexing='ij') + nextIndices = n_nodes * next_tri.ravel() + next_node.ravel() + + # edge centers are half way between adjacent nodes (ignoring great-circle + # distance) + edgeCoords = 0.5 * (node_coords + node_coords[nextIndices, :]) + + tree = cKDTree(data=edgeCoords, copy_data=True) + return tree + + +def find_spherical_transect_cells_and_weights( + lon_transect, lat_transect, ds_tris, ds_mesh, tree, degrees=True, + earth_radius=None, subdivision_res=10e3): + """ + Find "nodes" where the transect intersects the edges of the triangles + that make up MPAS cells. + + Parameters + ---------- + lon_transect : xarray.DataArray + The longitude of segments making up the transect + + lat_transect : xarray.DataArray + The latitude of segments making up the transect + + ds_tris : xarray.Dataset + A dataset that defines triangles, the results of calling + :py:func:`polaris.ocean.viz.transect.horiz.mesh_to_triangles()` + + ds_mesh : xarray.Dataset + A data set with the full MPAS mesh. + + tree : scipy.spatial.cKDTree + A tree of edge centers from triangles making up an MPAS mesh, the + return value from + :py:func:`polaris.ocean.viz.transect.horiz.make_triangle_tree()` + + degrees : bool, optional + Whether ``lon_transect`` and ``lat_transect`` are in degrees (as + opposed to radians). + + subdivision_res : float, optional + Resolution in m to use to subdivide the transect when looking for + intersection candidates. Should be small enough that curvature is + small. + + earth_radius : float, optional + The radius of the Earth in meters, taken from the `sphere_radius` + global attribute if not provided + + Returns + ------- + ds_out : xarray.Dataset + A dataset that contains "nodes" where the transect intersects the + edges of the triangles in ``ds_tris``. The nodes also includes the two + end points of the transect, which typically lie within triangles. Each + internal node (that is, not including the end points) is purposefully + repeated twice, once for each triangle that node touches. This allows + for discontinuous fields between triangles (e.g. if one wishes to plot + constant values on each MPAS cell). The Cartesian and lon/lat + coordinates of these nodes are ``xCartNode``, ``yCartNode``, + ``zCartNode``, ``lonNode`` and ``latNode``. The distance along the + transect of each intersection is ``dNode``. The index of the triangle + and the first triangle node in ``ds_tris`` associated with each + intersection node are given by ``horizTriangleIndices`` and + ``horizTriangleNodeIndices``, respectively. The second node on the + triangle for the edge associated with the intersection is given by + ``numpy.mod(horizTriangleNodeIndices + 1, 3)``. + + The MPAS cell and edge that a given node belongs to are given by + ``horizCellIndices`` and ``horizEdgeIndices``, respectively. Each node + also has an associated set of 6 + ``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be + used to interpolate from MPAS cell centers to nodes first with + area-weighted averaging to MPAS vertices and then linear interpolation + along triangle edges. Some of the weights may be zero, in which case + the associated ``interpHorizCellIndices`` will be -1. + + Finally, ``lonTransect`` and ``latTransect`` are included in the + dataset, along with Cartesian coordinates ``xCartTransect``, + ``yCartTransect``, `zCartTransect``, and ``dTransect``, the + great-circle distance along the transect of each original transect + point. In order to interpolate values (e.g. observations) from the + original transect points to the intersection nodes, linear + interpolation indices ``transectIndicesOnHorizNode`` and weights + ``transectWeightsOnHorizNode`` are provided. The values at nodes are + found by:: + + nodeValues = ((transectValues[transectIndicesOnHorizNode] * + transectWeightsOnHorizNode) + + (transectValues[transectIndicesOnHorizNode+1] * + (1.0 - transectWeightsOnHorizNode)) + """ + if earth_radius is None: + earth_radius = ds_mesh.attrs['sphere_radius'] + buffer = np.maximum(np.amax(ds_mesh.dvEdge.values), + np.amax(ds_mesh.dcEdge.values)) + + x, y, z = lon_lat_to_cartesian(lon_transect, lat_transect, earth_radius, + degrees) + + n_nodes = ds_tris.sizes['nNodes'] + node_cell_weights = ds_tris.nodeCellWeights.values + node_cell_indices = ds_tris.nodeCellIndices.values + + x_node = ds_tris.xNode.values.ravel() + y_node = ds_tris.yNode.values.ravel() + z_node = ds_tris.zNode.values.ravel() + + d_transect = np.zeros(lon_transect.shape) + + d_node = None + x_out = None + y_out = None + z_out = None + tris = None + nodes = None + interp_cells = None + cell_weights = None + + n_horiz_weights = 6 + + first = True + + d_start = 0. + for seg_index in range(len(x) - 1): + transectv0 = Vector(x[seg_index].values, + y[seg_index].values, + z[seg_index].values) + transectv1 = Vector(x[seg_index + 1].values, + y[seg_index + 1].values, + z[seg_index + 1].values) + + sub_slice = slice(seg_index, seg_index + 2) + x_sub, y_sub, z_sub, _, _ = subdivide_great_circle( + x[sub_slice].values, y[sub_slice].values, z[sub_slice].values, + subdivision_res, earth_radius) + + coords = np.zeros((len(x_sub), 3)) + coords[:, 0] = x_sub + coords[:, 1] = y_sub + coords[:, 2] = z_sub + radius = buffer + subdivision_res + + index_list = tree.query_ball_point(x=coords, r=radius) + + unique_indices = set() + for indices in index_list: + unique_indices.update(indices) + + n0_indices_cand = np.array(list(unique_indices)) + + if len(n0_indices_cand) == 0: + continue + + tris_cand = n0_indices_cand // n_nodes + next_node_index = np.mod(n0_indices_cand + 1, n_nodes) + n1_indices_cand = n_nodes * tris_cand + next_node_index + + n0_cand = Vector(x_node[n0_indices_cand], + y_node[n0_indices_cand], + z_node[n0_indices_cand]) + n1_cand = Vector(x_node[n1_indices_cand], + y_node[n1_indices_cand], + z_node[n1_indices_cand]) + + intersect = Vector.intersects(n0_cand, n1_cand, transectv0, + transectv1) + + n0_inter = Vector(n0_cand.x[intersect], + n0_cand.y[intersect], + n0_cand.z[intersect]) + n1_inter = Vector(n1_cand.x[intersect], + n1_cand.y[intersect], + n1_cand.z[intersect]) + + tris_inter = tris_cand[intersect] + n0_indices_inter = n0_indices_cand[intersect] + n1_indices_inter = n1_indices_cand[intersect] + + intersections = Vector.intersection(n0_inter, n1_inter, transectv0, + transectv1) + intersections = Vector(earth_radius * intersections.x, + earth_radius * intersections.y, + earth_radius * intersections.z) + + angular_distance = transectv0.angular_distance(intersections) + + d_node_local = d_start + earth_radius * angular_distance + + d_start += earth_radius * transectv0.angular_distance(transectv1) + + node0_inter = np.mod(n0_indices_inter, n_nodes) + node1_inter = np.mod(n1_indices_inter, n_nodes) + + node_weights = (intersections.angular_distance(n1_inter) / + n0_inter.angular_distance(n1_inter)) + + weights = np.zeros((len(tris_inter), n_horiz_weights)) + cell_indices = np.zeros((len(tris_inter), n_horiz_weights), int) + for index in range(3): + weights[:, index] = ( + node_weights * + node_cell_weights[tris_inter, node0_inter, index]) + cell_indices[:, index] = ( + node_cell_indices[tris_inter, node0_inter, index]) + weights[:, index + 3] = ( + (1.0 - node_weights) * + node_cell_weights[tris_inter, node1_inter, index]) + cell_indices[:, index + 3] = ( + node_cell_indices[tris_inter, node1_inter, index]) + + if first: + x_out = intersections.x + y_out = intersections.y + z_out = intersections.z + d_node = d_node_local + + tris = tris_inter + nodes = node0_inter + interp_cells = cell_indices + cell_weights = weights + first = False + else: + x_out = np.append(x_out, intersections.x) + y_out = np.append(y_out, intersections.y) + z_out = np.append(z_out, intersections.z) + d_node = np.append(d_node, d_node_local) + + tris = np.concatenate((tris, tris_inter)) + nodes = np.concatenate((nodes, node0_inter)) + interp_cells = np.concatenate((interp_cells, cell_indices), axis=0) + cell_weights = np.concatenate((cell_weights, weights), axis=0) + + d_transect[seg_index + 1] = d_start + + epsilon = 1e-6 * subdivision_res + (d_node, x_out, y_out, z_out, seg_tris, seg_nodes, interp_cells, + cell_weights, valid_nodes) = _sort_intersections( + d_node, tris, nodes, x_out, y_out, z_out, interp_cells, cell_weights, + epsilon) + + lon_out, lat_out = cartesian_to_lon_lat(x_out, y_out, z_out, earth_radius, + degrees) + + valid_segs = seg_tris >= 0 + valid_seg_tris = seg_tris[valid_segs] + cell_indices = -1 * np.ones(seg_tris.shape, dtype=int) + cell_indices[valid_segs] = ( + ds_tris.triCellIndices.values[valid_seg_tris]) + + edge_indices = -1 * np.ones(seg_tris.shape, dtype=int) + edge_indices[valid_segs] = ( + ds_tris.triEdgeIndices.values[valid_seg_tris]) + + ds_out = xr.Dataset() + ds_out['xCartNode'] = (('nNodes',), x_out) + ds_out['yCartNode'] = (('nNodes',), y_out) + ds_out['zCartNode'] = (('nNodes',), z_out) + ds_out['dNode'] = (('nNodes',), d_node) + ds_out['lonNode'] = (('nNodes',), lon_out) + ds_out['latNode'] = (('nNodes',), lat_out) + + ds_out['horizTriangleIndices'] = ('nSegments', seg_tris) + ds_out['horizCellIndices'] = ('nSegments', cell_indices) + ds_out['horizEdgeIndices'] = ('nSegments', edge_indices) + ds_out['horizTriangleNodeIndices'] = (('nSegments', 'nHorizBounds'), + seg_nodes) + ds_out['interpHorizCellIndices'] = (('nNodes', 'nHorizWeights'), + interp_cells) + ds_out['interpHorizCellWeights'] = (('nNodes', 'nHorizWeights'), + cell_weights) + ds_out['validNodes'] = (('nNodes',), valid_nodes) + + transect_indices_on_horiz_node = np.zeros(d_node.shape, dtype=int) + transect_weights_on_horiz_node = np.zeros(d_node.shape) + for trans_index in range(len(d_transect) - 1): + d0 = d_transect[trans_index] + d1 = d_transect[trans_index + 1] + mask = np.logical_and(d_node >= d0, d_node < d1) + transect_indices_on_horiz_node[mask] = trans_index + transect_weights_on_horiz_node[mask] = (d1 - d_node[mask]) / (d1 - d0) + # last index will get missed by the mask and needs to be handled as a + # special case + transect_indices_on_horiz_node[-1] = len(d_transect) - 2 + transect_weights_on_horiz_node[-1] = 0.0 + + ds_out['lonTransect'] = lon_transect + ds_out['latTransect'] = lat_transect + ds_out['xCartTransect'] = x + ds_out['yCartTransect'] = y + ds_out['zCartTransect'] = z + ds_out['dTransect'] = (lon_transect.dims, d_transect) + ds_out['transectIndicesOnHorizNode'] = (('nNodes',), + transect_indices_on_horiz_node) + ds_out['transectWeightsOnHorizNode'] = (('nNodes',), + transect_weights_on_horiz_node) + + return ds_out + + +def find_planar_transect_cells_and_weights( + x_transect, y_transect, ds_tris, ds_mesh, tree, subdivision_res=10e3): + """ + Find "nodes" where the transect intersects the edges of the triangles + that make up MPAS cells. + + Parameters + ---------- + x_transect : xarray.DataArray + The x points defining segments making up the transect + + y_transect : xarray.DataArray + The y points defining segments making up the transect + + ds_tris : xarray.Dataset + A dataset that defines triangles, the results of calling + :py:func:`polaris.ocean.viz.transect.horiz.mesh_to_triangles()` + + ds_mesh : xarray.Dataset + A data set with the full MPAS mesh. + + tree : scipy.spatial.cKDTree + A tree of edge centers from triangles making up an MPAS mesh, the + return value from + :py:func:`polaris.ocean.viz.transect.horiz.make_triangle_tree()` + + subdivision_res : float, optional + Resolution in m to use to subdivide the transect when looking for + intersection candidates. Should be small enough that curvature is + small. + + Returns + ------- + ds_out : xarray.Dataset + A dataset that contains "nodes" where the transect intersects the + edges of the triangles in ``ds_tris``. The nodes also include the two + end points of the transect, which typically lie within triangles. Each + internal node (that is, not including the end points) is purposefully + repeated twice, once for each triangle that node touches. This allows + for discontinuous fields between triangles (e.g. if one wishes to plot + constant values on each MPAS cell). The planar coordinates of these + nodes are ``xNode`` and ``yNode``. The distance along the transect of + each intersection is ``dNode``. The index of the triangle and the first + triangle node in ``ds_tris`` associated with each intersection node are + given by ``horizTriangleIndices`` and ``horizTriangleNodeIndices``, + respectively. The second node on the triangle for the edge associated + with the intersection is given by + ``numpy.mod(horizTriangleNodeIndices + 1, 3)``. + + The MPAS cell and edge that a given node belongs to are given by + ``horizCellIndices`` and ``horizEdgeIndices``, respectively. Each node + also has an associated set of 6 + ``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be + used to interpolate from MPAS cell centers to nodes first with + area-weighted averaging to MPAS vertices and then linear interpolation + along triangle edges. Some of the weights may be zero, in which case + the associated ``interpHorizCellIndices`` will be -1. + + Finally, ``xTransect`` and ``yTransect`` are included in the + dataset, along with ``dTransect``, the distance along the transect of + each original transect point. In order to interpolate values (e.g. + observations) from the original transect points to the intersection + nodes, linear interpolation indices ``transectIndicesOnHorizNode`` and + weights ``transectWeightsOnHorizNode`` are provided. The values at + nodes are found by:: + + nodeValues = ((transectValues[transectIndicesOnHorizNode] * + transectWeightsOnHorizNode) + + (transectValues[transectIndicesOnHorizNode+1] * + (1.0 - transectWeightsOnHorizNode)) + """ + buffer = np.maximum(np.amax(ds_mesh.dvEdge.values), + np.amax(ds_mesh.dcEdge.values)) + + n_nodes = ds_tris.sizes['nNodes'] + node_cell_weights = ds_tris.nodeCellWeights.values + node_cell_indices = ds_tris.nodeCellIndices.values + + x = x_transect + y = y_transect + + x_node = ds_tris.xNode.values.ravel() + y_node = ds_tris.yNode.values.ravel() + + coordNode = np.zeros((len(x_node), 2)) + coordNode[:, 0] = x_node + coordNode[:, 1] = y_node + + d_transect = np.zeros(x_transect.shape) + + d_node = None + x_out = np.array([]) + y_out = np.array([]) + tris = None + nodes = None + interp_cells = None + cell_weights = None + + n_horiz_weights = 6 + + first = True + + d_start = 0. + for seg_index in range(len(x) - 1): + + sub_slice = slice(seg_index, seg_index + 2) + x_sub, y_sub, _, _ = subdivide_planar( + x[sub_slice].values, y[sub_slice].values, subdivision_res) + + start_point = Point(x_transect[seg_index].values, + y_transect[seg_index].values) + end_point = Point(x_transect[seg_index + 1].values, + y_transect[seg_index + 1].values) + + segment = LineString([start_point, end_point]) + + coords = np.zeros((len(x_sub), 3)) + coords[:, 0] = x_sub + coords[:, 1] = y_sub + radius = buffer + subdivision_res + + index_list = tree.query_ball_point(x=coords, r=radius) + + unique_indices = set() + for indices in index_list: + unique_indices.update(indices) + + start_indices = np.array(list(unique_indices)) + + if len(start_indices) == 0: + continue + + tris_cand = start_indices // n_nodes + next_node_index = np.mod(start_indices + 1, n_nodes) + end_indices = n_nodes * tris_cand + next_node_index + + intersecting_nodes = list() + tris_inter_list = list() + x_intersection_list = list() + y_intersection_list = list() + node_weights_list = list() + node0_inter_list = list() + node1_inter_list = list() + distances_list = list() + + for index in range(len(start_indices)): + start = start_indices[index] + end = end_indices[index] + + node0 = Point(coordNode[start, 0], coordNode[start, 1]) + node1 = Point(coordNode[end, 0], coordNode[end, 1]) + + edge = LineString([node0, node1]) + if segment.intersects(edge): + point = segment.intersection(edge) + intersecting_nodes.append((node0, node1, start, end, edge)) + + if isinstance(point, LineString): + raise ValueError('A triangle edge exactly coincides with ' + 'a transect segment and I can\'t handle ' + 'that case. Try moving the transect a ' + 'tiny bit.') + elif not isinstance(point, Point): + raise ValueError(f'Unexpected intersection type {point}') + + x_intersection_list.append(point.x) + y_intersection_list.append(point.y) + + start_to_intersection = LineString([start_point, point]) + + weight = (LineString([point, node1]).length / + LineString([node0, node1]).length) + + node_weights_list.append(weight) + node0_inter_list.append(np.mod(start, n_nodes)) + node1_inter_list.append(np.mod(end, n_nodes)) + distances_list.append(start_to_intersection.length) + tris_inter_list.append(tris_cand[index]) + + distances = np.array(distances_list) + x_intersection = np.array(x_intersection_list) + y_intersection = np.array(y_intersection_list) + node_weights = np.array(node_weights_list) + node0_inter = np.array(node0_inter_list, dtype=int) + node1_inter = np.array(node1_inter_list, dtype=int) + tris_inter = np.array(tris_inter_list, dtype=int) + + d_node_local = d_start + distances + + d_start += segment.length + + weights = np.zeros((len(tris_inter), n_horiz_weights)) + cell_indices = np.zeros((len(tris_inter), n_horiz_weights), int) + for index in range(3): + weights[:, index] = ( + node_weights * + node_cell_weights[tris_inter, node0_inter, index]) + cell_indices[:, index] = ( + node_cell_indices[tris_inter, node0_inter, index]) + weights[:, index + 3] = ( + (1.0 - node_weights) * + node_cell_weights[tris_inter, node1_inter, index]) + cell_indices[:, index + 3] = ( + node_cell_indices[tris_inter, node1_inter, index]) + + if first: + x_out = x_intersection + y_out = y_intersection + d_node = d_node_local + + tris = tris_inter + nodes = node0_inter + interp_cells = cell_indices + cell_weights = weights + first = False + else: + x_out = np.append(x_out, x_intersection) + y_out = np.append(y_out, y_intersection) + d_node = np.append(d_node, d_node_local) + + tris = np.concatenate((tris, tris_inter)) + nodes = np.concatenate((nodes, node0_inter)) + interp_cells = np.concatenate((interp_cells, cell_indices), axis=0) + cell_weights = np.concatenate((cell_weights, weights), axis=0) + + d_transect[seg_index + 1] = d_start + + z_out = np.zeros(x_out.shape) + + epsilon = 1e-6 * subdivision_res + (d_node, x_out, y_out, z_out, seg_tris, seg_nodes, interp_cells, + cell_weights, valid_nodes) = _sort_intersections( + d_node, tris, nodes, x_out, y_out, z_out, interp_cells, cell_weights, + epsilon) + + valid_segs = seg_tris >= 0 + valid_seg_tris = seg_tris[valid_segs] + cell_indices = -1 * np.ones(seg_tris.shape, dtype=int) + cell_indices[valid_segs] = ( + ds_tris.triCellIndices.values[valid_seg_tris]) + + edge_indices = -1 * np.ones(seg_tris.shape, dtype=int) + edge_indices[valid_segs] = ( + ds_tris.triEdgeIndices.values[valid_seg_tris]) + + ds_out = xr.Dataset() + ds_out['xNode'] = (('nNodes',), x_out) + ds_out['yNode'] = (('nNodes',), y_out) + ds_out['dNode'] = (('nNodes',), d_node) + + ds_out['horizTriangleIndices'] = ('nSegments', seg_tris) + ds_out['horizCellIndices'] = ('nSegments', cell_indices) + ds_out['horizEdgeIndices'] = ('nSegments', edge_indices) + ds_out['horizTriangleNodeIndices'] = (('nSegments', 'nHorizBounds'), + seg_nodes) + ds_out['interpHorizCellIndices'] = (('nNodes', 'nHorizWeights'), + interp_cells) + ds_out['interpHorizCellWeights'] = (('nNodes', 'nHorizWeights'), + cell_weights) + ds_out['validNodes'] = (('nNodes',), valid_nodes) + + transect_indices_on_horiz_node = np.zeros(d_node.shape, int) + transect_weights_on_horiz_node = np.zeros(d_node.shape) + for trans_index in range(len(d_transect) - 1): + d0 = d_transect[trans_index] + d1 = d_transect[trans_index + 1] + mask = np.logical_and(d_node >= d0, d_node < d1) + transect_indices_on_horiz_node[mask] = trans_index + transect_weights_on_horiz_node[mask] = (d1 - d_node[mask]) / (d1 - d0) + # last index will get missed by the mask and needs to be handled as a + # special case + transect_indices_on_horiz_node[-1] = len(d_transect) - 2 + transect_weights_on_horiz_node[-1] = 0.0 + + ds_out['xTransect'] = x + ds_out['yTransect'] = y + ds_out['dTransect'] = (x_transect.dims, d_transect) + ds_out['transectIndicesOnHorizNode'] = (('nNodes',), + transect_indices_on_horiz_node) + ds_out['transectWeightsOnHorizNode'] = (('nNodes',), + transect_weights_on_horiz_node) + + return ds_out + + +def interp_mpas_horiz_to_transect_nodes(ds_transect, da): + """ + Interpolate a 2D (``nCells``) MPAS DataArray to transect nodes, linearly + interpolating fields between the closest neighboring cells + + Parameters + ---------- + ds_transect : xr.Dataset + A dataset that defines an MPAS transect, the results of calling + ``find_spherical_transect_cells_and_weights()`` or + ``find_planar_transect_cells_and_weights()`` + + da : xr.DataArray + An MPAS 2D field with dimensions `nCells`` (possibly among others) + + Returns + ------- + da_nodes : xr.DataArray + The data array interpolated to transect nodes with dimensions + ``nNodes`` (in addition to whatever dimensions were in ``da`` besides + ``nCells``) + """ + interp_cell_indices = ds_transect.interpHorizCellIndices + interp_cell_weights = ds_transect.interpHorizCellWeights + da = da.isel(nCells=interp_cell_indices) + da_nodes = (da * interp_cell_weights).sum(dim='nHorizWeights') + + da_nodes = da_nodes.where(ds_transect.validNodes) + + return da_nodes + + +def _sort_intersections(d_node, tris, nodes, x_out, y_out, z_out, interp_cells, + cell_weights, epsilon): + """ sort nodes by distance and define segment between them """ + + sort_indices = np.argsort(d_node) + d_sorted = d_node[sort_indices] + + # make a list of indices for each unique value of d + d = d_sorted[0] + unique_d_indices = [sort_indices[0]] + unique_d_all_indices = [[sort_indices[0]]] + for index, next_d, in zip(sort_indices[1:], d_sorted[1:]): + if next_d - d < epsilon: + # this d value is effectively the same as the last, so we'll treat + # it as the same + unique_d_all_indices[-1].append(index) + else: + # this is a new d, so we'll add to a new list + d = next_d + unique_d_indices.append(index) + unique_d_all_indices.append([index]) + + # there is a segment between each unique d, though some are invalid (do + # not correspond to a triangle) + seg_tris_list = list() + seg_nodes_list = list() + + index0 = unique_d_indices[0] + indices0 = unique_d_all_indices[0] + d0 = d_node[index0] + + indices = [index0] + ds = [d0] + for seg_index in range(len(unique_d_all_indices) - 1): + indices1 = unique_d_all_indices[seg_index + 1] + index1 = unique_d_indices[seg_index + 1] + d1 = d_node[index1] + + # are there any triangles in common between this d value and the next? + tris0 = tris[indices0] + tris1 = tris[indices1] + both = set(tris0).intersection(set(tris1)) + + if len(both) > 0: + tri = both.pop() + seg_tris_list.append(tri) + indices.append(index1) + ds.append(d1) + + # the triangle nodes are the 2 corresponding to the same triangle + # in the original list + index0 = indices0[np.where(tris0 == tri)[0][0]] + index1 = indices1[np.where(tris1 == tri)[0][0]] + seg_nodes_list.append([nodes[index0], nodes[index1]]) + else: + # this is an invalid segment so we need to insert and extra invalid + # node to allow for proper masking + seg_tris_list.extend([-1, -1]) + seg_nodes_list.extend([[-1, -1], [-1, -1]]) + indices.extend([index0, index1]) + ds.extend([0.5 * (d0 + d1), d1]) + + index0 = index1 + indices0 = indices1 + d0 = d1 + + indices = np.array(indices, dtype=int) + d_node = np.array(ds, dtype=float) + seg_tris = np.array(seg_tris_list, dtype=int) + seg_nodes = np.array(seg_nodes_list, dtype=int) + + valid_nodes = np.ones(len(indices), dtype=bool) + valid_nodes[1:-1] = np.logical_or(seg_tris[0:-1] >= 0, + seg_tris[1:] > 0) + + x_out = x_out[indices] + y_out = y_out[indices] + z_out = z_out[indices] + + interp_cells = interp_cells[indices, :] + cell_weights = cell_weights[indices, :] + + return (d_node, x_out, y_out, z_out, seg_tris, seg_nodes, interp_cells, + cell_weights, valid_nodes) + + +def _fix_periodic_tris(ds_tris, periodic_var, period): + """ + make sure the given node coordinate on tris is within one period of the + cell center + """ + coord_node = ds_tris[periodic_var].values + coord_cell = coord_node[:, 0] + n_triangles = ds_tris.sizes['nTriangles'] + copy_pos = np.zeros(coord_cell.shape, dtype=bool) + copy_neg = np.zeros(coord_cell.shape, dtype=bool) + for i_node in [1, 2]: + mask = coord_node[:, i_node] - coord_cell > 0.5 * period + copy_pos = np.logical_or(copy_pos, mask) + coord_node[:, i_node][mask] = coord_node[:, i_node][mask] - period + mask = coord_node[:, i_node] - coord_cell < -0.5 * period + copy_neg = np.logical_or(copy_neg, mask) + coord_node[:, i_node][mask] = coord_node[:, i_node][mask] + period + + pos_indices = np.nonzero(copy_pos)[0] + neg_indices = np.nonzero(copy_neg)[0] + tri_indices = np.append(np.append(np.arange(0, n_triangles), + pos_indices), neg_indices) + + ds_new = xr.Dataset(ds_tris) + ds_new[periodic_var] = (('nTriangles', 'nNodes'), coord_node) + ds_new = ds_new.isel(nTriangles=tri_indices) + coord_node = ds_new[periodic_var].values + + pos_slice = slice(n_triangles, n_triangles + len(pos_indices)) + coord_node[pos_slice, :] = coord_node[pos_slice, :] + period + neg_slice = slice(n_triangles + len(pos_indices), + n_triangles + len(pos_indices) + len(neg_indices)) + coord_node[neg_slice, :] = coord_node[neg_slice, :] - period + ds_new[periodic_var] = (('nTriangles', 'nNodes'), coord_node) + return ds_new diff --git a/ocean/transects/python/transect/plot.py b/ocean/transects/python/transect/plot.py new file mode 100644 index 00000000..d8cc763d --- /dev/null +++ b/ocean/transects/python/transect/plot.py @@ -0,0 +1,224 @@ +""" +Copied from Polaris +""" +import matplotlib.pyplot as plt +import numpy as np + +from transect.vert import ( + interp_mpas_edges_to_transect_cells, + interp_mpas_to_transect_cells, + interp_mpas_to_transect_nodes, +) + + +def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, + title=None, vmin=None, vmax=None, colorbar_label=None, + cmap=None, figsize=(12, 6), dpi=200, method='flat', + outline_color='black', ssh_color=None, seafloor_color=None, + interface_color=None, cell_boundary_color=None, + linewidth=1.0, color_start_and_end=False, + start_color='red', end_color='green'): + """ + plot a transect showing the field on the MPAS-Ocean mesh and save to a file + + Parameters + ---------- + ds_transect : xarray.Dataset + A transect dataset from + :py:func:`polaris.ocean.viz.compute_transect()` + + mpas_field : xarray.DataArray + The MPAS-Ocean 3D field (``nCells`` or ``nEdges`` by ``nVertLevels``) + to plot + + out_filename : str, optional + The png file to write out to + + ax : matplotlib.axes.Axes + Axes to plot to if making a multi-panel figure + + title : str + The title of the plot + + vmin : float, optional + The minimum values for the colorbar + + vmax : float, optional + The maximum values for the colorbar + + colorbar_label : str, optional + The colorbar label, or ``None`` if no colorbar is to be included. + Use an empty string to display a colorbar without a label. + + cmap : str, optional + The name of a colormap to use + + figsize : tuple, optional + The size of the figure in inches + + dpi : int, optional + The dots per inch of the image + + method : {'flat', 'bilinear'}, optional + The type of interpolation to use in plots. ``flat`` means constant + values over each MPAS cell or edge. ``bilinear`` means smooth + interpolation horizontally between cell centers and vertical between + the middle of layers (available only for fields on MPAS cells). + + outline_color : str or None, optional + The color to use to outline the transect or ``None`` for no outline + + ssh_color : str or None, optional + The color to use to plot the SSH (sea surface height) or ``None`` if + not plotting the SSH (except perhaps as part of the outline) + + seafloor_color : str or None, optional + The color to use to plot the seafloor depth or ``None`` if not plotting + the seafloor depth (except perhaps as part of the outline) + + interface_color : str or None, optional + The color to use to plot interfaces between layers or ``None`` if + not plotting the layer interfaces + + cell_boundary_color : str or None, optional + The color to use to plot vertical boundaries between cells or ``None`` + if not plotting cell boundaries. Typically, ``cell_boundary_color`` + will be used along with ``interface_color`` to outline cells both + horizontally and vertically. + + linewidth : float, optional + The width of outlines, interfaces and cell boundaries + + color_start_and_end : bool, optional + Whether to color the left and right axes of the transect, which is + useful if the transect is also being plotted in an inset or on top of + a horizontal field + + start_color : str, optional + The color of left axis marking the start of the transect if + ``plot_start_end == True`` + + end_color : str, optional + The color of right axis marking the end of the transect if + ``plot_start_end == True`` + """ + + if ax is None and out_filename is None: + raise ValueError('One of ax or out_filename must be supplied') + + create_fig = ax is None + if create_fig: + plt.figure(figsize=figsize) + ax = plt.subplot(111) + + z = ds_transect.zTransectNode + x = 1e-3 * ds_transect.dNode.broadcast_like(z) + + if mpas_field is not None: + if 'nCells' in mpas_field.dims: + if method == 'flat': + transect_field = interp_mpas_to_transect_cells(ds_transect, + mpas_field) + shading = 'flat' + elif method == 'bilinear': + transect_field = interp_mpas_to_transect_nodes(ds_transect, + mpas_field) + shading = 'gouraud' + else: + raise ValueError(f'Unsupported method for cell fields: ' + f'{method}') + elif 'nEdges' in mpas_field.dims: + if method != 'flat': + raise ValueError(f'Unsupported method for edge fields: ' + f'{method}') + + transect_field = interp_mpas_edges_to_transect_cells(ds_transect, + mpas_field) + shading = 'flat' + else: + raise ValueError(f'Expected one of nCells or nEdges in ' + f'{mpas_field.dims}') + + pc = ax.pcolormesh(x.values, z.values, transect_field.values, + shading=shading, cmap=cmap, vmin=vmin, vmax=vmax, + zorder=0) + ax.autoscale(tight=True) + if colorbar_label is not None: + plt.colorbar(pc, extend='both', shrink=0.7, ax=ax, + label=colorbar_label) + + _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, + ssh_color, seafloor_color, color_start_and_end, + start_color, end_color, linewidth) + + _plot_outline(x, z, ds_transect.validNodes, ax, outline_color, + linewidth) + + ax.set_xlabel('transect distance (km)') + ax.set_ylabel('z (m)') + + if create_fig: + if title is not None: + plt.title(title) + plt.savefig(out_filename, dpi=dpi, bbox_inches='tight', pad_inches=0.2) + plt.close() + + +def _plot_interfaces(ds_transect, ax, interface_color, cell_boundary_color, + ssh_color, seafloor_color, color_start_and_end, + start_color, end_color, linewidth): + if cell_boundary_color is not None: + x_bnd = 1e-3 * ds_transect.dCellBoundary.values.T + z_bnd = ds_transect.zCellBoundary.values.T + ax.plot(x_bnd, z_bnd, color=cell_boundary_color, linewidth=linewidth, + zorder=1) + + if interface_color is not None: + x_int = 1e-3 * ds_transect.dInterfaceSegment.values.T + z_int = ds_transect.zInterfaceSegment.values.T + ax.plot(x_int, z_int, color=interface_color, linewidth=linewidth, + zorder=2) + + if ssh_color is not None: + valid = ds_transect.validNodes.any(dim='nVertNodes') + x_ssh = 1e-3 * ds_transect.dNode.values + z_ssh = ds_transect.ssh.where(valid).values + ax.plot(x_ssh, z_ssh, color=ssh_color, linewidth=linewidth, zorder=4) + + if seafloor_color is not None: + valid = ds_transect.validNodes.any(dim='nVertNodes') + x_floor = 1e-3 * ds_transect.dNode.values + z_floor = ds_transect.zSeafloor.where(valid).values + ax.plot(x_floor, z_floor, color=seafloor_color, linewidth=linewidth, + zorder=5) + + if color_start_and_end: + ax.spines['left'].set_color(start_color) + ax.spines['left'].set_linewidth(4 * linewidth) + ax.spines['right'].set_color(end_color) + ax.spines['right'].set_linewidth(4 * linewidth) + + +def _plot_outline(x, z, valid_nodes, ax, outline_color, linewidth, + epsilon=1e-6): + if outline_color is not None: + # add a buffer of invalid values around the edge of the domain + valid = np.zeros((x.shape[0] + 2, x.shape[1] + 2), dtype=float) + z_buf = np.zeros(valid.shape, dtype=float) + x_buf = np.zeros(valid.shape, dtype=float) + + valid[1:-1, 1:-1] = valid_nodes.astype(float) + z_buf[1:-1, 1:-1] = z.values + z_buf[0, 1:-1] = z_buf[1, 1:-1] + z_buf[-1, 1:-1] = z_buf[-2, 1:-1] + z_buf[:, 0] = z_buf[:, 1] + z_buf[:, -1] = z_buf[:, -2] + + x_buf[1:-1, 1:-1] = x.values + x_buf[0, 1:-1] = x_buf[1, 1:-1] + x_buf[-1, 1:-1] = x_buf[-2, 1:-1] + x_buf[:, 0] = x_buf[:, 1] + x_buf[:, -1] = x_buf[:, -2] + + ax.contour(x_buf, z_buf, valid, levels=[1. - epsilon], + colors=outline_color, linewidths=linewidth, zorder=3) diff --git a/ocean/transects/python/transect/vert.py b/ocean/transects/python/transect/vert.py new file mode 100644 index 00000000..4174e577 --- /dev/null +++ b/ocean/transects/python/transect/vert.py @@ -0,0 +1,521 @@ +""" +Copied from Polaris +""" +import numpy as np +import xarray as xr + +from transect.horiz import ( + find_planar_transect_cells_and_weights, + find_spherical_transect_cells_and_weights, + make_triangle_tree, + mesh_to_triangles, +) + + +def compute_transect(x, y, ds_horiz_mesh, layer_thickness, bottom_depth, + min_level_cell, max_level_cell, spherical=False): + """ + build a sequence of quads showing the transect intersecting mpas cells. + This can be used to plot transects of fields with dimensions ``nCells`` and + ``nVertLevels`` using :py:func:`polaris.ocean.viz.plot_transect()` + + Parameters + ---------- + x : xarray.DataArray + The x or longitude coordinate of the transect + + y : xarray.DataArray + The y or latitude coordinate of the transect + + ds_horiz_mesh : xarray.Dataset + The horizontal MPAS mesh to use for plotting + + layer_thickness : xarray.DataArray + The layer thickness at a particular instant in time. + `layerThickness.isel(Time=tidx)` to select a particular time index + `tidx` if the original data array contains `Time`. + + bottom_depth : xarray.DataArray + the (positive down) depth of the seafloor on the MPAS mesh + + min_level_cell : xarray.DataArray + the vertical zero-based index of the sea surface on the MPAS mesh + + max_level_cell : xarray.DataArray + the vertical zero-based index of the bathymetry on the MPAS mesh + + spherical : bool, optional + Whether the x and y coordinates are latitude and longitude in degrees + + Returns + ------- + ds_transect : xarray.Dataset + The transect dataset, see + :py:func:`polaris.ocean.viz.transect.vert.find_transect_levels_and_weights()` + for details + """ # noqa: E501 + + ds_tris = mesh_to_triangles(ds_horiz_mesh) + + triangle_tree = make_triangle_tree(ds_tris) + + if spherical: + ds_horiz_transect = find_spherical_transect_cells_and_weights( + x, y, ds_tris, ds_horiz_mesh, triangle_tree, degrees=True) + else: + ds_horiz_transect = find_planar_transect_cells_and_weights( + x, y, ds_tris, ds_horiz_mesh, triangle_tree) + + # mask horizontal transect to valid cells (max_level_cell >= 0) + cell_indices = ds_horiz_transect.horizCellIndices + seg_mask = max_level_cell.isel(nCells=cell_indices).values >= 0 + node_mask = np.zeros(ds_horiz_transect.sizes['nNodes'], dtype=bool) + node_mask[0:-1] = seg_mask + node_mask[1:] = np.logical_or(node_mask[1:], seg_mask) + + ds_horiz_transect = ds_horiz_transect.isel(nSegments=seg_mask, + nNodes=node_mask) + + ds_transect = find_transect_levels_and_weights( + ds_horiz_transect=ds_horiz_transect, layer_thickness=layer_thickness, + bottom_depth=bottom_depth, min_level_cell=min_level_cell, + max_level_cell=max_level_cell) + + ds_transect.compute() + + return ds_transect + + +def find_transect_levels_and_weights(ds_horiz_transect, layer_thickness, + bottom_depth, min_level_cell, + max_level_cell): + """ + Construct a vertical coordinate for a transect produced by + :py:func:`polaris.ocean.viz.transect.horiz.find_spherical_transect_cells_and_weights()` + or :py:func:`polaris.ocean.viz.transect.horiz.find_planar_transect_cells_and_weights()`. + Also, compute interpolation weights such that observations at points on the + original transect and with vertical coordinate ``transectZ`` can be + bilinearly interpolated to the nodes of the transect. + + Parameters + ---------- + ds_horiz_transect : xarray.Dataset + A dataset that defines nodes of the transect + + layer_thickness : xarray.DataArray + layer thicknesses on the MPAS mesh + + bottom_depth : xarray.DataArray + the (positive down) depth of the seafloor on the MPAS mesh + + min_level_cell : xarray.DataArray + the vertical zero-based index of the sea surface on the MPAS mesh + + max_level_cell : xarray.DataArray + the vertical zero-based index of the bathymetry on the MPAS mesh + + Returns + ------- + ds_transect : xarray.Dataset + A dataset that contains nodes and cells that make up a 2D transect. + + There are ``nSegments`` horizontal and ``nHalfLevels`` vertical + transect cells (quadrilaterals), bounded by ``nHorizNodes`` horizontal + and ``nVertNodes`` vertical nodes (corners). + + In addition to the variables and coordinates in the input + ``ds_transect``, the output dataset contains: + + - ``validCells``, ``validNodes``: which transect cells and nodes + are valid (above the bathymetry and below the sea surface) + + - zTransectNode: the vertical height of each triangle node + - ssh, zSeaFloor: the sea-surface height and sea-floor height at + each node of each transect segment + + - ``cellIndices``: the MPAS-Ocean cell of a given transect segment + - ``levelIndices``: the MPAS-Ocean vertical level of a given + transect level + + - ``interpCellIndices``, ``interpLevelIndices``: the MPAS-Ocean + cells and levels from which the value at a given transect cell is + interpolated. This can involve up to + ``nHorizWeights * nVertWeights = 12`` different cells and levels. + - interpCellWeights: the weight to multiply each field value by + to perform interpolation to a transect cell. + + - ``dInterfaceSegment``, ``zInterfaceSegment`` - segments that can + be used to plot the interfaces between MPAS-Ocean layers + + - ``dCellBoundary``, ``zCellBoundary`` - segments that can + be used to plot the vertical boundaries between MPAS-Ocean cells + + Interpolation of a DataArray from MPAS cells and levels to transect + cells can be performed with + :py:func:`polaris.ocean.viz.transect.vert.interp_mpas_to_transect_cells()`. + Similarly, interpolation to transect nodes can be performed with + :py:func:`polaris.ocean.viz.transect.vert.interp_mpas_to_transect_nodes()`. + """ # noqa: E501 + if 'Time' in layer_thickness.dims: + raise ValueError('Please select a single time level in layer ' + 'thickness.') + + ds_transect_cells = ds_horiz_transect.rename({'nNodes': 'nHorizNodes'}) + + (z_half_interface, ssh, z_seafloor, interp_cell_indices, + interp_cell_weights, valid_transect_cells, + level_indices) = _get_vertical_coordinate( + ds_transect_cells, layer_thickness, bottom_depth, min_level_cell, + max_level_cell) + + ds_transect_cells['zTransectNode'] = z_half_interface + + ds_transect_cells['ssh'] = ssh + ds_transect_cells['zSeafloor'] = z_seafloor + + ds_transect_cells['cellIndices'] = ds_transect_cells.horizCellIndices + ds_transect_cells['levelIndices'] = level_indices + ds_transect_cells['validCells'] = valid_transect_cells + + d_interface_seg, z_interface_seg = _get_interface_segments( + z_half_interface, ds_transect_cells.dNode, valid_transect_cells) + + ds_transect_cells['dInterfaceSegment'] = d_interface_seg + ds_transect_cells['zInterfaceSegment'] = z_interface_seg + + d_cell_boundary, z_cell_boundary = _get_cell_boundary_segments( + ssh, z_seafloor, ds_transect_cells.dNode, + ds_transect_cells.horizCellIndices) + + ds_transect_cells['dCellBoundary'] = d_cell_boundary + ds_transect_cells['zCellBoundary'] = z_cell_boundary + + interp_level_indices, interp_cell_weights, valid_nodes = \ + _get_interp_indices_and_weights(layer_thickness, interp_cell_indices, + interp_cell_weights, level_indices, + valid_transect_cells) + + ds_transect_cells['interpCellIndices'] = interp_cell_indices + ds_transect_cells['interpLevelIndices'] = interp_level_indices + ds_transect_cells['interpCellWeights'] = interp_cell_weights + ds_transect_cells['validNodes'] = valid_nodes + + dims = ['nSegments', 'nHalfLevels', 'nHorizNodes', 'nVertNodes', + 'nInterfaceSegments', 'nCellBoundaries', 'nHorizBounds', + 'nVertBounds', 'nHorizWeights', 'nVertWeights'] + for dim in ds_transect_cells.dims: + if dim not in dims: + dims.insert(0, dim) + ds_transect_cells = ds_transect_cells.transpose(*dims) + + return ds_transect_cells + + +def interp_mpas_to_transect_cells(ds_transect, da): + """ + Interpolate an MPAS-Ocean DataArray with dimensions ``nCells`` by + ``nVertLevels`` to transect cells + + Parameters + ---------- + ds_transect : xarray.Dataset + A dataset that defines an MPAS-Ocean transect, the results of calling + ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An MPAS-Ocean field with dimensions ``nCells`` and ``nVertLevels`` + (possibly among others) + + Returns + ------- + da_cells : xarray.DataArray + The data array interpolated to transect cells with dimensions + ``nSegments`` and ``nHalfLevels`` (in addition to whatever + dimensions were in ``da`` besides ``nCells`` and ``nVertLevels``) + """ + + cell_indices = ds_transect.cellIndices + level_indices = ds_transect.levelIndices + + da_cells = da.isel(nCells=cell_indices, nVertLevels=level_indices) + da_cells = da_cells.where(ds_transect.validCells) + + return da_cells + + +def interp_mpas_edges_to_transect_cells(ds_transect, da): + """ + Interpolate an MPAS-Ocean DataArray with dimensions ``nEdges`` by + ``nVertLevels`` to transect cells + + Parameters + ---------- + ds_transect : xarray.Dataset + A dataset that defines an MPAS-Ocean transect, the results of calling + ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An MPAS-Ocean field with dimensions ``nEdges`` and ``nVertLevels`` + (possibly among others) + + Returns + ------- + da_cells : xarray.DataArray + The data array interpolated to transect cells with dimensions + ``nSegments`` and ``nHalfLevels`` (in addition to whatever + dimensions were in ``da`` besides ``nEdges`` and ``nVertLevels``) + """ + + edge_indices = ds_transect.edgeIndices + level_indices = ds_transect.levelIndices + + da_cells = da.isel(nEdges=edge_indices, nVertLevels=level_indices) + da_cells = da_cells.where(ds_transect.validCells) + + return da_cells + + +def interp_mpas_to_transect_nodes(ds_transect, da): + """ + Interpolate an MPAS-Ocean DataArray with dimensions ``nCells`` by + ``nVertLevels`` to transect nodes, linearly interpolating fields between + the closest neighboring cells + + Parameters + ---------- + ds_transect : xarray.Dataset + A dataset that defines an MPAS-Ocean transect, the results of calling + ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An MPAS-Ocean field with dimensions `nCells`` and ``nVertLevels`` + (possibly among others) + + Returns + ------- + da_nodes : xarray.DataArray + The data array interpolated to transect nodes with dimensions + ``nHorizNodes`` and ``nVertNodes`` (in addition to whatever + dimensions were in ``da`` besides ``nCells`` and ``nVertLevels``) + """ + + interp_cell_indices = ds_transect.interpCellIndices + interp_level_indices = ds_transect.interpLevelIndices + interp_cell_weights = ds_transect.interpCellWeights + + da = da.isel(nCells=interp_cell_indices, nVertLevels=interp_level_indices) + + da_nodes = (da * interp_cell_weights).sum( + dim=('nHorizWeights', 'nVertWeights')) + + da_nodes = da_nodes.where(ds_transect.validNodes) + + return da_nodes + + +def _get_vertical_coordinate(ds_transect, layer_thickness, bottom_depth, + min_level_cell, max_level_cell): + n_horiz_nodes = ds_transect.sizes['nHorizNodes'] + n_segments = ds_transect.sizes['nSegments'] + n_vert_levels = layer_thickness.sizes['nVertLevels'] + + # we assume below that there is a segment (whether valid or invalid) + # connecting each pair of adjacent nodes + assert n_horiz_nodes == n_segments + 1 + + interp_horiz_cell_indices = ds_transect.interpHorizCellIndices + interp_horiz_cell_weights = ds_transect.interpHorizCellWeights + + bottom_depth_interp = bottom_depth.isel(nCells=interp_horiz_cell_indices) + layer_thickness_interp = layer_thickness.isel( + nCells=interp_horiz_cell_indices) + + cell_mask_interp = _get_cell_mask(interp_horiz_cell_indices, + min_level_cell, max_level_cell, + n_vert_levels) + layer_thickness_interp = layer_thickness_interp.where(cell_mask_interp, 0.) + + ssh_interp = (-bottom_depth_interp + + layer_thickness_interp.sum(dim='nVertLevels')) + + interp_mask = np.logical_and(interp_horiz_cell_indices > 0, + cell_mask_interp) + + interp_cell_weights = interp_mask * interp_horiz_cell_weights + weight_sum = interp_cell_weights.sum(dim='nHorizWeights') + + cell_indices = ds_transect.horizCellIndices + + valid_cells = _get_cell_mask(cell_indices, min_level_cell, max_level_cell, + n_vert_levels) + + valid_cells = valid_cells.transpose('nSegments', 'nVertLevels').values + + valid_nodes = np.zeros((n_horiz_nodes, n_vert_levels), dtype=bool) + valid_nodes[0:-1, :] = valid_cells + valid_nodes[1:, :] = np.logical_or(valid_nodes[1:, :], valid_cells) + + valid_nodes = xr.DataArray(dims=('nHorizNodes', 'nVertLevels'), + data=valid_nodes) + + valid_weights = valid_nodes.broadcast_like(interp_cell_weights) + interp_cell_weights = \ + (interp_cell_weights / weight_sum).where(valid_weights) + + layer_thickness_transect = (layer_thickness_interp * + interp_cell_weights).sum(dim='nHorizWeights') + + interp_mask = max_level_cell.isel(nCells=interp_horiz_cell_indices) >= 0 + interp_horiz_cell_weights = interp_mask * interp_horiz_cell_weights + weight_sum = interp_horiz_cell_weights.sum(dim='nHorizWeights') + interp_horiz_cell_weights = \ + (interp_horiz_cell_weights / weight_sum).where(interp_mask) + + ssh_transect = (ssh_interp * + interp_horiz_cell_weights).sum(dim='nHorizWeights') + + z_bot = ssh_transect - layer_thickness_transect.cumsum(dim='nVertLevels') + z_mid = z_bot + 0.5 * layer_thickness_transect + + z_half_interfaces = [ssh_transect] + for z_index in range(n_vert_levels): + z_half_interfaces.extend([z_mid.isel(nVertLevels=z_index), + z_bot.isel(nVertLevels=z_index)]) + + z_half_interface = xr.concat(z_half_interfaces, dim='nVertNodes') + z_half_interface = z_half_interface.transpose('nHorizNodes', 'nVertNodes') + + z_seafloor = ssh_transect - layer_thickness_transect.sum( + dim='nVertLevels') + + valid_transect_cells = np.zeros((n_segments, 2 * n_vert_levels), + dtype=bool) + valid_transect_cells[:, 0::2] = valid_cells + valid_transect_cells[:, 1::2] = valid_cells + valid_transect_cells = xr.DataArray(dims=('nSegments', 'nHalfLevels'), + data=valid_transect_cells) + + level_indices = np.zeros(2 * n_vert_levels, dtype=int) + level_indices[0::2] = np.arange(n_vert_levels) + level_indices[1::2] = np.arange(n_vert_levels) + level_indices = xr.DataArray(dims=('nHalfLevels',), data=level_indices) + + return (z_half_interface, ssh_transect, z_seafloor, + interp_horiz_cell_indices, interp_cell_weights, + valid_transect_cells, level_indices) + + +def _get_cell_mask(cell_indices, min_level_cell, max_level_cell, + n_vert_levels): + level_indices = xr.DataArray(data=np.arange(n_vert_levels), + dims='nVertLevels') + min_level_cell = min_level_cell.isel(nCells=cell_indices) + max_level_cell = max_level_cell.isel(nCells=cell_indices) + + cell_mask = np.logical_and( + level_indices >= min_level_cell, + level_indices <= max_level_cell) + + cell_mask = np.logical_and(cell_mask, cell_indices >= 0) + + return cell_mask + + +def _get_interface_segments(z_half_interface, d_node, valid_transect_cells): + + d = d_node.broadcast_like(z_half_interface) + z_interface = z_half_interface.values[:, 0::2] + d = d.values[:, 0::2] + + n_segments = valid_transect_cells.sizes['nSegments'] + n_half_levels = valid_transect_cells.sizes['nHalfLevels'] + n_vert_levels = n_half_levels // 2 + + valid_segs = np.zeros((n_segments, n_vert_levels + 1), dtype=bool) + valid_segs[:, 0:-1] = valid_transect_cells.values[:, 1::2] + valid_segs[:, 1:] = np.logical_or(valid_segs[:, 1:], + valid_transect_cells.values[:, 0::2]) + + n_interface_segs = np.count_nonzero(valid_segs) + + d_seg = np.zeros((n_interface_segs, 2)) + z_seg = np.zeros((n_interface_segs, 2)) + d_seg[:, 0] = d[0:-1, :][valid_segs] + d_seg[:, 1] = d[1:, :][valid_segs] + z_seg[:, 0] = z_interface[0:-1, :][valid_segs] + z_seg[:, 1] = z_interface[1:, :][valid_segs] + + d_seg = xr.DataArray(dims=('nInterfaceSegments', 'nHorizBounds'), + data=d_seg) + + z_seg = xr.DataArray(dims=('nInterfaceSegments', 'nHorizBounds'), + data=z_seg) + + return d_seg, z_seg + + +def _get_cell_boundary_segments(ssh, z_seafloor, d_node, cell_indices): + + n_horiz_nodes = d_node.sizes['nHorizNodes'] + + cell_boundary = np.ones(n_horiz_nodes, dtype=bool) + cell_boundary[1:-1] = cell_indices.values[0:-1] != cell_indices.values[1:] + + n_cell_boundaries = np.count_nonzero(cell_boundary) + + d_seg = np.zeros((n_cell_boundaries, 2)) + z_seg = np.zeros((n_cell_boundaries, 2)) + d_seg[:, 0] = d_node.values[cell_boundary] + d_seg[:, 1] = d_seg[:, 0] + z_seg[:, 0] = ssh[cell_boundary] + z_seg[:, 1] = z_seafloor[cell_boundary] + + d_seg = xr.DataArray(dims=('nCellBoundaries', 'nVertBounds'), data=d_seg) + + z_seg = xr.DataArray(dims=('nCellBoundaries', 'nVertBounds'), data=z_seg) + + return d_seg, z_seg + + +def _get_interp_indices_and_weights(layer_thickness, interp_cell_indices, + interp_cell_weights, level_indices, + valid_transect_cells): + n_horiz_nodes = interp_cell_indices.sizes['nHorizNodes'] + n_vert_levels = layer_thickness.sizes['nVertLevels'] + n_vert_nodes = 2 * n_vert_levels + 1 + n_vert_weights = 2 + + interp_level_indices = -1 * np.ones((n_vert_nodes, n_vert_weights), + dtype=int) + interp_level_indices[1:, 0] = level_indices.values + interp_level_indices[0:-1, 1] = level_indices.values + + interp_level_indices = xr.DataArray(dims=('nVertNodes', 'nVertWeights'), + data=interp_level_indices) + + half_level_thickness = 0.5 * layer_thickness.isel( + nCells=interp_cell_indices, nVertLevels=interp_level_indices) + half_level_thickness = half_level_thickness.where( + interp_level_indices >= 0, other=0.) + + # vertical weights are proportional to the half-level thickness + interp_cell_weights = half_level_thickness * interp_cell_weights.isel( + nVertLevels=interp_level_indices) + + valid_nodes = np.zeros((n_horiz_nodes, n_vert_nodes), dtype=bool) + valid_nodes[0:-1, 0:-1] = valid_transect_cells + valid_nodes[1:, 0:-1] = np.logical_or(valid_nodes[1:, 0:-1], + valid_transect_cells) + valid_nodes[0:-1, 1:] = np.logical_or(valid_nodes[0:-1, 1:], + valid_transect_cells) + valid_nodes[1:, 1:] = np.logical_or(valid_nodes[1:, 1:], + valid_transect_cells) + + valid_nodes = xr.DataArray(dims=('nHorizNodes', 'nVertNodes'), + data=valid_nodes) + + weight_sum = interp_cell_weights.sum(dim=('nHorizWeights', 'nVertWeights')) + out_mask = (weight_sum > 0.).broadcast_like(interp_cell_weights) + interp_cell_weights = (interp_cell_weights / weight_sum).where(out_mask) + + return interp_level_indices, interp_cell_weights, valid_nodes