Skip to content

Commit

Permalink
format conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
shimwell committed Nov 27, 2023
2 parents eccc915 + 10a9ab7 commit ce984dc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 26 deletions.
12 changes: 6 additions & 6 deletions examples/plot_two_tallies_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@
# work with older versions of openmc
source_n = openmc.Source()
source_p = openmc.Source()

source_n.space = openmc.stats.Point((200, 0, 0))
source_n.angle = openmc.stats.Isotropic()
source_n.energy = openmc.stats.Discrete([0.1e6], [1])
source_n.strength = 1
source_n.particle='neutron'
source_n.particle = "neutron"

source_p.space = openmc.stats.Point((-200, 0, 0))
source_p.angle = openmc.stats.Isotropic()
source_p.energy = openmc.stats.Discrete([10e6], [1])
source_p.strength = 10
source_p.particle='photon'
source_p.particle = "photon"

my_settings.source = [source_n, source_p]

Expand Down Expand Up @@ -101,7 +101,7 @@
)
plot.title.set_text("neutron heating")
plot.figure.savefig("neutron_regular_mesh_plotter.png")
print('written file neutron_regular_mesh_plotter.png')
print("written file neutron_regular_mesh_plotter.png")

plot = plot_mesh_tally(
tally=[my_mesh_tally_photon],
Expand All @@ -110,7 +110,7 @@
)
plot.title.set_text("photon heating")
plot.figure.savefig("photon_regular_mesh_plotter.png")
print('written file photon_regular_mesh_plotter.png')
print("written file photon_regular_mesh_plotter.png")

plot = plot_mesh_tally(
tally=[my_mesh_tally_photon, my_mesh_tally_neutron],
Expand All @@ -119,4 +119,4 @@
)
plot.title.set_text("photon and neutron heating")
plot.figure.savefig("photon_and_neutron_regular_mesh_plotter.png")
print('written file photon_and_neutron_regular_mesh_plotter.png')
print("written file photon_and_neutron_regular_mesh_plotter.png")
31 changes: 12 additions & 19 deletions src/openmc_regular_mesh_plotter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,16 @@ def plot_mesh_tally(
# TODO check the tallies use the same mesh
mesh_ids.append(mesh.id)
if not all(i == mesh_ids[0] for i in mesh_ids):
raise ValueError(f'mesh ids {mesh_ids} are different, please use same mesh when combining tallies')
raise ValueError(
f"mesh ids {mesh_ids} are different, please use same mesh when combining tallies"
)
else:
mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh

if isinstance(mesh, openmc.CylindricalMesh):
raise NotImplemented(f"Only RegularMesh are supported, not {type(mesh)}, try the openmc_cylindrical_mesh_plotter package available at https://github.com/fusion-energy/openmc_cylindrical_mesh_plotter/")
raise NotImplemented(
f"Only RegularMesh are supported, not {type(mesh)}, try the openmc_cylindrical_mesh_plotter package available at https://github.com/fusion-energy/openmc_cylindrical_mesh_plotter/"
)
if not isinstance(mesh, openmc.RegularMesh):
raise NotImplemented(f"Only RegularMesh are supported, not {type(mesh)}")

Expand Down Expand Up @@ -145,7 +149,6 @@ def plot_mesh_tally(
default_imshow_kwargs.update(kwargs)

if isinstance(tally, typing.Sequence):

for counter, one_tally in enumerate(tally):
new_data = _get_tally_data(
scaling_factor,
Expand All @@ -155,11 +158,11 @@ def plot_mesh_tally(
value,
volume_normalization,
score,
slice_index
slice_index,
)
if counter == 0:
data = np.zeros(shape=new_data.shape)
data=data+new_data
data = data + new_data
else: # single tally
data = _get_tally_data(
scaling_factor,
Expand All @@ -169,7 +172,7 @@ def plot_mesh_tally(
value,
volume_normalization,
score,
slice_index
slice_index,
)

im = axes.imshow(data, extent=(x_min, x_max, y_min, y_max), **default_imshow_kwargs)
Expand Down Expand Up @@ -295,17 +298,10 @@ def get_index_where(self, value: float, basis: str = "xy"):

return slice_index


def _get_tally_data(
scaling_factor,
mesh,
basis,
tally,
value,
volume_normalization,
score,
slice_index
scaling_factor, mesh, basis, tally, value, volume_normalization, score, slice_index
):

# if score is not specified and tally has a single score then we know which score to use
if score is None:
if len(tally.scores) == 1:
Expand All @@ -316,8 +312,6 @@ def _get_tally_data(

tally_slice = tally.get_slice(scores=[score])



if 1 in mesh.dimension:
index_of_2d = mesh.dimension.index(1)
axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d]
Expand All @@ -340,7 +334,6 @@ def _get_tally_data(

# if len(tally_data.shape) == 3:
if mesh.n_dimension == 3:

if basis == "xz":
slice_data = tally_data[:, slice_index, :]
data = np.flip(np.rot90(slice_data, -1))
Expand All @@ -362,4 +355,4 @@ def _get_tally_data(

if scaling_factor:
data = data * scaling_factor
return data
return data
5 changes: 4 additions & 1 deletion tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def test_plot_two_mesh_tallies(model):
tally_result_2 = statepoint.get_tally(name="mesh-tal-2")

plot = plot_mesh_tally(
tally=[tally_result_1, tally_result_2], basis="yz", slice_index=0 # max value of slice selected
tally=[tally_result_1, tally_result_2],
basis="yz",
slice_index=0, # max value of slice selected
)

plot = plot_mesh_tally(tally=[tally_result_1, tally_result_2], basis="yz")
Expand All @@ -194,4 +196,5 @@ def test_plot_two_mesh_tallies(model):
assert plot.get_xlim() == (-2.0, 2.5) # note that units are in m
assert plot.get_ylim() == (-3.0, 3.5)


# todo catch errors when 2d mesh used and 1d axis selected for plotting'

0 comments on commit ce984dc

Please sign in to comment.