Skip to content

Commit

Permalink
Merge pull request #57 from fusion-energy/supporting_2d_meshes
Browse files Browse the repository at this point in the history
Supporting 2d meshes
  • Loading branch information
shimwell authored Sep 11, 2023
2 parents 0b979af + 8778e4f commit e8ec9f5
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 56 deletions.
83 changes: 83 additions & 0 deletions examples/plot_minimal_2d_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import openmc
from matplotlib.colors import LogNorm
from openmc_regular_mesh_plotter import plot_mesh_tally


# MATERIALS
mat_1 = openmc.Material()
mat_1.add_element("Li", 1)
mat_1.set_density("g/cm3", 0.45)
my_materials = openmc.Materials([mat_1])

# GEOMETRY
# surfaces
inner_surface = openmc.Sphere(r=200)
outer_surface = openmc.Sphere(r=400, boundary_type="vacuum")
# regions
inner_region = -inner_surface
outer_region = -outer_surface & +inner_surface
# cells
inner_cell = openmc.Cell(region=inner_region)
outer_cell = openmc.Cell(region=outer_region)
outer_cell.fill = mat_1
my_geometry = openmc.Geometry([inner_cell, outer_cell])

# SIMULATION SETTINGS
my_settings = openmc.Settings()
my_settings.batches = 10
my_settings.inactive = 0
my_settings.particles = 5000
my_settings.run_mode = "fixed source"
# Create a DT point source
try:
source = openmc.IndependentSource()
except:
# work with older versions of openmc
source = openmc.Source()
source.space = openmc.stats.Point((100, 0, 0))
source.angle = openmc.stats.Isotropic()
source.energy = openmc.stats.Discrete([14e6], [1])
my_settings.source = source

# Tallies
my_tallies = openmc.Tallies()
mesh = openmc.RegularMesh().from_domain(
my_geometry, # the corners of the mesh are being set automatically to surround the geometry
dimension=[1, 40, 40],
)
mesh_filter = openmc.MeshFilter(mesh)
mesh_tally_1 = openmc.Tally(name="mesh_tally")
mesh_tally_1.filters = [mesh_filter]
mesh_tally_1.scores = ["heating"]
my_tallies.append(mesh_tally_1)

model = openmc.model.Model(my_geometry, my_materials, my_settings, my_tallies)
sp_filename = model.run()

# post process simulation result
statepoint = openmc.StatePoint(sp_filename)

# extracts the mesh tally by name
my_mesh_tally = statepoint.get_tally(name="mesh_tally")

# default tally units for heating are in eV per source neutron
# for this example plot we want Mega Joules per second per cm3 or Mjcm^-3s^-1
neutrons_per_second = 1e21
eV_to_joules = 1.60218e-19
joules_to_mega_joules = 1e-6
scaling_factor = neutrons_per_second * eV_to_joules * joules_to_mega_joules
# note that volume_normalization is enabled so this will also change the units to divide by the volume of each mesh voxel
# alternatively you could set volume_normalization to false and divide by the mesh.volume[0][0][0] in the scaling factor
# in a regular mesh all the voxels have the same volume so the [0][0][0] just picks the first volume

plot = plot_mesh_tally(
basis="yz", # as the mesh dimention is [1,40,40] only the yz basis can be plotted
tally=my_mesh_tally,
outline=True, # enables an outline around the geometry
geometry=my_geometry, # needed for outline
norm=LogNorm(), # log scale
colorbar=False,
)

plot.figure.savefig("example_openmc_2d_regular_mesh_plotter.png")
plot.title.set_text("")
17 changes: 2 additions & 15 deletions examples/plot_minimal_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,10 @@

plot = plot_mesh_tally(
tally=my_mesh_tally,
basis="xz",
# slice_index=11, # middle value of slice selected automatically, but you can set the slide index if preferred
axis_units="m", # set to meters otherwise this defaults to cm
score="heating", # as we just have one score this could be missed out and found automatically
value="mean", # set to mean but could also be set to std_dev
outline=True, # enables an outline around the geometry
geometry=my_geometry,
outline_by="material",
colorbar_kwargs={"label": "Heating MJ/cm3/s"},
outline_kwargs={
"colors": "grey",
"linewidths": 2,
}, # setting the outline color and thickness, otherwise this defaults to black and 1
pixels=6000000, # this controls the resolution of the outline
geometry=my_geometry, # needed for outline
norm=LogNorm(), # log scale
scaling_factor=scaling_factor, # multiplies the tally result by scaling_factor
volume_normalization=True,
colorbar=False,
)

plot.figure.savefig("example_openmc_regular_mesh_plotter.png")
Expand Down
112 changes: 78 additions & 34 deletions src/openmc_regular_mesh_plotter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def plot_mesh_tally(
raise NotImplemented(
f"Only RegularMesh are currently supported not {type(mesh)}"
)
if mesh.n_dimension != 3:
msg = "Your mesh has {mesh.n_dimension} dimension and currently only RegularMesh with 3 dimensions are supported"
raise NotImplementedError(msg)
# if mesh.n_dimension != 3:
# msg = "Your mesh has {mesh.n_dimension} dimension and currently only RegularMesh with 3 dimensions are supported"
# raise NotImplementedError(msg)

# if score is not specified and tally has a single score then we know which score to use
if score is None:
Expand All @@ -112,24 +112,59 @@ def plot_mesh_tally(

tally_slice = tally.get_slice(scores=[score])

# if mesh.n_dimension == 3:

if 1 in mesh.dimension:
index_of_2d = mesh.dimension.index(1)
axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d]

# todo check if 1 appears twice or three times, raise value error if so

tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value).squeeze()

if slice_index is None:
basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis]
slice_index = int(tally_data.shape[basis_to_index] / 2)

if basis == "xz":
slice_data = tally_data[:, slice_index, :]
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]"
elif basis == "yz":
slice_data = tally_data[slice_index, :, :]
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]"
else: # basis == 'xy'
slice_data = tally_data[:, :, slice_index]
data = np.rot90(slice_data, -3)
xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"
basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis]
if len(tally_data.shape) == 3:
if slice_index is None:
slice_index = int(tally_data.shape[basis_to_index] / 2)

if basis == "xz":
slice_data = tally_data[:, slice_index, :]
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]"
elif basis == "yz":
slice_data = tally_data[slice_index, :, :]
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]"
else: # basis == 'xy'
slice_data = tally_data[:, :, slice_index]
data = np.rot90(slice_data, -3)
xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"
# elif mesh.n_dimension == 2:
elif len(tally_data.shape) == 2:
print("got here")
if basis_to_index == index_of_2d:
print("good basis selected", basis)

slice_data = tally_data[:, :]
if basis == "xz":
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]"
elif basis == "yz":
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]"
else: # basis == 'xy'
data = np.rot90(slice_data, -3)
xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"

else:
raise ValueError(
"The selected tally has a mesh that has 1 dimension in the "
f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis "
f"of {basis}."
)

else:
raise ValueError("mesh n_dimension")

if volume_normalization:
# in a regular mesh all volumes are the same so we just divide by the first
Expand Down Expand Up @@ -166,25 +201,34 @@ def plot_mesh_tally(
center_of_mesh = mesh.bounding_box.center
if basis == "xy":
zarr = np.linspace(z0, z1, nz + 1)
center_of_mesh_slice = [
center_of_mesh[0],
center_of_mesh[1],
(zarr[slice_index] + zarr[slice_index + 1]) / 2,
]
if len(tally_data.shape) == 3:
center_of_mesh_slice = [
center_of_mesh[0],
center_of_mesh[1],
(zarr[slice_index] + zarr[slice_index + 1]) / 2,
]
else: # 2
center_of_mesh_slice = mesh.bounding_box.center
if basis == "xz":
yarr = np.linspace(y0, y1, ny + 1)
center_of_mesh_slice = [
center_of_mesh[0],
(yarr[slice_index] + yarr[slice_index + 1]) / 2,
center_of_mesh[2],
]
if len(tally_data.shape) == 3:
center_of_mesh_slice = [
center_of_mesh[0],
(yarr[slice_index] + yarr[slice_index + 1]) / 2,
center_of_mesh[2],
]
else: # 2
center_of_mesh_slice = mesh.bounding_box.center
if basis == "yz":
xarr = np.linspace(x0, x1, nx + 1)
center_of_mesh_slice = [
(xarr[slice_index] + xarr[slice_index + 1]) / 2,
center_of_mesh[1],
center_of_mesh[2],
]
if len(tally_data.shape) == 3:
center_of_mesh_slice = [
(xarr[slice_index] + xarr[slice_index + 1]) / 2,
center_of_mesh[1],
center_of_mesh[2],
]
else: # 2
center_of_mesh_slice = mesh.bounding_box.center

model = openmc.Model()
model.geometry = geometry
Expand Down
68 changes: 61 additions & 7 deletions tests/test_units.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import openmc
from matplotlib.colors import LogNorm
from openmc_regular_mesh_plotter import plot_mesh_tally
import pytest


def test_plot_mesh_tally():
@pytest.fixture()
def model():
mat1 = openmc.Material()
mat1.add_nuclide("Li6", 1, percent_type="ao")
mats = openmc.Materials([mat1])
Expand Down Expand Up @@ -39,17 +41,26 @@ def test_plot_mesh_tally():
sett.run_mode = "fixed source"
sett.source = source

mesh = openmc.RegularMesh().from_domain(geom, dimension=[10, 20, 30])
model = openmc.Model(geom, mats, sett)

return model


def test_plot_3d_mesh_tally(model):
geometry = model.geometry

mesh = openmc.RegularMesh().from_domain(geometry, dimension=[10, 20, 30])
mesh_filter = openmc.MeshFilter(mesh)
mesh_tally = openmc.Tally(name="mesh-tal")
mesh_tally.filters = [mesh_filter]
mesh_tally.scores = ["flux"]
tallies = openmc.Tallies([mesh_tally])

model = openmc.Model(geom, mats, sett, tallies)
model.tallies = tallies

sp_filename = model.run()
statepoint = openmc.StatePoint(sp_filename)
tally_result = statepoint.get_tally(name="mesh-tal")
with openmc.StatePoint(sp_filename) as statepoint:
tally_result = statepoint.get_tally(name="mesh-tal")

plot = plot_mesh_tally(
tally=tally_result, basis="xy", slice_index=29 # max value of slice selected
Expand All @@ -64,7 +75,7 @@ def test_plot_mesh_tally():
tally=tally_result,
basis="yz",
axis_units="m",
slice_index=9, # max value of slice selected
# slice_index=9, # max value of slice selected
value="std_dev",
)
plot.figure.savefig("x.png")
Expand All @@ -81,7 +92,7 @@ def test_plot_mesh_tally():
score="flux",
value="mean",
outline=True,
geometry=geom,
geometry=geometry,
outline_by="material",
colorbar_kwargs={"label": "neutron flux"},
norm=LogNorm(vmin=1e-6, vmax=max(tally_result.mean.flatten())),
Expand All @@ -91,3 +102,46 @@ def test_plot_mesh_tally():
assert plot.get_xlim() == (-1000.0, 500) # note that units are in mm
assert plot.get_ylim() == (-3000.0, 3500.0)
plot.figure.savefig("z.png")


def test_plot_2d_mesh_tally(model):
geometry = model.geometry

mesh = openmc.RegularMesh().from_domain(geometry, dimension=[1, 20, 30])
mesh_filter = openmc.MeshFilter(mesh)
mesh_tally = openmc.Tally(name="mesh-tal")
mesh_tally.filters = [mesh_filter]
mesh_tally.scores = ["flux"]
tallies = openmc.Tallies([mesh_tally])

model.tallies = tallies

sp_filename = model.run()
with openmc.StatePoint(sp_filename) as statepoint:
tally_result = statepoint.get_tally(name="mesh-tal")

plot = plot_mesh_tally(
tally=tally_result, basis="yz", slice_index=29 # max value of slice selected
)
# axis_units defaults to cm
assert plot.xaxis.get_label().get_text() == "y [cm]"
assert plot.yaxis.get_label().get_text() == "z [cm]"
assert plot.get_xlim() == (-200.0, 250.0)
assert plot.get_ylim() == (-300.0, 350.0)
plot.figure.savefig("t.png")

plot = plot_mesh_tally(
tally=tally_result,
basis="yz",
axis_units="m",
# slice_index=9, # max value of slice selected
value="std_dev",
)
plot.figure.savefig("x.png")
assert plot.xaxis.get_label().get_text() == "y [m]"
assert plot.yaxis.get_label().get_text() == "z [m]"
assert plot.get_xlim() == (-2.0, 2.5) # note that units are in m
assert plot.get_ylim() == (-3.0, 3.5)


# todo catch errors when 2d mesh used and 1d axis selected for plotting'

0 comments on commit e8ec9f5

Please sign in to comment.