diff --git a/examples/plot_two_tallies_combined.py b/examples/plot_two_tallies_combined.py index e8383d9..10d6c87 100644 --- a/examples/plot_two_tallies_combined.py +++ b/examples/plot_two_tallies_combined.py @@ -38,18 +38,18 @@ # work with older versions of openmc source_n = openmc.Source() source_p = openmc.Source() - + source_n.space = openmc.stats.Point((200, 0, 0)) source_n.angle = openmc.stats.Isotropic() source_n.energy = openmc.stats.Discrete([0.1e6], [1]) source_n.strength = 1 -source_n.particle='neutron' +source_n.particle = "neutron" source_p.space = openmc.stats.Point((-200, 0, 0)) source_p.angle = openmc.stats.Isotropic() source_p.energy = openmc.stats.Discrete([10e6], [1]) source_p.strength = 10 -source_p.particle='photon' +source_p.particle = "photon" my_settings.source = [source_n, source_p] @@ -101,7 +101,7 @@ ) plot.title.set_text("neutron heating") plot.figure.savefig("neutron_regular_mesh_plotter.png") -print('written file neutron_regular_mesh_plotter.png') +print("written file neutron_regular_mesh_plotter.png") plot = plot_mesh_tally( tally=[my_mesh_tally_photon], @@ -110,7 +110,7 @@ ) plot.title.set_text("photon heating") plot.figure.savefig("photon_regular_mesh_plotter.png") -print('written file photon_regular_mesh_plotter.png') +print("written file photon_regular_mesh_plotter.png") plot = plot_mesh_tally( tally=[my_mesh_tally_photon, my_mesh_tally_neutron], @@ -119,4 +119,4 @@ ) plot.title.set_text("photon and neutron heating") plot.figure.savefig("photon_and_neutron_regular_mesh_plotter.png") -print('written file photon_and_neutron_regular_mesh_plotter.png') +print("written file photon_and_neutron_regular_mesh_plotter.png") diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py index 5f8e69f..19e56bd 100644 --- a/src/openmc_regular_mesh_plotter/core.py +++ b/src/openmc_regular_mesh_plotter/core.py @@ -108,12 +108,16 @@ def plot_mesh_tally( # TODO check the tallies use the same mesh mesh_ids.append(mesh.id) if not all(i == mesh_ids[0] for i in mesh_ids): - raise ValueError(f'mesh ids {mesh_ids} are different, please use same mesh when combining tallies') + raise ValueError( + f"mesh ids {mesh_ids} are different, please use same mesh when combining tallies" + ) else: mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh if isinstance(mesh, openmc.CylindricalMesh): - raise NotImplemented(f"Only RegularMesh are supported, not {type(mesh)}, try the openmc_cylindrical_mesh_plotter package available at https://github.com/fusion-energy/openmc_cylindrical_mesh_plotter/") + raise NotImplemented( + f"Only RegularMesh are supported, not {type(mesh)}, try the openmc_cylindrical_mesh_plotter package available at https://github.com/fusion-energy/openmc_cylindrical_mesh_plotter/" + ) if not isinstance(mesh, openmc.RegularMesh): raise NotImplemented(f"Only RegularMesh are supported, not {type(mesh)}") @@ -145,7 +149,7 @@ def plot_mesh_tally( default_imshow_kwargs.update(kwargs) if isinstance(tally, typing.Sequence): - data = np.zeros(shape=(40,40)) + data = np.zeros(shape=(40, 40)) for one_tally in tally: new_data = _get_tally_data( scaling_factor, @@ -155,9 +159,9 @@ def plot_mesh_tally( value, volume_normalization, score, - slice_index + slice_index, ) - data=data+new_data + data = data + new_data else: # single tally data = _get_tally_data( scaling_factor, @@ -167,7 +171,7 @@ def plot_mesh_tally( value, volume_normalization, score, - slice_index + slice_index, ) im = axes.imshow(data, extent=(x_min, x_max, y_min, y_max), **default_imshow_kwargs) @@ -293,17 +297,10 @@ def get_index_where(self, value: float, basis: str = "xy"): return slice_index + def _get_tally_data( - scaling_factor, - mesh, - basis, - tally, - value, - volume_normalization, - score, - slice_index + scaling_factor, mesh, basis, tally, value, volume_normalization, score, slice_index ): - # if score is not specified and tally has a single score then we know which score to use if score is None: if len(tally.scores) == 1: @@ -314,8 +311,6 @@ def _get_tally_data( tally_slice = tally.get_slice(scores=[score]) - - if 1 in mesh.dimension: index_of_2d = mesh.dimension.index(1) axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d] @@ -338,7 +333,6 @@ def _get_tally_data( # if len(tally_data.shape) == 3: if mesh.n_dimension == 3: - if basis == "xz": slice_data = tally_data[:, slice_index, :] data = np.flip(np.rot90(slice_data, -1)) @@ -360,4 +354,4 @@ def _get_tally_data( if scaling_factor: data = data * scaling_factor - return data \ No newline at end of file + return data