Skip to content

Commit

Permalink
Merge pull request #48 from fusion-energy/adding_tally_finding_utils
Browse files Browse the repository at this point in the history
Adding axis label finding function
  • Loading branch information
shimwell authored Feb 23, 2023
2 parents 2f04be0 + f3b4dc8 commit b536de9
Showing 1 changed file with 59 additions and 7 deletions.
66 changes: 59 additions & 7 deletions src/regular_mesh_plotter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,40 @@ def get_side_extent(self, side: str, view_direction: str = "x", bb=None):
return avail_extents[(side, view_direction)]


def reshape_data(self, dataset, view_direction):
reshaped_ds = dataset.reshape(self.dimension, order="F")

if view_direction == "x":
# vertical axis is z, horizontal axis is -y
transposed_ds = reshaped_ds.transpose(0, 1, 2)

elif view_direction == "-x":
# vertical axis is z, horizontal axis is y
transposed_ds = reshaped_ds.transpose(0, 1, 2)

elif view_direction == "y":
# vertical axis is z, horizontal axis is x
transposed_ds = reshaped_ds.transpose(1, 2, 0)

elif view_direction == "-y":
# vertical axis is z, horizontal axis is -x
transposed_ds = reshaped_ds.transpose(1, 2, 0)

elif view_direction == "z":
# vertical axis is y, horizontal axis is -x
transposed_ds = reshaped_ds.transpose(2, 0, 1)

elif view_direction == "-z":
# vertical axis is y, horizontal axis is x
transposed_ds = reshaped_ds.transpose(2, 0, 1)

else:
msg = "view_direction of {view_direction} is not one of the acceptable options ({supported_view_dirs})"
raise ValueError(msg)

return transposed_ds


def slice_of_data(
self,
dataset: np.ndarray,
Expand Down Expand Up @@ -114,34 +148,28 @@ def slice_of_data(
if volume_normalization:
dataset = dataset.flatten() / self.volumes.flatten()

reshaped_ds = dataset.reshape(self.dimension, order="F")
transposed_ds = self.reshape_data(dataset, view_direction)[slice_index]

if view_direction == "x":
# vertical axis is z, horizontal axis is -y
transposed_ds = reshaped_ds.transpose(0, 1, 2)[slice_index]
rotated_ds = np.rot90(transposed_ds, 1)
aligned_ds = np.fliplr(rotated_ds)
elif view_direction == "-x":
# vertical axis is z, horizontal axis is y
transposed_ds = reshaped_ds.transpose(0, 1, 2)[slice_index]
aligned_ds = np.rot90(transposed_ds, 1)
elif view_direction == "y":
# vertical axis is z, horizontal axis is x
transposed_ds = reshaped_ds.transpose(1, 2, 0)[slice_index]
aligned_ds = np.flipud(transposed_ds)
elif view_direction == "-y":
# vertical axis is z, horizontal axis is -x
transposed_ds = reshaped_ds.transpose(1, 2, 0)[slice_index]
aligned_ds = np.flipud(transposed_ds)
aligned_ds = np.fliplr(aligned_ds)
elif view_direction == "z":
# vertical axis is y, horizontal axis is -x
transposed_ds = reshaped_ds.transpose(2, 0, 1)[slice_index]
aligned_ds = np.rot90(transposed_ds, 1)
aligned_ds = np.fliplr(aligned_ds)
elif view_direction == "-z":
# vertical axis is y, horizontal axis is x
transposed_ds = reshaped_ds.transpose(2, 0, 1)[slice_index]
aligned_ds = np.rot90(transposed_ds, 1)
else:
msg = "view_direction of {view_direction} is not one of the acceptable options ({supported_view_dirs})"
Expand All @@ -150,9 +178,33 @@ def slice_of_data(
return aligned_ds


def get_axis_labels(self, view_direction):
"""Returns two axis label values for the x and y value. Takes
view_direction into account."""

if view_direction == "x":
xlabel = "Y [cm]"
ylabel = "Z [cm]"
if view_direction == "y":
xlabel = "X [cm]"
ylabel = "Z [cm]"
if view_direction == "z":
xlabel = "X [cm]"
ylabel = "Y [cm]"
return xlabel, ylabel


openmc.RegularMesh.reshape_data = reshape_data
openmc.mesh.RegularMesh.reshape_data = reshape_data

openmc.RegularMesh.get_axis_labels = get_axis_labels
openmc.mesh.RegularMesh.get_axis_labels = get_axis_labels

openmc.RegularMesh.slice_of_data = slice_of_data
openmc.mesh.RegularMesh.slice_of_data = slice_of_data

openmc.RegularMesh.get_mpl_plot_extent = get_mpl_plot_extent
openmc.mesh.RegularMesh.get_mpl_plot_extent = get_mpl_plot_extent

openmc.RegularMesh.get_side_extent = get_side_extent
openmc.mesh.RegularMesh.get_side_extent = get_side_extent

0 comments on commit b536de9

Please sign in to comment.