diff --git a/examples/plot_minimal_2d_example.py b/examples/plot_minimal_2d_example.py new file mode 100644 index 0000000..f290368 --- /dev/null +++ b/examples/plot_minimal_2d_example.py @@ -0,0 +1,83 @@ +import openmc +from matplotlib.colors import LogNorm +from openmc_regular_mesh_plotter import plot_mesh_tally + + +# MATERIALS +mat_1 = openmc.Material() +mat_1.add_element("Li", 1) +mat_1.set_density("g/cm3", 0.45) +my_materials = openmc.Materials([mat_1]) + +# GEOMETRY +# surfaces +inner_surface = openmc.Sphere(r=200) +outer_surface = openmc.Sphere(r=400, boundary_type="vacuum") +# regions +inner_region = -inner_surface +outer_region = -outer_surface & +inner_surface +# cells +inner_cell = openmc.Cell(region=inner_region) +outer_cell = openmc.Cell(region=outer_region) +outer_cell.fill = mat_1 +my_geometry = openmc.Geometry([inner_cell, outer_cell]) + +# SIMULATION SETTINGS +my_settings = openmc.Settings() +my_settings.batches = 10 +my_settings.inactive = 0 +my_settings.particles = 5000 +my_settings.run_mode = "fixed source" +# Create a DT point source +try: + source = openmc.IndependentSource() +except: + # work with older versions of openmc + source = openmc.Source() +source.space = openmc.stats.Point((100, 0, 0)) +source.angle = openmc.stats.Isotropic() +source.energy = openmc.stats.Discrete([14e6], [1]) +my_settings.source = source + +# Tallies +my_tallies = openmc.Tallies() +mesh = openmc.RegularMesh().from_domain( + my_geometry, # the corners of the mesh are being set automatically to surround the geometry + dimension=[1, 40, 40], +) +mesh_filter = openmc.MeshFilter(mesh) +mesh_tally_1 = openmc.Tally(name="mesh_tally") +mesh_tally_1.filters = [mesh_filter] +mesh_tally_1.scores = ["heating"] +my_tallies.append(mesh_tally_1) + +model = openmc.model.Model(my_geometry, my_materials, my_settings, my_tallies) +sp_filename = model.run() + +# post process simulation result +statepoint = openmc.StatePoint(sp_filename) + +# extracts the mesh tally by name +my_mesh_tally = statepoint.get_tally(name="mesh_tally") + +# default tally units for heating are in eV per source neutron +# for this example plot we want Mega Joules per second per cm3 or Mjcm^-3s^-1 +neutrons_per_second = 1e21 +eV_to_joules = 1.60218e-19 +joules_to_mega_joules = 1e-6 +scaling_factor = neutrons_per_second * eV_to_joules * joules_to_mega_joules +# note that volume_normalization is enabled so this will also change the units to divide by the volume of each mesh voxel +# alternatively you could set volume_normalization to false and divide by the mesh.volume[0][0][0] in the scaling factor +# in a regular mesh all the voxels have the same volume so the [0][0][0] just picks the first volume + +plot = plot_mesh_tally( + basis="yz", # as the mesh dimention is [1,40,40] only the yz basis can be plotted + tally=my_mesh_tally, + outline=True, # enables an outline around the geometry + geometry=my_geometry, # needed for outline + norm=LogNorm(), # log scale + colorbar=False, +) + +plot.figure.savefig("example_openmc_2d_regular_mesh_plotter.png") +plot.title.set_text("") diff --git a/examples/plot_minimal_example.py b/examples/plot_minimal_example.py index 661b29b..04ca4af 100644 --- a/examples/plot_minimal_example.py +++ b/examples/plot_minimal_example.py @@ -72,23 +72,10 @@ plot = plot_mesh_tally( tally=my_mesh_tally, - basis="xz", - # slice_index=11, # middle value of slice selected automatically, but you can set the slide index if preferred - axis_units="m", # set to meters otherwise this defaults to cm - score="heating", # as we just have one score this could be missed out and found automatically - value="mean", # set to mean but could also be set to std_dev outline=True, # enables an outline around the geometry - geometry=my_geometry, - outline_by="material", - colorbar_kwargs={"label": "Heating MJ/cm3/s"}, - outline_kwargs={ - "colors": "grey", - "linewidths": 2, - }, # setting the outline color and thickness, otherwise this defaults to black and 1 - pixels=6000000, # this controls the resolution of the outline + geometry=my_geometry, # needed for outline norm=LogNorm(), # log scale - scaling_factor=scaling_factor, # multiplies the tally result by scaling_factor - volume_normalization=True, + colorbar=False, ) plot.figure.savefig("example_openmc_regular_mesh_plotter.png") diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py index 26ee548..f4f2c1e 100644 --- a/src/openmc_regular_mesh_plotter/core.py +++ b/src/openmc_regular_mesh_plotter/core.py @@ -98,9 +98,9 @@ def plot_mesh_tally( raise NotImplemented( f"Only RegularMesh are currently supported not {type(mesh)}" ) - if mesh.n_dimension != 3: - msg = "Your mesh has {mesh.n_dimension} dimension and currently only RegularMesh with 3 dimensions are supported" - raise NotImplementedError(msg) + # if mesh.n_dimension != 3: + # msg = "Your mesh has {mesh.n_dimension} dimension and currently only RegularMesh with 3 dimensions are supported" + # raise NotImplementedError(msg) # if score is not specified and tally has a single score then we know which score to use if score is None: @@ -112,24 +112,59 @@ def plot_mesh_tally( tally_slice = tally.get_slice(scores=[score]) + # if mesh.n_dimension == 3: + + if 1 in mesh.dimension: + index_of_2d = mesh.dimension.index(1) + axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d] + + # todo check if 1 appears twice or three times, raise value error if so + tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value).squeeze() - if slice_index is None: - basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis] - slice_index = int(tally_data.shape[basis_to_index] / 2) - - if basis == "xz": - slice_data = tally_data[:, slice_index, :] - data = np.flip(np.rot90(slice_data, -1)) - xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]" - elif basis == "yz": - slice_data = tally_data[slice_index, :, :] - data = np.flip(np.rot90(slice_data, -1)) - xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]" - else: # basis == 'xy' - slice_data = tally_data[:, :, slice_index] - data = np.rot90(slice_data, -3) - xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" + basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis] + if len(tally_data.shape) == 3: + if slice_index is None: + slice_index = int(tally_data.shape[basis_to_index] / 2) + + if basis == "xz": + slice_data = tally_data[:, slice_index, :] + data = np.flip(np.rot90(slice_data, -1)) + xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]" + elif basis == "yz": + slice_data = tally_data[slice_index, :, :] + data = np.flip(np.rot90(slice_data, -1)) + xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]" + else: # basis == 'xy' + slice_data = tally_data[:, :, slice_index] + data = np.rot90(slice_data, -3) + xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" + # elif mesh.n_dimension == 2: + elif len(tally_data.shape) == 2: + print("got here") + if basis_to_index == index_of_2d: + print("good basis selected", basis) + + slice_data = tally_data[:, :] + if basis == "xz": + data = np.flip(np.rot90(slice_data, -1)) + xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]" + elif basis == "yz": + data = np.flip(np.rot90(slice_data, -1)) + xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]" + else: # basis == 'xy' + data = np.rot90(slice_data, -3) + xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" + + else: + raise ValueError( + "The selected tally has a mesh that has 1 dimension in the " + f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis " + f"of {basis}." + ) + + else: + raise ValueError("mesh n_dimension") if volume_normalization: # in a regular mesh all volumes are the same so we just divide by the first @@ -166,25 +201,34 @@ def plot_mesh_tally( center_of_mesh = mesh.bounding_box.center if basis == "xy": zarr = np.linspace(z0, z1, nz + 1) - center_of_mesh_slice = [ - center_of_mesh[0], - center_of_mesh[1], - (zarr[slice_index] + zarr[slice_index + 1]) / 2, - ] + if len(tally_data.shape) == 3: + center_of_mesh_slice = [ + center_of_mesh[0], + center_of_mesh[1], + (zarr[slice_index] + zarr[slice_index + 1]) / 2, + ] + else: # 2 + center_of_mesh_slice = mesh.bounding_box.center if basis == "xz": yarr = np.linspace(y0, y1, ny + 1) - center_of_mesh_slice = [ - center_of_mesh[0], - (yarr[slice_index] + yarr[slice_index + 1]) / 2, - center_of_mesh[2], - ] + if len(tally_data.shape) == 3: + center_of_mesh_slice = [ + center_of_mesh[0], + (yarr[slice_index] + yarr[slice_index + 1]) / 2, + center_of_mesh[2], + ] + else: # 2 + center_of_mesh_slice = mesh.bounding_box.center if basis == "yz": xarr = np.linspace(x0, x1, nx + 1) - center_of_mesh_slice = [ - (xarr[slice_index] + xarr[slice_index + 1]) / 2, - center_of_mesh[1], - center_of_mesh[2], - ] + if len(tally_data.shape) == 3: + center_of_mesh_slice = [ + (xarr[slice_index] + xarr[slice_index + 1]) / 2, + center_of_mesh[1], + center_of_mesh[2], + ] + else: # 2 + center_of_mesh_slice = mesh.bounding_box.center model = openmc.Model() model.geometry = geometry diff --git a/tests/test_units.py b/tests/test_units.py index 900f689..c4528a7 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -1,9 +1,11 @@ import openmc from matplotlib.colors import LogNorm from openmc_regular_mesh_plotter import plot_mesh_tally +import pytest -def test_plot_mesh_tally(): +@pytest.fixture() +def model(): mat1 = openmc.Material() mat1.add_nuclide("Li6", 1, percent_type="ao") mats = openmc.Materials([mat1]) @@ -39,17 +41,26 @@ def test_plot_mesh_tally(): sett.run_mode = "fixed source" sett.source = source - mesh = openmc.RegularMesh().from_domain(geom, dimension=[10, 20, 30]) + model = openmc.Model(geom, mats, sett) + + return model + + +def test_plot_3d_mesh_tally(model): + geometry = model.geometry + + mesh = openmc.RegularMesh().from_domain(geometry, dimension=[10, 20, 30]) mesh_filter = openmc.MeshFilter(mesh) mesh_tally = openmc.Tally(name="mesh-tal") mesh_tally.filters = [mesh_filter] mesh_tally.scores = ["flux"] tallies = openmc.Tallies([mesh_tally]) - model = openmc.Model(geom, mats, sett, tallies) + model.tallies = tallies + sp_filename = model.run() - statepoint = openmc.StatePoint(sp_filename) - tally_result = statepoint.get_tally(name="mesh-tal") + with openmc.StatePoint(sp_filename) as statepoint: + tally_result = statepoint.get_tally(name="mesh-tal") plot = plot_mesh_tally( tally=tally_result, basis="xy", slice_index=29 # max value of slice selected @@ -64,7 +75,7 @@ def test_plot_mesh_tally(): tally=tally_result, basis="yz", axis_units="m", - slice_index=9, # max value of slice selected + # slice_index=9, # max value of slice selected value="std_dev", ) plot.figure.savefig("x.png") @@ -81,7 +92,7 @@ def test_plot_mesh_tally(): score="flux", value="mean", outline=True, - geometry=geom, + geometry=geometry, outline_by="material", colorbar_kwargs={"label": "neutron flux"}, norm=LogNorm(vmin=1e-6, vmax=max(tally_result.mean.flatten())), @@ -91,3 +102,46 @@ def test_plot_mesh_tally(): assert plot.get_xlim() == (-1000.0, 500) # note that units are in mm assert plot.get_ylim() == (-3000.0, 3500.0) plot.figure.savefig("z.png") + + +def test_plot_2d_mesh_tally(model): + geometry = model.geometry + + mesh = openmc.RegularMesh().from_domain(geometry, dimension=[1, 20, 30]) + mesh_filter = openmc.MeshFilter(mesh) + mesh_tally = openmc.Tally(name="mesh-tal") + mesh_tally.filters = [mesh_filter] + mesh_tally.scores = ["flux"] + tallies = openmc.Tallies([mesh_tally]) + + model.tallies = tallies + + sp_filename = model.run() + with openmc.StatePoint(sp_filename) as statepoint: + tally_result = statepoint.get_tally(name="mesh-tal") + + plot = plot_mesh_tally( + tally=tally_result, basis="yz", slice_index=29 # max value of slice selected + ) + # axis_units defaults to cm + assert plot.xaxis.get_label().get_text() == "y [cm]" + assert plot.yaxis.get_label().get_text() == "z [cm]" + assert plot.get_xlim() == (-200.0, 250.0) + assert plot.get_ylim() == (-300.0, 350.0) + plot.figure.savefig("t.png") + + plot = plot_mesh_tally( + tally=tally_result, + basis="yz", + axis_units="m", + # slice_index=9, # max value of slice selected + value="std_dev", + ) + plot.figure.savefig("x.png") + assert plot.xaxis.get_label().get_text() == "y [m]" + assert plot.yaxis.get_label().get_text() == "z [m]" + assert plot.get_xlim() == (-2.0, 2.5) # note that units are in m + assert plot.get_ylim() == (-3.0, 3.5) + + +# todo catch errors when 2d mesh used and 1d axis selected for plotting'