From 1f94d0ec1297f08d0c37a152d7d0d9de6c3ce5d2 Mon Sep 17 00:00:00 2001 From: Jonathan Shimwell Date: Fri, 24 Nov 2023 14:16:43 +0000 Subject: [PATCH 1/6] squeezing end of tally shape only --- src/openmc_regular_mesh_plotter/core.py | 47 +++++++++++-------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py index 17c7b42..a19765b 100644 --- a/src/openmc_regular_mesh_plotter/core.py +++ b/src/openmc_regular_mesh_plotter/core.py @@ -21,6 +21,10 @@ _default_outline_kwargs = {"colors": "black", "linestyles": "solid", "linewidths": 1} +def _squeeze_end_of_array(array, dims_required=3): + while len(array.shape) > dims_required: + array = np.squeeze(array, axis=len(array.shape)-1) + return array def plot_mesh_tally( tally: "openmc.Tally", @@ -112,19 +116,29 @@ def plot_mesh_tally( tally_slice = tally.get_slice(scores=[score]) - # if mesh.n_dimension == 3: + basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis] + print('mesh.dimension', mesh.dimension) if 1 in mesh.dimension: index_of_2d = mesh.dimension.index(1) axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d] + if axis_of_2d in basis: # checks if the axis is being plotted, e.g is 'x' in 'xy' + raise ValueError( + "The selected tally has a mesh that has 1 dimension in the " + f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis " + f"of {basis}." + ) - # todo check if 1 appears twice or three times, raise value error if so + # TODO check if 1 appears twice or three times, raise value error if so - tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value).squeeze() + tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value)#.squeeze() - basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis] - if len(tally_data.shape) == 3: + tally_data = _squeeze_end_of_array(tally_data, dims_required=3) + + # if len(tally_data.shape) == 3: + if mesh.n_dimension == 3: if slice_index is None: + # finds the mid index slice_index = int(tally_data.shape[basis_to_index] / 2) if basis == "xz": @@ -137,31 +151,12 @@ def plot_mesh_tally( xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]" else: # basis == 'xy' slice_data = tally_data[:, :, slice_index] + print('shape slice_data', slice_data.shape) data = np.rot90(slice_data, -3) xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" - # elif mesh.n_dimension == 2: - elif len(tally_data.shape) == 2: - if basis_to_index == index_of_2d: - slice_data = tally_data[:, :] - if basis == "xz": - data = np.flip(np.rot90(slice_data, -1)) - xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]" - elif basis == "yz": - data = np.flip(np.rot90(slice_data, -1)) - xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]" - else: # basis == 'xy' - data = np.rot90(slice_data, -3) - xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" - - else: - raise ValueError( - "The selected tally has a mesh that has 1 dimension in the " - f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis " - f"of {basis}." - ) else: - raise ValueError("mesh n_dimension") + raise ValueError(f"mesh n_dimension is not 3 or 2 but is {mesh.n_dimension} which is not supported") if volume_normalization: # in a regular mesh all volumes are the same so we just divide by the first From b6c54c2bd7d7a7ec7b452a42be77b89d106ba5d1 Mon Sep 17 00:00:00 2001 From: shimwell Date: Fri, 24 Nov 2023 14:18:47 +0000 Subject: [PATCH 2/6] [skip ci] Apply formatting changes --- src/openmc_regular_mesh_plotter/core.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py index a19765b..5b94a0a 100644 --- a/src/openmc_regular_mesh_plotter/core.py +++ b/src/openmc_regular_mesh_plotter/core.py @@ -21,11 +21,13 @@ _default_outline_kwargs = {"colors": "black", "linestyles": "solid", "linewidths": 1} + def _squeeze_end_of_array(array, dims_required=3): while len(array.shape) > dims_required: - array = np.squeeze(array, axis=len(array.shape)-1) + array = np.squeeze(array, axis=len(array.shape) - 1) return array + def plot_mesh_tally( tally: "openmc.Tally", basis: str = "xy", @@ -118,11 +120,13 @@ def plot_mesh_tally( basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis] - print('mesh.dimension', mesh.dimension) + print("mesh.dimension", mesh.dimension) if 1 in mesh.dimension: index_of_2d = mesh.dimension.index(1) axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d] - if axis_of_2d in basis: # checks if the axis is being plotted, e.g is 'x' in 'xy' + if ( + axis_of_2d in basis + ): # checks if the axis is being plotted, e.g is 'x' in 'xy' raise ValueError( "The selected tally has a mesh that has 1 dimension in the " f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis " @@ -131,7 +135,9 @@ def plot_mesh_tally( # TODO check if 1 appears twice or three times, raise value error if so - tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value)#.squeeze() + tally_data = tally_slice.get_reshaped_data( + expand_dims=True, value=value + ) # .squeeze() tally_data = _squeeze_end_of_array(tally_data, dims_required=3) @@ -151,12 +157,14 @@ def plot_mesh_tally( xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]" else: # basis == 'xy' slice_data = tally_data[:, :, slice_index] - print('shape slice_data', slice_data.shape) + print("shape slice_data", slice_data.shape) data = np.rot90(slice_data, -3) xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" else: - raise ValueError(f"mesh n_dimension is not 3 or 2 but is {mesh.n_dimension} which is not supported") + raise ValueError( + f"mesh n_dimension is not 3 or 2 but is {mesh.n_dimension} which is not supported" + ) if volume_normalization: # in a regular mesh all volumes are the same so we just divide by the first From 8ac55f2ceecf47bf3aa898aed43ec941041d9577 Mon Sep 17 00:00:00 2001 From: Jonathan Shimwell Date: Fri, 24 Nov 2023 14:41:34 +0000 Subject: [PATCH 3/6] removed prints --- src/openmc_regular_mesh_plotter/core.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py index a19765b..2784b42 100644 --- a/src/openmc_regular_mesh_plotter/core.py +++ b/src/openmc_regular_mesh_plotter/core.py @@ -100,11 +100,8 @@ def plot_mesh_tally( mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh if not isinstance(mesh, openmc.RegularMesh): raise NotImplemented( - f"Only RegularMesh are currently supported not {type(mesh)}" + f"Only RegularMesh are supported not {type(mesh)}" ) - # if mesh.n_dimension != 3: - # msg = "Your mesh has {mesh.n_dimension} dimension and currently only RegularMesh with 3 dimensions are supported" - # raise NotImplementedError(msg) # if score is not specified and tally has a single score then we know which score to use if score is None: @@ -118,7 +115,6 @@ def plot_mesh_tally( basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis] - print('mesh.dimension', mesh.dimension) if 1 in mesh.dimension: index_of_2d = mesh.dimension.index(1) axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d] @@ -151,7 +147,6 @@ def plot_mesh_tally( xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]" else: # basis == 'xy' slice_data = tally_data[:, :, slice_index] - print('shape slice_data', slice_data.shape) data = np.rot90(slice_data, -3) xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" From 50faaa24cc232950e3ae5c238fea88c72feecf37 Mon Sep 17 00:00:00 2001 From: shimwell Date: Fri, 24 Nov 2023 14:44:27 +0000 Subject: [PATCH 4/6] [skip ci] Apply formatting changes --- src/openmc_regular_mesh_plotter/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py index d35ccfd..65256d5 100644 --- a/src/openmc_regular_mesh_plotter/core.py +++ b/src/openmc_regular_mesh_plotter/core.py @@ -101,9 +101,7 @@ def plot_mesh_tally( mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh if not isinstance(mesh, openmc.RegularMesh): - raise NotImplemented( - f"Only RegularMesh are supported not {type(mesh)}" - ) + raise NotImplemented(f"Only RegularMesh are supported not {type(mesh)}") # if score is not specified and tally has a single score then we know which score to use if score is None: From 237195da950025ad3360bd01dcb727528ee9d4be Mon Sep 17 00:00:00 2001 From: Jonathan Shimwell Date: Mon, 27 Nov 2023 11:03:24 +0000 Subject: [PATCH 5/6] changed slice index --- tests/test_units.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_units.py b/tests/test_units.py index c4528a7..682a5d5 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -121,7 +121,11 @@ def test_plot_2d_mesh_tally(model): tally_result = statepoint.get_tally(name="mesh-tal") plot = plot_mesh_tally( - tally=tally_result, basis="yz", slice_index=29 # max value of slice selected + tally=tally_result, basis="yz", slice_index=0 # max value of slice selected + ) + + plot = plot_mesh_tally( + tally=tally_result, basis="yz" ) # axis_units defaults to cm assert plot.xaxis.get_label().get_text() == "y [cm]" From 97c27726e05225c6e264aa61eab14edeb193449e Mon Sep 17 00:00:00 2001 From: shimwell Date: Mon, 27 Nov 2023 11:06:44 +0000 Subject: [PATCH 6/6] [skip ci] Apply formatting changes --- tests/test_units.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_units.py b/tests/test_units.py index 682a5d5..13a07b9 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -124,9 +124,7 @@ def test_plot_2d_mesh_tally(model): tally=tally_result, basis="yz", slice_index=0 # max value of slice selected ) - plot = plot_mesh_tally( - tally=tally_result, basis="yz" - ) + plot = plot_mesh_tally(tally=tally_result, basis="yz") # axis_units defaults to cm assert plot.xaxis.get_label().get_text() == "y [cm]" assert plot.yaxis.get_label().get_text() == "z [cm]"