diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py index 17c7b42..65256d5 100644 --- a/src/openmc_regular_mesh_plotter/core.py +++ b/src/openmc_regular_mesh_plotter/core.py @@ -22,6 +22,12 @@ _default_outline_kwargs = {"colors": "black", "linestyles": "solid", "linewidths": 1} +def _squeeze_end_of_array(array, dims_required=3): + while len(array.shape) > dims_required: + array = np.squeeze(array, axis=len(array.shape) - 1) + return array + + def plot_mesh_tally( tally: "openmc.Tally", basis: str = "xy", @@ -95,12 +101,7 @@ def plot_mesh_tally( mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh if not isinstance(mesh, openmc.RegularMesh): - raise NotImplemented( - f"Only RegularMesh are currently supported not {type(mesh)}" - ) - # if mesh.n_dimension != 3: - # msg = "Your mesh has {mesh.n_dimension} dimension and currently only RegularMesh with 3 dimensions are supported" - # raise NotImplementedError(msg) + raise NotImplemented(f"Only RegularMesh are supported not {type(mesh)}") # if score is not specified and tally has a single score then we know which score to use if score is None: @@ -112,19 +113,32 @@ def plot_mesh_tally( tally_slice = tally.get_slice(scores=[score]) - # if mesh.n_dimension == 3: + basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis] if 1 in mesh.dimension: index_of_2d = mesh.dimension.index(1) axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d] + if ( + axis_of_2d in basis + ): # checks if the axis is being plotted, e.g is 'x' in 'xy' + raise ValueError( + "The selected tally has a mesh that has 1 dimension in the " + f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis " + f"of {basis}." + ) - # todo check if 1 appears twice or three times, raise value error if so + # TODO check if 1 appears twice or three times, raise value error if so - tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value).squeeze() + tally_data = tally_slice.get_reshaped_data( + expand_dims=True, value=value + ) # .squeeze() - basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis] - if len(tally_data.shape) == 3: + tally_data = _squeeze_end_of_array(tally_data, dims_required=3) + + # if len(tally_data.shape) == 3: + if mesh.n_dimension == 3: if slice_index is None: + # finds the mid index slice_index = int(tally_data.shape[basis_to_index] / 2) if basis == "xz": @@ -139,29 +153,11 @@ def plot_mesh_tally( slice_data = tally_data[:, :, slice_index] data = np.rot90(slice_data, -3) xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" - # elif mesh.n_dimension == 2: - elif len(tally_data.shape) == 2: - if basis_to_index == index_of_2d: - slice_data = tally_data[:, :] - if basis == "xz": - data = np.flip(np.rot90(slice_data, -1)) - xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]" - elif basis == "yz": - data = np.flip(np.rot90(slice_data, -1)) - xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]" - else: # basis == 'xy' - data = np.rot90(slice_data, -3) - xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" - - else: - raise ValueError( - "The selected tally has a mesh that has 1 dimension in the " - f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis " - f"of {basis}." - ) else: - raise ValueError("mesh n_dimension") + raise ValueError( + f"mesh n_dimension is not 3 or 2 but is {mesh.n_dimension} which is not supported" + ) if volume_normalization: # in a regular mesh all volumes are the same so we just divide by the first diff --git a/tests/test_units.py b/tests/test_units.py index c4528a7..13a07b9 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -121,8 +121,10 @@ def test_plot_2d_mesh_tally(model): tally_result = statepoint.get_tally(name="mesh-tal") plot = plot_mesh_tally( - tally=tally_result, basis="yz", slice_index=29 # max value of slice selected + tally=tally_result, basis="yz", slice_index=0 # max value of slice selected ) + + plot = plot_mesh_tally(tally=tally_result, basis="yz") # axis_units defaults to cm assert plot.xaxis.get_label().get_text() == "y [cm]" assert plot.yaxis.get_label().get_text() == "z [cm]"