From 0185d4f9e48a49dc17b5014072125893eec7ccf1 Mon Sep 17 00:00:00 2001 From: Jonathan Shimwell Date: Thu, 23 Feb 2023 18:36:17 +0000 Subject: [PATCH 1/2] added transpose function --- src/regular_mesh_plotter/core.py | 66 ++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/src/regular_mesh_plotter/core.py b/src/regular_mesh_plotter/core.py index 2050a10..219c4c5 100644 --- a/src/regular_mesh_plotter/core.py +++ b/src/regular_mesh_plotter/core.py @@ -74,6 +74,41 @@ def get_side_extent(self, side: str, view_direction: str = "x", bb=None): return avail_extents[(side, view_direction)] +def reshape_data(self, dataset, view_direction): + + reshaped_ds = dataset.reshape(self.dimension, order="F") + + if view_direction == "x": + # vertical axis is z, horizontal axis is -y + transposed_ds = reshaped_ds.transpose(0, 1, 2) + + elif view_direction == "-x": + # vertical axis is z, horizontal axis is y + transposed_ds = reshaped_ds.transpose(0, 1, 2) + + elif view_direction == "y": + # vertical axis is z, horizontal axis is x + transposed_ds = reshaped_ds.transpose(1, 2, 0) + + elif view_direction == "-y": + # vertical axis is z, horizontal axis is -x + transposed_ds = reshaped_ds.transpose(1, 2, 0) + + elif view_direction == "z": + # vertical axis is y, horizontal axis is -x + transposed_ds = reshaped_ds.transpose(2, 0, 1) + + elif view_direction == "-z": + # vertical axis is y, horizontal axis is x + transposed_ds = reshaped_ds.transpose(2, 0, 1) + + else: + msg = "view_direction of {view_direction} is not one of the acceptable options ({supported_view_dirs})" + raise ValueError(msg) + + return transposed_ds + + def slice_of_data( self, dataset: np.ndarray, @@ -114,34 +149,28 @@ def slice_of_data( if volume_normalization: dataset = dataset.flatten() / self.volumes.flatten() - reshaped_ds = dataset.reshape(self.dimension, order="F") + transposed_ds = self.reshape_data(dataset, view_direction)[slice_index] if view_direction == "x": # vertical axis is z, horizontal axis is -y - transposed_ds = reshaped_ds.transpose(0, 1, 2)[slice_index] rotated_ds = np.rot90(transposed_ds, 1) aligned_ds = np.fliplr(rotated_ds) elif view_direction == "-x": # vertical axis is z, horizontal axis is y - transposed_ds = reshaped_ds.transpose(0, 1, 2)[slice_index] aligned_ds = np.rot90(transposed_ds, 1) elif view_direction == "y": # vertical axis is z, horizontal axis is x - transposed_ds = reshaped_ds.transpose(1, 2, 0)[slice_index] aligned_ds = np.flipud(transposed_ds) elif view_direction == "-y": # vertical axis is z, horizontal axis is -x - transposed_ds = reshaped_ds.transpose(1, 2, 0)[slice_index] aligned_ds = np.flipud(transposed_ds) aligned_ds = np.fliplr(aligned_ds) elif view_direction == "z": # vertical axis is y, horizontal axis is -x - transposed_ds = reshaped_ds.transpose(2, 0, 1)[slice_index] aligned_ds = np.rot90(transposed_ds, 1) aligned_ds = np.fliplr(aligned_ds) elif view_direction == "-z": # vertical axis is y, horizontal axis is x - transposed_ds = reshaped_ds.transpose(2, 0, 1)[slice_index] aligned_ds = np.rot90(transposed_ds, 1) else: msg = "view_direction of {view_direction} is not one of the acceptable options ({supported_view_dirs})" @@ -150,9 +179,32 @@ def slice_of_data( return aligned_ds +def get_axis_labels(self, view_direction): + """Returns two axis label values for the x and y value. Takes + view_direction into account.""" + + if view_direction == "x": + xlabel = "Y [cm]" + ylabel = "Z [cm]" + if view_direction == "y": + xlabel = "X [cm]" + ylabel = "Z [cm]" + if view_direction == "z": + xlabel = "X [cm]" + ylabel = "Y [cm]" + return xlabel, ylabel + +openmc.RegularMesh.reshape_data = reshape_data +openmc.mesh.RegularMesh.reshape_data = reshape_data + +openmc.RegularMesh.get_axis_labels = get_axis_labels +openmc.mesh.RegularMesh.get_axis_labels = get_axis_labels + openmc.RegularMesh.slice_of_data = slice_of_data openmc.mesh.RegularMesh.slice_of_data = slice_of_data + openmc.RegularMesh.get_mpl_plot_extent = get_mpl_plot_extent openmc.mesh.RegularMesh.get_mpl_plot_extent = get_mpl_plot_extent + openmc.RegularMesh.get_side_extent = get_side_extent openmc.mesh.RegularMesh.get_side_extent = get_side_extent From f3b4dc8f35f6455db9f8fec037429ae1d46d7ffe Mon Sep 17 00:00:00 2001 From: shimwell Date: Thu, 23 Feb 2023 18:38:09 +0000 Subject: [PATCH 2/2] [skip ci] Apply formatting changes --- src/regular_mesh_plotter/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/regular_mesh_plotter/core.py b/src/regular_mesh_plotter/core.py index f2687ef..614209a 100644 --- a/src/regular_mesh_plotter/core.py +++ b/src/regular_mesh_plotter/core.py @@ -75,7 +75,6 @@ def get_side_extent(self, side: str, view_direction: str = "x", bb=None): def reshape_data(self, dataset, view_direction): - reshaped_ds = dataset.reshape(self.dimension, order="F") if view_direction == "x": @@ -194,6 +193,7 @@ def get_axis_labels(self, view_direction): ylabel = "Y [cm]" return xlabel, ylabel + openmc.RegularMesh.reshape_data = reshape_data openmc.mesh.RegularMesh.reshape_data = reshape_data