diff --git a/src/regular_mesh_plotter/core.py b/src/regular_mesh_plotter/core.py index 33192d0..614209a 100644 --- a/src/regular_mesh_plotter/core.py +++ b/src/regular_mesh_plotter/core.py @@ -74,6 +74,40 @@ def get_side_extent(self, side: str, view_direction: str = "x", bb=None): return avail_extents[(side, view_direction)] +def reshape_data(self, dataset, view_direction): + reshaped_ds = dataset.reshape(self.dimension, order="F") + + if view_direction == "x": + # vertical axis is z, horizontal axis is -y + transposed_ds = reshaped_ds.transpose(0, 1, 2) + + elif view_direction == "-x": + # vertical axis is z, horizontal axis is y + transposed_ds = reshaped_ds.transpose(0, 1, 2) + + elif view_direction == "y": + # vertical axis is z, horizontal axis is x + transposed_ds = reshaped_ds.transpose(1, 2, 0) + + elif view_direction == "-y": + # vertical axis is z, horizontal axis is -x + transposed_ds = reshaped_ds.transpose(1, 2, 0) + + elif view_direction == "z": + # vertical axis is y, horizontal axis is -x + transposed_ds = reshaped_ds.transpose(2, 0, 1) + + elif view_direction == "-z": + # vertical axis is y, horizontal axis is x + transposed_ds = reshaped_ds.transpose(2, 0, 1) + + else: + msg = "view_direction of {view_direction} is not one of the acceptable options ({supported_view_dirs})" + raise ValueError(msg) + + return transposed_ds + + def slice_of_data( self, dataset: np.ndarray, @@ -114,34 +148,28 @@ def slice_of_data( if volume_normalization: dataset = dataset.flatten() / self.volumes.flatten() - reshaped_ds = dataset.reshape(self.dimension, order="F") + transposed_ds = self.reshape_data(dataset, view_direction)[slice_index] if view_direction == "x": # vertical axis is z, horizontal axis is -y - transposed_ds = reshaped_ds.transpose(0, 1, 2)[slice_index] rotated_ds = np.rot90(transposed_ds, 1) aligned_ds = np.fliplr(rotated_ds) elif view_direction == "-x": # vertical axis is z, horizontal axis is y - transposed_ds = reshaped_ds.transpose(0, 1, 2)[slice_index] aligned_ds = np.rot90(transposed_ds, 1) elif view_direction == "y": # vertical axis is z, horizontal axis is x - transposed_ds = reshaped_ds.transpose(1, 2, 0)[slice_index] aligned_ds = np.flipud(transposed_ds) elif view_direction == "-y": # vertical axis is z, horizontal axis is -x - transposed_ds = reshaped_ds.transpose(1, 2, 0)[slice_index] aligned_ds = np.flipud(transposed_ds) aligned_ds = np.fliplr(aligned_ds) elif view_direction == "z": # vertical axis is y, horizontal axis is -x - transposed_ds = reshaped_ds.transpose(2, 0, 1)[slice_index] aligned_ds = np.rot90(transposed_ds, 1) aligned_ds = np.fliplr(aligned_ds) elif view_direction == "-z": # vertical axis is y, horizontal axis is x - transposed_ds = reshaped_ds.transpose(2, 0, 1)[slice_index] aligned_ds = np.rot90(transposed_ds, 1) else: msg = "view_direction of {view_direction} is not one of the acceptable options ({supported_view_dirs})" @@ -150,9 +178,33 @@ def slice_of_data( return aligned_ds +def get_axis_labels(self, view_direction): + """Returns two axis label values for the x and y value. Takes + view_direction into account.""" + + if view_direction == "x": + xlabel = "Y [cm]" + ylabel = "Z [cm]" + if view_direction == "y": + xlabel = "X [cm]" + ylabel = "Z [cm]" + if view_direction == "z": + xlabel = "X [cm]" + ylabel = "Y [cm]" + return xlabel, ylabel + + +openmc.RegularMesh.reshape_data = reshape_data +openmc.mesh.RegularMesh.reshape_data = reshape_data + +openmc.RegularMesh.get_axis_labels = get_axis_labels +openmc.mesh.RegularMesh.get_axis_labels = get_axis_labels + openmc.RegularMesh.slice_of_data = slice_of_data openmc.mesh.RegularMesh.slice_of_data = slice_of_data + openmc.RegularMesh.get_mpl_plot_extent = get_mpl_plot_extent openmc.mesh.RegularMesh.get_mpl_plot_extent = get_mpl_plot_extent + openmc.RegularMesh.get_side_extent = get_side_extent openmc.mesh.RegularMesh.get_side_extent = get_side_extent