diff --git a/examples/plot_two_tallies_combined.py b/examples/plot_two_tallies_combined.py index e8383d9..10d6c87 100644 --- a/examples/plot_two_tallies_combined.py +++ b/examples/plot_two_tallies_combined.py @@ -38,18 +38,18 @@ # work with older versions of openmc source_n = openmc.Source() source_p = openmc.Source() - + source_n.space = openmc.stats.Point((200, 0, 0)) source_n.angle = openmc.stats.Isotropic() source_n.energy = openmc.stats.Discrete([0.1e6], [1]) source_n.strength = 1 -source_n.particle='neutron' +source_n.particle = "neutron" source_p.space = openmc.stats.Point((-200, 0, 0)) source_p.angle = openmc.stats.Isotropic() source_p.energy = openmc.stats.Discrete([10e6], [1]) source_p.strength = 10 -source_p.particle='photon' +source_p.particle = "photon" my_settings.source = [source_n, source_p] @@ -101,7 +101,7 @@ ) plot.title.set_text("neutron heating") plot.figure.savefig("neutron_regular_mesh_plotter.png") -print('written file neutron_regular_mesh_plotter.png') +print("written file neutron_regular_mesh_plotter.png") plot = plot_mesh_tally( tally=[my_mesh_tally_photon], @@ -110,7 +110,7 @@ ) plot.title.set_text("photon heating") plot.figure.savefig("photon_regular_mesh_plotter.png") -print('written file photon_regular_mesh_plotter.png') +print("written file photon_regular_mesh_plotter.png") plot = plot_mesh_tally( tally=[my_mesh_tally_photon, my_mesh_tally_neutron], @@ -119,4 +119,4 @@ ) plot.title.set_text("photon and neutron heating") plot.figure.savefig("photon_and_neutron_regular_mesh_plotter.png") -print('written file photon_and_neutron_regular_mesh_plotter.png') +print("written file photon_and_neutron_regular_mesh_plotter.png") diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py index 7247e91..6f219f8 100644 --- a/src/openmc_regular_mesh_plotter/core.py +++ b/src/openmc_regular_mesh_plotter/core.py @@ -108,12 +108,16 @@ def plot_mesh_tally( # TODO check the tallies use the same mesh mesh_ids.append(mesh.id) if not all(i == mesh_ids[0] for i in mesh_ids): - raise ValueError(f'mesh ids {mesh_ids} are different, please use same mesh when combining tallies') + raise ValueError( + f"mesh ids {mesh_ids} are different, please use same mesh when combining tallies" + ) else: mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh if isinstance(mesh, openmc.CylindricalMesh): - raise NotImplemented(f"Only RegularMesh are supported, not {type(mesh)}, try the openmc_cylindrical_mesh_plotter package available at https://github.com/fusion-energy/openmc_cylindrical_mesh_plotter/") + raise NotImplemented( + f"Only RegularMesh are supported, not {type(mesh)}, try the openmc_cylindrical_mesh_plotter package available at https://github.com/fusion-energy/openmc_cylindrical_mesh_plotter/" + ) if not isinstance(mesh, openmc.RegularMesh): raise NotImplemented(f"Only RegularMesh are supported, not {type(mesh)}") @@ -145,7 +149,6 @@ def plot_mesh_tally( default_imshow_kwargs.update(kwargs) if isinstance(tally, typing.Sequence): - for counter, one_tally in enumerate(tally): new_data = _get_tally_data( scaling_factor, @@ -155,11 +158,11 @@ def plot_mesh_tally( value, volume_normalization, score, - slice_index + slice_index, ) if counter == 0: data = np.zeros(shape=new_data.shape) - data=data+new_data + data = data + new_data else: # single tally data = _get_tally_data( scaling_factor, @@ -169,7 +172,7 @@ def plot_mesh_tally( value, volume_normalization, score, - slice_index + slice_index, ) im = axes.imshow(data, extent=(x_min, x_max, y_min, y_max), **default_imshow_kwargs) @@ -295,17 +298,10 @@ def get_index_where(self, value: float, basis: str = "xy"): return slice_index + def _get_tally_data( - scaling_factor, - mesh, - basis, - tally, - value, - volume_normalization, - score, - slice_index + scaling_factor, mesh, basis, tally, value, volume_normalization, score, slice_index ): - # if score is not specified and tally has a single score then we know which score to use if score is None: if len(tally.scores) == 1: @@ -316,8 +312,6 @@ def _get_tally_data( tally_slice = tally.get_slice(scores=[score]) - - if 1 in mesh.dimension: index_of_2d = mesh.dimension.index(1) axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d] @@ -340,7 +334,6 @@ def _get_tally_data( # if len(tally_data.shape) == 3: if mesh.n_dimension == 3: - if basis == "xz": slice_data = tally_data[:, slice_index, :] data = np.flip(np.rot90(slice_data, -1)) @@ -362,4 +355,4 @@ def _get_tally_data( if scaling_factor: data = data * scaling_factor - return data \ No newline at end of file + return data diff --git a/tests/test_units.py b/tests/test_units.py index e4fceb1..b441cb8 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -170,7 +170,9 @@ def test_plot_two_mesh_tallies(model): tally_result_2 = statepoint.get_tally(name="mesh-tal-2") plot = plot_mesh_tally( - tally=[tally_result_1, tally_result_2], basis="yz", slice_index=0 # max value of slice selected + tally=[tally_result_1, tally_result_2], + basis="yz", + slice_index=0, # max value of slice selected ) plot = plot_mesh_tally(tally=[tally_result_1, tally_result_2], basis="yz") @@ -194,4 +196,5 @@ def test_plot_two_mesh_tallies(model): assert plot.get_xlim() == (-2.0, 2.5) # note that units are in m assert plot.get_ylim() == (-3.0, 3.5) + # todo catch errors when 2d mesh used and 1d axis selected for plotting'