diff --git a/polaris/ocean/viz/transect/horiz.py b/polaris/ocean/viz/transect/horiz.py index 6e16e9198..305f3e4d7 100644 --- a/polaris/ocean/viz/transect/horiz.py +++ b/polaris/ocean/viz/transect/horiz.py @@ -28,7 +28,8 @@ def mesh_to_triangles(ds_mesh): to cell centers as well as the cell index that each triangle is in and cell indices and weights for interpolating data defined at cell centers to triangle nodes. ``ds_tris`` includes variables ``triCellIndices``, - the cell that each triangle is part of; ``nodeCellIndices`` and + the cell that each triangle is part of; ``triEdgeIndices``, the edge + that each triangle is adjacent to; ``nodeCellIndices`` and ``nodeCellWeights``, the indices and weights used to interpolate from MPAS cell centers to triangle nodes; Cartesian coordinates ``xNode``, ``yNode``, and ``zNode``; and ``lonNode``` and ``latNode`` in radians. @@ -40,6 +41,7 @@ def mesh_to_triangles(ds_mesh): n_vertices_on_cell = ds_mesh.nEdgesOnCell.values vertices_on_cell = ds_mesh.verticesOnCell.values - 1 cells_on_vertex = ds_mesh.cellsOnVertex.values - 1 + edges_on_cell = ds_mesh.edgesOnCell.values - 1 on_a_sphere = ds_mesh.attrs['on_a_sphere'].strip() == 'YES' is_periodic = False @@ -63,6 +65,7 @@ def mesh_to_triangles(ds_mesh): # find the third vertex for each triangle next_vertex = -1 * np.ones(vertices_on_cell.shape, int) + next_edge = -1 * np.ones(edges_on_cell.shape, int) for i_vertex in range(max_edges): valid = i_vertex < n_vertices_on_cell invalid = np.logical_not(valid) @@ -72,6 +75,8 @@ def mesh_to_triangles(ds_mesh): i_next = np.where(i_vertex < nv - 1, i_vertex + 1, 0) next_vertex[:, i_vertex][valid] = ( vertices_on_cell[cell_indices, i_next]) + next_edge[:, i_vertex][valid] = ( + edges_on_cell[cell_indices, i_next]) valid = vertices_on_cell >= 0 vertices_on_cell = vertices_on_cell[valid] @@ -83,6 +88,10 @@ def mesh_to_triangles(ds_mesh): indexing='ij') tri_cell_indices = tri_cell_indices[valid] + # find the edge index for each triangle. Since the nth edge lies between + # the nth and n+1st vertex, this is next_edge, rather than edges_on_cell + tri_edge_indices = next_edge[valid] + # find list of cells and weights for each triangle node node_cell_indices = -1 * np.ones((n_triangles, 3, 3), dtype=int) node_cell_weights = np.zeros((n_triangles, 3, 3)) @@ -105,6 +114,7 @@ def mesh_to_triangles(ds_mesh): ds_tris = xr.Dataset() ds_tris['triCellIndices'] = ('nTriangles', tri_cell_indices) + ds_tris['triEdgeIndices'] = ('nTriangles', tri_edge_indices) ds_tris['nodeCellIndices'] = (('nTriangles', 'nNodes', 'nInterp'), node_cell_indices) ds_tris['nodeCellWeights'] = (('nTriangles', 'nNodes', 'nInterp'), @@ -231,8 +241,9 @@ def find_spherical_transect_cells_and_weights( triangle for the edge associated with the intersection is given by ``numpy.mod(horizTriangleNodeIndices + 1, 3)``. - The MPAS cell that a given node belongs to is given by - ``horizCellIndices``. Each node also has an associated set of 6 + The MPAS cell and edge that a given node belongs to are given by + ``horizCellIndices`` and ``horizEdgeIndices``, respectively. Each node + also has an associated set of 6 ``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be used to interpolate from MPAS cell centers to nodes first with area-weighted averaging to MPAS vertices and then linear interpolation @@ -407,9 +418,14 @@ def find_spherical_transect_cells_and_weights( degrees) valid_segs = seg_tris >= 0 + valid_seg_tris = seg_tris[valid_segs] cell_indices = -1 * np.ones(seg_tris.shape, dtype=int) cell_indices[valid_segs] = ( - ds_tris.triCellIndices.values[seg_tris[valid_segs]]) + ds_tris.triCellIndices.values[valid_seg_tris]) + + edge_indices = -1 * np.ones(seg_tris.shape, dtype=int) + edge_indices[valid_segs] = ( + ds_tris.triEdgeIndices.values[valid_seg_tris]) ds_out = xr.Dataset() ds_out['xCartNode'] = (('nNodes',), x_out) @@ -421,6 +437,7 @@ def find_spherical_transect_cells_and_weights( ds_out['horizTriangleIndices'] = ('nSegments', seg_tris) ds_out['horizCellIndices'] = ('nSegments', cell_indices) + ds_out['horizEdgeIndices'] = ('nSegments', edge_indices) ds_out['horizTriangleNodeIndices'] = (('nSegments', 'nHorizBounds'), seg_nodes) ds_out['interpHorizCellIndices'] = (('nNodes', 'nHorizWeights'), @@ -505,8 +522,9 @@ def find_planar_transect_cells_and_weights( with the intersection is given by ``numpy.mod(horizTriangleNodeIndices + 1, 3)``. - The MPAS cell that a given node belongs to is given by - ``horizCellIndices``. Each node also has an associated set of 6 + The MPAS cell and edge that a given node belongs to are given by + ``horizCellIndices`` and ``horizEdgeIndices``, respectively. Each node + also has an associated set of 6 ``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be used to interpolate from MPAS cell centers to nodes first with area-weighted averaging to MPAS vertices and then linear interpolation @@ -691,9 +709,14 @@ def find_planar_transect_cells_and_weights( epsilon) valid_segs = seg_tris >= 0 + valid_seg_tris = seg_tris[valid_segs] cell_indices = -1 * np.ones(seg_tris.shape, dtype=int) cell_indices[valid_segs] = ( - ds_tris.triCellIndices.values[seg_tris[valid_segs]]) + ds_tris.triCellIndices.values[valid_seg_tris]) + + edge_indices = -1 * np.ones(seg_tris.shape, dtype=int) + edge_indices[valid_segs] = ( + ds_tris.triEdgeIndices.values[valid_seg_tris]) ds_out = xr.Dataset() ds_out['xNode'] = (('nNodes',), x_out) @@ -702,6 +725,7 @@ def find_planar_transect_cells_and_weights( ds_out['horizTriangleIndices'] = ('nSegments', seg_tris) ds_out['horizCellIndices'] = ('nSegments', cell_indices) + ds_out['horizEdgeIndices'] = ('nSegments', edge_indices) ds_out['horizTriangleNodeIndices'] = (('nSegments', 'nHorizBounds'), seg_nodes) ds_out['interpHorizCellIndices'] = (('nNodes', 'nHorizWeights'), diff --git a/polaris/ocean/viz/transect/plot.py b/polaris/ocean/viz/transect/plot.py index 3ebac1e2f..b226cbcef 100644 --- a/polaris/ocean/viz/transect/plot.py +++ b/polaris/ocean/viz/transect/plot.py @@ -3,6 +3,7 @@ import numpy as np from polaris.ocean.viz.transect.vert import ( + interp_mpas_edges_to_transect_cells, interp_mpas_to_transect_cells, interp_mpas_to_transect_nodes, ) @@ -26,7 +27,8 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, :py:func:`polaris.ocean.viz.compute_transect()` mpas_field : xarray.DataArray - The MPAS-Ocean 3D field to plot + The MPAS-Ocean 3D field (``nCells`` or ``nEdges`` by ``nVertLevels``) + to plot out_filename : str, optional The png file to write out to @@ -58,9 +60,9 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, method : {'flat', 'bilinear'}, optional The type of interpolation to use in plots. ``flat`` means constant - values over each MPAS cell. ``bilinear`` means smooth interpolation - between horizontally between cell centers and vertical between the - middle of layers. + values over each MPAS cell or edge. ``bilinear`` means smooth + interpolation horizontally between cell centers and vertical between + the middle of layers (available only for fields on MPAS cells). outline_color : str or None, optional The color to use to outline the transect or ``None`` for no outline @@ -114,16 +116,29 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None, x = 1e-3 * ds_transect.dNode.broadcast_like(z) if mpas_field is not None: - if method == 'flat': - transect_field = interp_mpas_to_transect_cells(ds_transect, - mpas_field) + if 'nCells' in mpas_field.dims: + if method == 'flat': + transect_field = interp_mpas_to_transect_cells(ds_transect, + mpas_field) + shading = 'flat' + elif method == 'bilinear': + transect_field = interp_mpas_to_transect_nodes(ds_transect, + mpas_field) + shading = 'gouraud' + else: + raise ValueError(f'Unsupported method for cell fields: ' + f'{method}') + elif 'nEdges' in mpas_field.dims: + if method != 'flat': + raise ValueError(f'Unsupported method for edge fields: ' + f'{method}') + + transect_field = interp_mpas_edges_to_transect_cells(ds_transect, + mpas_field) shading = 'flat' - elif method == 'bilinear': - transect_field = interp_mpas_to_transect_nodes(ds_transect, - mpas_field) - shading = 'gouraud' else: - raise ValueError(f'Unsupported method: {method}') + raise ValueError(f'Expected one of nCells or nEdges in ' + f'{mpas_field.dims}') pc = ax.pcolormesh(x.values, z.values, transect_field.values, shading=shading, cmap=cmap, vmin=vmin, vmax=vmax, diff --git a/polaris/ocean/viz/transect/vert.py b/polaris/ocean/viz/transect/vert.py index a146cb46c..e76ac0b85 100644 --- a/polaris/ocean/viz/transect/vert.py +++ b/polaris/ocean/viz/transect/vert.py @@ -220,7 +220,7 @@ def interp_mpas_to_transect_cells(ds_transect, da): ``find_transect_levels_and_weights()`` da : xarray.DataArray - An MPAS-Ocean field with dimensions `nCells`` and ``nVertLevels`` + An MPAS-Ocean field with dimensions ``nCells`` and ``nVertLevels`` (possibly among others) Returns @@ -240,6 +240,38 @@ def interp_mpas_to_transect_cells(ds_transect, da): return da_cells +def interp_mpas_edges_to_transect_cells(ds_transect, da): + """ + Interpolate an MPAS-Ocean DataArray with dimensions ``nEdges`` by + ``nVertLevels`` to transect cells + + Parameters + ---------- + ds_transect : xarray.Dataset + A dataset that defines an MPAS-Ocean transect, the results of calling + ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An MPAS-Ocean field with dimensions ``nEdges`` and ``nVertLevels`` + (possibly among others) + + Returns + ------- + da_cells : xarray.DataArray + The data array interpolated to transect cells with dimensions + ``nSegments`` and ``nHalfLevels`` (in addition to whatever + dimensions were in ``da`` besides ``nEdges`` and ``nVertLevels``) + """ + + edge_indices = ds_transect.edgeIndices + level_indices = ds_transect.levelIndices + + da_cells = da.isel(nEdges=edge_indices, nVertLevels=level_indices) + da_cells = da_cells.where(ds_transect.validCells) + + return da_cells + + def interp_mpas_to_transect_nodes(ds_transect, da): """ Interpolate an MPAS-Ocean DataArray with dimensions ``nCells`` by