diff --git a/README.md b/README.md
index a05870a..b6a1674 100644
--- a/README.md
+++ b/README.md
@@ -21,11 +21,10 @@ Features
:arrow_right_hook: Customisable by passing keywords to underlying matplotlib functions colorbar, contour and imshow
-:arrow_right_hook: supports further customisable by supporting ```matplotlib.rc```
-```python
-import matplotlib
-matplotlib.rc('font', **{'family' : 'normal', 'size' : 22})
-```
+:arrow_right_hook: supports further customisations throught ```matplotlib.rc()```
+
+:heavy_plus_sign: Add tally results together to get combined plot.
+
|||
# Local install
diff --git a/examples/plot_sweep_through_slice_indexes.py b/examples/plot_sweep_through_slice_indexes.py
index cc39561..9c321f2 100644
--- a/examples/plot_sweep_through_slice_indexes.py
+++ b/examples/plot_sweep_through_slice_indexes.py
@@ -7,6 +7,10 @@
from matplotlib.colors import LogNorm
from openmc_regular_mesh_plotter import plot_mesh_tally
from matplotlib import cm
+import matplotlib
+
+# sets the font for the axis
+matplotlib.rc("font", **{"family": "normal", "size": 22})
# MATERIALS
diff --git a/examples/plot_two_tallies_combined.py b/examples/plot_two_tallies_combined.py
new file mode 100644
index 0000000..e8383d9
--- /dev/null
+++ b/examples/plot_two_tallies_combined.py
@@ -0,0 +1,122 @@
+import openmc
+from matplotlib.colors import LogNorm
+from openmc_regular_mesh_plotter import plot_mesh_tally
+
+
+# MATERIALS
+mat_1 = openmc.Material()
+mat_1.add_element("Li", 1)
+mat_1.set_density("g/cm3", 0.1)
+my_materials = openmc.Materials([mat_1])
+
+# GEOMETRY
+# surfaces
+inner_surface = openmc.Sphere(r=200)
+outer_surface = openmc.Sphere(r=400, boundary_type="vacuum")
+# regions
+inner_region = -inner_surface
+outer_region = -outer_surface & +inner_surface
+# cells
+inner_cell = openmc.Cell(region=inner_region)
+outer_cell = openmc.Cell(region=outer_region)
+outer_cell.fill = mat_1
+my_geometry = openmc.Geometry([inner_cell, outer_cell])
+
+# SIMULATION SETTINGS
+my_settings = openmc.Settings()
+my_settings.batches = 10
+my_settings.inactive = 0
+my_settings.particles = 50000
+my_settings.run_mode = "fixed source"
+# my_settings.photon_transport = True # could be enabled but we have a photon source instead which converges a photon head deposition quicker
+
+# Create a neutron and photon source
+try:
+ source_n = openmc.IndependentSource()
+ source_p = openmc.IndependentSource()
+except:
+ # work with older versions of openmc
+ source_n = openmc.Source()
+ source_p = openmc.Source()
+
+source_n.space = openmc.stats.Point((200, 0, 0))
+source_n.angle = openmc.stats.Isotropic()
+source_n.energy = openmc.stats.Discrete([0.1e6], [1])
+source_n.strength = 1
+source_n.particle='neutron'
+
+source_p.space = openmc.stats.Point((-200, 0, 0))
+source_p.angle = openmc.stats.Isotropic()
+source_p.energy = openmc.stats.Discrete([10e6], [1])
+source_p.strength = 10
+source_p.particle='photon'
+
+my_settings.source = [source_n, source_p]
+
+# Tallies
+mesh = openmc.RegularMesh().from_domain(
+ my_geometry, # the corners of the mesh are being set automatically to surround the geometry
+ dimension=[40, 40, 40],
+)
+
+mesh_filter = openmc.MeshFilter(mesh)
+neutron_filter = openmc.ParticleFilter("neutron")
+photon_filter = openmc.ParticleFilter("photon")
+
+mesh_tally_1 = openmc.Tally(name="mesh_tally_neutron")
+mesh_tally_1.filters = [mesh_filter, neutron_filter]
+mesh_tally_1.scores = ["heating"]
+
+mesh_tally_2 = openmc.Tally(name="mesh_tally_photon")
+mesh_tally_2.filters = [mesh_filter, photon_filter]
+mesh_tally_2.scores = ["heating"]
+
+my_tallies = openmc.Tallies([mesh_tally_1, mesh_tally_2])
+
+
+model = openmc.model.Model(my_geometry, my_materials, my_settings, my_tallies)
+sp_filename = model.run()
+
+# post process simulation result
+statepoint = openmc.StatePoint(sp_filename)
+
+# extracts the mesh tally by name
+my_mesh_tally_photon = statepoint.get_tally(name="mesh_tally_photon")
+my_mesh_tally_neutron = statepoint.get_tally(name="mesh_tally_neutron")
+
+# default tally units for heating are in eV per source neutron
+# for this example plot we want Mega Joules per second per cm3 or Mjcm^-3s^-1
+neutrons_per_second = 1e21
+eV_to_joules = 1.60218e-19
+joules_to_mega_joules = 1e-6
+scaling_factor = neutrons_per_second * eV_to_joules * joules_to_mega_joules
+# note that volume_normalization is enabled so this will also change the units to divide by the volume of each mesh voxel
+# alternatively you could set volume_normalization to false and divide by the mesh.volume[0][0][0] in the scaling factor
+# in a regular mesh all the voxels have the same volume so the [0][0][0] just picks the first volume
+
+plot = plot_mesh_tally(
+ tally=[my_mesh_tally_neutron],
+ colorbar=True,
+ # norm=LogNorm()
+)
+plot.title.set_text("neutron heating")
+plot.figure.savefig("neutron_regular_mesh_plotter.png")
+print('written file neutron_regular_mesh_plotter.png')
+
+plot = plot_mesh_tally(
+ tally=[my_mesh_tally_photon],
+ colorbar=True,
+ # norm=LogNorm()
+)
+plot.title.set_text("photon heating")
+plot.figure.savefig("photon_regular_mesh_plotter.png")
+print('written file photon_regular_mesh_plotter.png')
+
+plot = plot_mesh_tally(
+ tally=[my_mesh_tally_photon, my_mesh_tally_neutron],
+ colorbar=True,
+ # norm=LogNorm()
+)
+plot.title.set_text("photon and neutron heating")
+plot.figure.savefig("photon_and_neutron_regular_mesh_plotter.png")
+print('written file photon_and_neutron_regular_mesh_plotter.png')
diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py
index d0f5374..5f8e69f 100644
--- a/src/openmc_regular_mesh_plotter/core.py
+++ b/src/openmc_regular_mesh_plotter/core.py
@@ -1,7 +1,7 @@
import math
from pathlib import Path
from tempfile import TemporaryDirectory
-from typing import Optional
+import typing
import openmc
import numpy as np
import openmc
@@ -29,20 +29,20 @@ def _squeeze_end_of_array(array, dims_required=3):
def plot_mesh_tally(
- tally: "openmc.Tally",
+ tally: typing.Union["openmc.Tally", typing.Sequence["openmc.Tally"]],
basis: str = "xy",
- slice_index: Optional[int] = None,
- score: Optional[str] = None,
- axes: Optional[str] = None,
+ slice_index: typing.Optional[int] = None,
+ score: typing.Optional[str] = None,
+ axes: typing.Optional[str] = None,
axis_units: str = "cm",
value: str = "mean",
outline: bool = False,
outline_by: str = "cell",
- geometry: Optional["openmc.Geometry"] = None,
+ geometry: typing.Optional["openmc.Geometry"] = None,
pixels: int = 40000,
colorbar: bool = True,
volume_normalization: bool = True,
- scaling_factor: Optional[float] = None,
+ scaling_factor: typing.Optional[float] = None,
colorbar_kwargs: dict = {},
outline_kwargs: dict = _default_outline_kwargs,
**kwargs,
@@ -101,72 +101,21 @@ def plot_mesh_tally(
cv.check_type("volume_normalization", volume_normalization, bool)
cv.check_type("outline", outline, bool)
- mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh
- if not isinstance(mesh, openmc.RegularMesh):
- raise NotImplemented(f"Only RegularMesh are supported not {type(mesh)}")
-
- # if score is not specified and tally has a single score then we know which score to use
- if score is None:
- if len(tally.scores) == 1:
- score = tally.scores[0]
- else:
- msg = "score was not specified and there are multiple scores in the tally."
- raise ValueError(msg)
-
- tally_slice = tally.get_slice(scores=[score])
-
- basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis]
-
- if 1 in mesh.dimension:
- index_of_2d = mesh.dimension.index(1)
- axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d]
- if (
- axis_of_2d in basis
- ): # checks if the axis is being plotted, e.g is 'x' in 'xy'
- raise ValueError(
- "The selected tally has a mesh that has 1 dimension in the "
- f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis "
- f"of {basis}."
- )
-
- # TODO check if 1 appears twice or three times, raise value error if so
-
- tally_data = tally_slice.get_reshaped_data(
- expand_dims=True, value=value
- ) # .squeeze()
-
- tally_data = _squeeze_end_of_array(tally_data, dims_required=3)
-
- # if len(tally_data.shape) == 3:
- if mesh.n_dimension == 3:
- if slice_index is None:
- # finds the mid index
- slice_index = int(tally_data.shape[basis_to_index] / 2)
-
- if basis == "xz":
- slice_data = tally_data[:, slice_index, :]
- data = np.flip(np.rot90(slice_data, -1))
- xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]"
- elif basis == "yz":
- slice_data = tally_data[slice_index, :, :]
- data = np.flip(np.rot90(slice_data, -1))
- xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]"
- else: # basis == 'xy'
- slice_data = tally_data[:, :, slice_index]
- data = np.rot90(slice_data, -3)
- xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"
-
+ if isinstance(tally, typing.Sequence):
+ mesh_ids = []
+ for one_tally in tally:
+ mesh = one_tally.find_filter(filter_type=openmc.MeshFilter).mesh
+ # TODO check the tallies use the same mesh
+ mesh_ids.append(mesh.id)
+ if not all(i == mesh_ids[0] for i in mesh_ids):
+ raise ValueError(f'mesh ids {mesh_ids} are different, please use same mesh when combining tallies')
else:
- raise ValueError(
- f"mesh n_dimension is not 3 or 2 but is {mesh.n_dimension} which is not supported"
- )
-
- if volume_normalization:
- # in a regular mesh all volumes are the same so we just divide by the first
- data = data / mesh.volumes[0][0][0]
+ mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh
- if scaling_factor:
- data = data * scaling_factor
+ if isinstance(mesh, openmc.CylindricalMesh):
+ raise NotImplemented(f"Only RegularMesh are supported, not {type(mesh)}, try the openmc_cylindrical_mesh_plotter package available at https://github.com/fusion-energy/openmc_cylindrical_mesh_plotter/")
+ if not isinstance(mesh, openmc.RegularMesh):
+ raise NotImplemented(f"Only RegularMesh are supported, not {type(mesh)}")
axis_scaling_factor = {"km": 0.00001, "m": 0.01, "cm": 1, "mm": 10}[axis_units]
@@ -174,15 +123,53 @@ def plot_mesh_tally(
i * axis_scaling_factor for i in mesh.bounding_box.extent[basis]
]
+ if basis == "xz":
+ xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]"
+ elif basis == "yz":
+ xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]"
+ else: # basis == 'xy'
+ xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"
+
if axes is None:
fig, axes = plt.subplots()
axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)
+ basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis]
+ if slice_index is None:
+ # finds the mid index
+ slice_index = int(mesh.dimension[basis_to_index] / 2)
+
# zero values with logscale produce noise / fuzzy on the time but setting interpolation to none solves this
default_imshow_kwargs = {"interpolation": "none"}
default_imshow_kwargs.update(kwargs)
+ if isinstance(tally, typing.Sequence):
+ data = np.zeros(shape=(40,40))
+ for one_tally in tally:
+ new_data = _get_tally_data(
+ scaling_factor,
+ mesh,
+ basis,
+ one_tally,
+ value,
+ volume_normalization,
+ score,
+ slice_index
+ )
+ data=data+new_data
+ else: # single tally
+ data = _get_tally_data(
+ scaling_factor,
+ mesh,
+ basis,
+ tally,
+ value,
+ volume_normalization,
+ score,
+ slice_index
+ )
+
im = axes.imshow(data, extent=(x_min, x_max, y_min, y_max), **default_imshow_kwargs)
if colorbar:
@@ -198,36 +185,28 @@ def plot_mesh_tally(
x1, y1, z1 = mesh.upper_right
nx, ny, nz = mesh.dimension
center_of_mesh = mesh.bounding_box.center
+
if basis == "xy":
zarr = np.linspace(z0, z1, nz + 1)
- if len(tally_data.shape) == 3:
- center_of_mesh_slice = [
- center_of_mesh[0],
- center_of_mesh[1],
- (zarr[slice_index] + zarr[slice_index + 1]) / 2,
- ]
- else: # 2
- center_of_mesh_slice = mesh.bounding_box.center
+ center_of_mesh_slice = [
+ center_of_mesh[0],
+ center_of_mesh[1],
+ (zarr[slice_index] + zarr[slice_index + 1]) / 2,
+ ]
if basis == "xz":
yarr = np.linspace(y0, y1, ny + 1)
- if len(tally_data.shape) == 3:
- center_of_mesh_slice = [
- center_of_mesh[0],
- (yarr[slice_index] + yarr[slice_index + 1]) / 2,
- center_of_mesh[2],
- ]
- else: # 2
- center_of_mesh_slice = mesh.bounding_box.center
+ center_of_mesh_slice = [
+ center_of_mesh[0],
+ (yarr[slice_index] + yarr[slice_index + 1]) / 2,
+ center_of_mesh[2],
+ ]
if basis == "yz":
xarr = np.linspace(x0, x1, nx + 1)
- if len(tally_data.shape) == 3:
- center_of_mesh_slice = [
- (xarr[slice_index] + xarr[slice_index + 1]) / 2,
- center_of_mesh[1],
- center_of_mesh[2],
- ]
- else: # 2
- center_of_mesh_slice = mesh.bounding_box.center
+ center_of_mesh_slice = [
+ (xarr[slice_index] + xarr[slice_index + 1]) / 2,
+ center_of_mesh[1],
+ center_of_mesh[2],
+ ]
model = openmc.Model()
model.geometry = geometry
@@ -313,3 +292,72 @@ def get_index_where(self, value: float, basis: str = "xy"):
slice_index = (np.abs(voxel_axis_vals - value)).argmin()
return slice_index
+
+def _get_tally_data(
+ scaling_factor,
+ mesh,
+ basis,
+ tally,
+ value,
+ volume_normalization,
+ score,
+ slice_index
+):
+
+ # if score is not specified and tally has a single score then we know which score to use
+ if score is None:
+ if len(tally.scores) == 1:
+ score = tally.scores[0]
+ else:
+ msg = "score was not specified and there are multiple scores in the tally."
+ raise ValueError(msg)
+
+ tally_slice = tally.get_slice(scores=[score])
+
+
+
+ if 1 in mesh.dimension:
+ index_of_2d = mesh.dimension.index(1)
+ axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d]
+ if (
+ axis_of_2d in basis
+ ): # checks if the axis is being plotted, e.g is 'x' in 'xy'
+ raise ValueError(
+ "The selected tally has a mesh that has 1 dimension in the "
+ f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis "
+ f"of {basis}."
+ )
+
+ # TODO check if 1 appears twice or three times, raise value error if so
+
+ tally_data = tally_slice.get_reshaped_data(
+ expand_dims=True, value=value
+ ) # .squeeze()
+
+ tally_data = _squeeze_end_of_array(tally_data, dims_required=3)
+
+ # if len(tally_data.shape) == 3:
+ if mesh.n_dimension == 3:
+
+ if basis == "xz":
+ slice_data = tally_data[:, slice_index, :]
+ data = np.flip(np.rot90(slice_data, -1))
+ elif basis == "yz":
+ slice_data = tally_data[slice_index, :, :]
+ data = np.flip(np.rot90(slice_data, -1))
+ else: # basis == 'xy'
+ slice_data = tally_data[:, :, slice_index]
+ data = np.rot90(slice_data, -3)
+
+ else:
+ raise ValueError(
+ f"mesh n_dimension is not 3 but is {mesh.n_dimension} which is not supported"
+ )
+
+ if volume_normalization:
+ # in a regular mesh all volumes are the same so we just divide by the first
+ data = data / mesh.volumes[0][0][0]
+
+ if scaling_factor:
+ data = data * scaling_factor
+ return data
\ No newline at end of file