Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for fields on MPAS edges to ocean transects #193

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions polaris/ocean/viz/transect/horiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def mesh_to_triangles(ds_mesh):
to cell centers as well as the cell index that each triangle is in and
cell indices and weights for interpolating data defined at cell centers
to triangle nodes. ``ds_tris`` includes variables ``triCellIndices``,
the cell that each triangle is part of; ``nodeCellIndices`` and
the cell that each triangle is part of; ``triEdgeIndices``, the edge
that each triangle is adjacent to; ``nodeCellIndices`` and
``nodeCellWeights``, the indices and weights used to interpolate from
MPAS cell centers to triangle nodes; Cartesian coordinates ``xNode``,
``yNode``, and ``zNode``; and ``lonNode``` and ``latNode`` in radians.
Expand All @@ -40,6 +41,7 @@ def mesh_to_triangles(ds_mesh):
n_vertices_on_cell = ds_mesh.nEdgesOnCell.values
vertices_on_cell = ds_mesh.verticesOnCell.values - 1
cells_on_vertex = ds_mesh.cellsOnVertex.values - 1
edges_on_cell = ds_mesh.edgesOnCell.values - 1

on_a_sphere = ds_mesh.attrs['on_a_sphere'].strip() == 'YES'
is_periodic = False
Expand All @@ -63,6 +65,7 @@ def mesh_to_triangles(ds_mesh):

# find the third vertex for each triangle
next_vertex = -1 * np.ones(vertices_on_cell.shape, int)
next_edge = -1 * np.ones(edges_on_cell.shape, int)
for i_vertex in range(max_edges):
valid = i_vertex < n_vertices_on_cell
invalid = np.logical_not(valid)
Expand All @@ -72,6 +75,8 @@ def mesh_to_triangles(ds_mesh):
i_next = np.where(i_vertex < nv - 1, i_vertex + 1, 0)
next_vertex[:, i_vertex][valid] = (
vertices_on_cell[cell_indices, i_next])
next_edge[:, i_vertex][valid] = (
edges_on_cell[cell_indices, i_next])

valid = vertices_on_cell >= 0
vertices_on_cell = vertices_on_cell[valid]
Expand All @@ -83,6 +88,10 @@ def mesh_to_triangles(ds_mesh):
indexing='ij')
tri_cell_indices = tri_cell_indices[valid]

# find the edge index for each triangle. Since the nth edge lies between
# the nth and n+1st vertex, this is next_edge, rather than edges_on_cell
tri_edge_indices = next_edge[valid]

# find list of cells and weights for each triangle node
node_cell_indices = -1 * np.ones((n_triangles, 3, 3), dtype=int)
node_cell_weights = np.zeros((n_triangles, 3, 3))
Expand All @@ -105,6 +114,7 @@ def mesh_to_triangles(ds_mesh):

ds_tris = xr.Dataset()
ds_tris['triCellIndices'] = ('nTriangles', tri_cell_indices)
ds_tris['triEdgeIndices'] = ('nTriangles', tri_edge_indices)
ds_tris['nodeCellIndices'] = (('nTriangles', 'nNodes', 'nInterp'),
node_cell_indices)
ds_tris['nodeCellWeights'] = (('nTriangles', 'nNodes', 'nInterp'),
Expand Down Expand Up @@ -231,8 +241,9 @@ def find_spherical_transect_cells_and_weights(
triangle for the edge associated with the intersection is given by
``numpy.mod(horizTriangleNodeIndices + 1, 3)``.

The MPAS cell that a given node belongs to is given by
``horizCellIndices``. Each node also has an associated set of 6
The MPAS cell and edge that a given node belongs to are given by
``horizCellIndices`` and ``horizEdgeIndices``, respectively. Each node
also has an associated set of 6
``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be
used to interpolate from MPAS cell centers to nodes first with
area-weighted averaging to MPAS vertices and then linear interpolation
Expand Down Expand Up @@ -407,9 +418,14 @@ def find_spherical_transect_cells_and_weights(
degrees)

valid_segs = seg_tris >= 0
valid_seg_tris = seg_tris[valid_segs]
cell_indices = -1 * np.ones(seg_tris.shape, dtype=int)
cell_indices[valid_segs] = (
ds_tris.triCellIndices.values[seg_tris[valid_segs]])
ds_tris.triCellIndices.values[valid_seg_tris])

edge_indices = -1 * np.ones(seg_tris.shape, dtype=int)
edge_indices[valid_segs] = (
ds_tris.triEdgeIndices.values[valid_seg_tris])

ds_out = xr.Dataset()
ds_out['xCartNode'] = (('nNodes',), x_out)
Expand All @@ -421,6 +437,7 @@ def find_spherical_transect_cells_and_weights(

ds_out['horizTriangleIndices'] = ('nSegments', seg_tris)
ds_out['horizCellIndices'] = ('nSegments', cell_indices)
ds_out['horizEdgeIndices'] = ('nSegments', edge_indices)
ds_out['horizTriangleNodeIndices'] = (('nSegments', 'nHorizBounds'),
seg_nodes)
ds_out['interpHorizCellIndices'] = (('nNodes', 'nHorizWeights'),
Expand Down Expand Up @@ -505,8 +522,9 @@ def find_planar_transect_cells_and_weights(
with the intersection is given by
``numpy.mod(horizTriangleNodeIndices + 1, 3)``.

The MPAS cell that a given node belongs to is given by
``horizCellIndices``. Each node also has an associated set of 6
The MPAS cell and edge that a given node belongs to are given by
``horizCellIndices`` and ``horizEdgeIndices``, respectively. Each node
also has an associated set of 6
``interpHorizCellIndices`` and ``interpHorizCellWeights`` that can be
used to interpolate from MPAS cell centers to nodes first with
area-weighted averaging to MPAS vertices and then linear interpolation
Expand Down Expand Up @@ -691,9 +709,14 @@ def find_planar_transect_cells_and_weights(
epsilon)

valid_segs = seg_tris >= 0
valid_seg_tris = seg_tris[valid_segs]
cell_indices = -1 * np.ones(seg_tris.shape, dtype=int)
cell_indices[valid_segs] = (
ds_tris.triCellIndices.values[seg_tris[valid_segs]])
ds_tris.triCellIndices.values[valid_seg_tris])

edge_indices = -1 * np.ones(seg_tris.shape, dtype=int)
edge_indices[valid_segs] = (
ds_tris.triEdgeIndices.values[valid_seg_tris])

ds_out = xr.Dataset()
ds_out['xNode'] = (('nNodes',), x_out)
Expand All @@ -702,6 +725,7 @@ def find_planar_transect_cells_and_weights(

ds_out['horizTriangleIndices'] = ('nSegments', seg_tris)
ds_out['horizCellIndices'] = ('nSegments', cell_indices)
ds_out['horizEdgeIndices'] = ('nSegments', edge_indices)
ds_out['horizTriangleNodeIndices'] = (('nSegments', 'nHorizBounds'),
seg_nodes)
ds_out['interpHorizCellIndices'] = (('nNodes', 'nHorizWeights'),
Expand Down
39 changes: 27 additions & 12 deletions polaris/ocean/viz/transect/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from polaris.ocean.viz.transect.vert import (
interp_mpas_edges_to_transect_cells,
interp_mpas_to_transect_cells,
interp_mpas_to_transect_nodes,
)
Expand All @@ -26,7 +27,8 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None,
:py:func:`polaris.ocean.viz.compute_transect()`

mpas_field : xarray.DataArray
The MPAS-Ocean 3D field to plot
The MPAS-Ocean 3D field (``nCells`` or ``nEdges`` by ``nVertLevels``)
to plot

out_filename : str, optional
The png file to write out to
Expand Down Expand Up @@ -58,9 +60,9 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None,

method : {'flat', 'bilinear'}, optional
The type of interpolation to use in plots. ``flat`` means constant
values over each MPAS cell. ``bilinear`` means smooth interpolation
between horizontally between cell centers and vertical between the
middle of layers.
values over each MPAS cell or edge. ``bilinear`` means smooth
interpolation horizontally between cell centers and vertical between
the middle of layers (available only for fields on MPAS cells).

outline_color : str or None, optional
The color to use to outline the transect or ``None`` for no outline
Expand Down Expand Up @@ -114,16 +116,29 @@ def plot_transect(ds_transect, mpas_field=None, out_filename=None, ax=None,
x = 1e-3 * ds_transect.dNode.broadcast_like(z)

if mpas_field is not None:
if method == 'flat':
transect_field = interp_mpas_to_transect_cells(ds_transect,
mpas_field)
if 'nCells' in mpas_field.dims:
if method == 'flat':
transect_field = interp_mpas_to_transect_cells(ds_transect,
mpas_field)
shading = 'flat'
elif method == 'bilinear':
transect_field = interp_mpas_to_transect_nodes(ds_transect,
mpas_field)
shading = 'gouraud'
else:
raise ValueError(f'Unsupported method for cell fields: '
f'{method}')
elif 'nEdges' in mpas_field.dims:
if method != 'flat':
raise ValueError(f'Unsupported method for edge fields: '
f'{method}')

transect_field = interp_mpas_edges_to_transect_cells(ds_transect,
mpas_field)
shading = 'flat'
elif method == 'bilinear':
transect_field = interp_mpas_to_transect_nodes(ds_transect,
mpas_field)
shading = 'gouraud'
else:
raise ValueError(f'Unsupported method: {method}')
raise ValueError(f'Expected one of nCells or nEdges in '
f'{mpas_field.dims}')

pc = ax.pcolormesh(x.values, z.values, transect_field.values,
shading=shading, cmap=cmap, vmin=vmin, vmax=vmax,
Expand Down
34 changes: 33 additions & 1 deletion polaris/ocean/viz/transect/vert.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def interp_mpas_to_transect_cells(ds_transect, da):
``find_transect_levels_and_weights()``

da : xarray.DataArray
An MPAS-Ocean field with dimensions `nCells`` and ``nVertLevels``
An MPAS-Ocean field with dimensions ``nCells`` and ``nVertLevels``
(possibly among others)

Returns
Expand All @@ -240,6 +240,38 @@ def interp_mpas_to_transect_cells(ds_transect, da):
return da_cells


def interp_mpas_edges_to_transect_cells(ds_transect, da):
"""
Interpolate an MPAS-Ocean DataArray with dimensions ``nEdges`` by
``nVertLevels`` to transect cells

Parameters
----------
ds_transect : xarray.Dataset
A dataset that defines an MPAS-Ocean transect, the results of calling
``find_transect_levels_and_weights()``

da : xarray.DataArray
An MPAS-Ocean field with dimensions ``nEdges`` and ``nVertLevels``
(possibly among others)

Returns
-------
da_cells : xarray.DataArray
The data array interpolated to transect cells with dimensions
``nSegments`` and ``nHalfLevels`` (in addition to whatever
dimensions were in ``da`` besides ``nEdges`` and ``nVertLevels``)
"""

edge_indices = ds_transect.edgeIndices
level_indices = ds_transect.levelIndices

da_cells = da.isel(nEdges=edge_indices, nVertLevels=level_indices)
da_cells = da_cells.where(ds_transect.validCells)

return da_cells


def interp_mpas_to_transect_nodes(ds_transect, da):
"""
Interpolate an MPAS-Ocean DataArray with dimensions ``nCells`` by
Expand Down
Loading