Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved 2d mesh plotting by squeezing end of tally shape only #59

Merged
merged 8 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 28 additions & 32 deletions src/openmc_regular_mesh_plotter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
_default_outline_kwargs = {"colors": "black", "linestyles": "solid", "linewidths": 1}


def _squeeze_end_of_array(array, dims_required=3):
while len(array.shape) > dims_required:
array = np.squeeze(array, axis=len(array.shape) - 1)
return array


def plot_mesh_tally(
tally: "openmc.Tally",
basis: str = "xy",
Expand Down Expand Up @@ -95,12 +101,7 @@ def plot_mesh_tally(

mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh
if not isinstance(mesh, openmc.RegularMesh):
raise NotImplemented(
f"Only RegularMesh are currently supported not {type(mesh)}"
)
# if mesh.n_dimension != 3:
# msg = "Your mesh has {mesh.n_dimension} dimension and currently only RegularMesh with 3 dimensions are supported"
# raise NotImplementedError(msg)
raise NotImplemented(f"Only RegularMesh are supported not {type(mesh)}")

# if score is not specified and tally has a single score then we know which score to use
if score is None:
Expand All @@ -112,19 +113,32 @@ def plot_mesh_tally(

tally_slice = tally.get_slice(scores=[score])

# if mesh.n_dimension == 3:
basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis]

if 1 in mesh.dimension:
index_of_2d = mesh.dimension.index(1)
axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d]
if (
axis_of_2d in basis
): # checks if the axis is being plotted, e.g is 'x' in 'xy'
raise ValueError(
"The selected tally has a mesh that has 1 dimension in the "
f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis "
f"of {basis}."
)

# todo check if 1 appears twice or three times, raise value error if so
# TODO check if 1 appears twice or three times, raise value error if so

tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value).squeeze()
tally_data = tally_slice.get_reshaped_data(
expand_dims=True, value=value
) # .squeeze()

basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis]
if len(tally_data.shape) == 3:
tally_data = _squeeze_end_of_array(tally_data, dims_required=3)

# if len(tally_data.shape) == 3:
if mesh.n_dimension == 3:
if slice_index is None:
# finds the mid index
slice_index = int(tally_data.shape[basis_to_index] / 2)

if basis == "xz":
Expand All @@ -139,29 +153,11 @@ def plot_mesh_tally(
slice_data = tally_data[:, :, slice_index]
data = np.rot90(slice_data, -3)
xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"
# elif mesh.n_dimension == 2:
elif len(tally_data.shape) == 2:
if basis_to_index == index_of_2d:
slice_data = tally_data[:, :]
if basis == "xz":
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]"
elif basis == "yz":
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]"
else: # basis == 'xy'
data = np.rot90(slice_data, -3)
xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"

else:
raise ValueError(
"The selected tally has a mesh that has 1 dimension in the "
f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis "
f"of {basis}."
)

else:
raise ValueError("mesh n_dimension")
raise ValueError(
f"mesh n_dimension is not 3 or 2 but is {mesh.n_dimension} which is not supported"
)

if volume_normalization:
# in a regular mesh all volumes are the same so we just divide by the first
Expand Down
4 changes: 3 additions & 1 deletion tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def test_plot_2d_mesh_tally(model):
tally_result = statepoint.get_tally(name="mesh-tal")

plot = plot_mesh_tally(
tally=tally_result, basis="yz", slice_index=29 # max value of slice selected
tally=tally_result, basis="yz", slice_index=0 # max value of slice selected
)

plot = plot_mesh_tally(tally=tally_result, basis="yz")
# axis_units defaults to cm
assert plot.xaxis.get_label().get_text() == "y [cm]"
assert plot.yaxis.get_label().get_text() == "z [cm]"
Expand Down