Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sass/add mT to particle #264

Merged
merged 11 commits into from
Aug 6, 2024
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ The main categories for changes in this file are:

A `Deprecated` section could be added if needed for soon-to-be removed features.

## v.2.0.0
Date:

### Added
* Particle: Add member function to calculate the transverse mass of a particle
* Oscar: Add transverse mass cut to methods
* Jetscape: Add transverse mass cut to methods

### Changed
* Filter: Perform general clean up to reduce code duplications

## v1.3.0-Newton
Date: 2024-07-25

Expand Down
1 change: 1 addition & 0 deletions docs/source/classes/Particle/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Particle
.. automethod:: Particle.proper_time
.. automethod:: Particle.compute_mass_from_energy_momentum
.. automethod:: Particle.compute_charge_from_pdg
.. automethod:: Particle.mT
.. automethod:: Particle.is_meson
.. automethod:: Particle.is_baryon
.. automethod:: Particle.is_hadron
Expand Down
302 changes: 194 additions & 108 deletions src/sparkx/Filter.py

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions src/sparkx/Jetscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ def __apply_kwargs_filters(self, event, filters_dict):
event, filters_dict['lower_event_energy_cut'])
elif i == 'pt_cut':
event = pt_cut(event, filters_dict['pt_cut'])
elif i == 'mT_cut':
event = mT_cut(event, filters_dict['mT_cut'])
elif i == 'rapidity_cut':
event = rapidity_cut(event, filters_dict['rapidity_cut'])
elif i == 'pseudorapidity_cut':
Expand Down Expand Up @@ -739,6 +741,32 @@ def pt_cut(self, cut_value_tuple):

return self

def mT_cut(self, cut_value_tuple):
"""
Apply transverse mass cut to all events by passing an acceptance
range by ::code`cut_value_tuple`. All particles outside this range will
be removed.

Parameters
----------
cut_value_tuple : tuple
Tuple with the upper and lower limits of the mT acceptance
range :code:`(cut_min, cut_max)`. If one of the limits is not
required, set it to :code:`None`, i.e. :code:`(None, cut_max)`
or :code:`(cut_min, None)`.

Returns
-------
self : Oscar object
Containing only particles complying with the transverse mass
cut for all events
"""

self.particle_list_ = mT_cut(self.particle_list_, cut_value_tuple)
self.__update_num_output_per_event_after_filter()

return self

def rapidity_cut(self, cut_value):
"""
Apply rapidity cut to all events and remove all particles with rapidity
Expand Down
28 changes: 28 additions & 0 deletions src/sparkx/Oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ def __apply_kwargs_filters(self, event, filters_dict):
filters_dict['spacetime_cut'][1])
elif i == 'pt_cut':
event = pt_cut(event, filters_dict['pt_cut'])
elif i == 'mT_cut':
event = mT_cut(event, filters_dict['mT_cut'])
elif i == 'rapidity_cut':
event = rapidity_cut(event, filters_dict['rapidity_cut'])
elif i == 'pseudorapidity_cut':
Expand Down Expand Up @@ -921,6 +923,32 @@ def pt_cut(self, cut_value_tuple):

return self

def mT_cut(self, cut_value_tuple):
"""
Apply transverse mass cut to all events by passing an acceptance
range by ::code`cut_value_tuple`. All particles outside this range will
be removed.

Parameters
----------
cut_value_tuple : tuple
Tuple with the upper and lower limits of the mT acceptance
range :code:`(cut_min, cut_max)`. If one of the limits is not
required, set it to :code:`None`, i.e. :code:`(None, cut_max)`
or :code:`(cut_min, None)`.

Returns
-------
self : Oscar object
Containing only particles complying with the transverse mass
cut for all events
"""

self.particle_list_ = mT_cut(self.particle_list_, cut_value_tuple)
self.__update_num_output_per_event_after_filter()

return self

def rapidity_cut(self, cut_value):
"""
Apply rapidity cut to all events and remove all particles with rapidity
Expand Down
32 changes: 29 additions & 3 deletions src/sparkx/Particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ class Particle:
Compute mass from energy momentum relation
compute_charge_from_pdg:
Compute charge from PDG code
mT:
Compute transverse mass
is_meson:
Is the particle a meson?
is_baryon:
Expand Down Expand Up @@ -1051,11 +1053,12 @@ def compute_mass_from_energy_momentum(self):
self.pz):
return np.nan
else:
if np.abs(self.E**2. - self.p_abs()**2.) > 1e-16 and\
self.E**2. - self.p_abs()**2. > 0.:
if abs(self.E) >= abs(self.p_abs()):
return np.sqrt(self.E**2. - self.p_abs()**2.)
else:
return 0.
warnings.warn('|E| >= |p| not fulfilled or not within numerical precision! '
'The mass is set to nan.')
return np.nan

def compute_charge_from_pdg(self):
"""
Expand All @@ -1076,6 +1079,29 @@ def compute_charge_from_pdg(self):
return np.nan
return PDGID(self.pdg).charge

def mT(self):
"""
Compute the transverse mass :math:`m_{T}=\\sqrt{E^2-p_z^2}` of the particle.

Returns
-------
float
transverse mass

Notes
-----
If one of the needed particle quantities is not given, then `np.nan`
is returned.
"""
if np.isnan(self.E) or np.isnan(self.pz):
return np.nan
elif abs(self.E) >= abs(self.pz):
return np.sqrt(self.E**2. - self.pz**2.)
else:
warnings.warn('|E| >= |pz| not fulfilled or not within numerical precision! '
'The transverse mass is set to nan.')
return np.nan

def is_meson(self):
"""
Is the particle a meson?
Expand Down
136 changes: 91 additions & 45 deletions tests/test_Filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,45 +233,40 @@ def particle_list_positions():
def test_spacetime_cut(particle_list_positions):
test_cases = [
# Test cases for valid input
('t', (0.5, 1.5), [[particle_list_positions[0][0]]]),
('t', (1.5, 2.5), [[]]),
('x',
(-0.5,
0.5),
('t', (0.5, 1.5), None, None, [[particle_list_positions[0][0]]]),
('t', (1.5, 2.5), None, None, [[]]),
('x',(-0.5, 0.5), None, None,
[[particle_list_positions[0][0],
particle_list_positions[0][2],
particle_list_positions[0][3]]]),
('y', (0.5, None), [[particle_list_positions[0][2]]]),
('z',
(None,
0.5),
('y', (0.5, None), None, None, [[particle_list_positions[0][2]]]),
('z', (None, 0.5), None, None,
[[particle_list_positions[0][0],
particle_list_positions[0][1],
particle_list_positions[0][2]]]),

# Test cases for error conditions
('t', (None, None), ValueError),
('t', (1.5, 0.5), ValueError),
('t', (0.5,), TypeError),
('t', ('a', 1.5), ValueError),
('w', (0.5, 1.5), ValueError),
('x', (1.5, 0.5), ValueError),
('t', (None, None), None, ValueError, None),
('t', (1.5, 0.5), UserWarning, None, [[particle_list_positions[0][0]]]),
('t', (0.5,), None, TypeError, None),
('t', ('a', 1.5), None, ValueError, None),
('w', (0.5, 1.5), None, ValueError, None),
('x', (1.5, 0.5), UserWarning, None, [[particle_list_positions[0][1]]]),
]

for dim, cut_value_tuple, expected_result in test_cases:
if isinstance(
expected_result,
type) and issubclass(
expected_result,
Exception):
# If expected_result is an Exception, we expect an error to be
# raised
with pytest.raises(expected_result):
spacetime_cut(particle_list_positions, dim, cut_value_tuple)
for dim, cut_value_tuple, expected_warning, expected_error, expected_result in test_cases:
if expected_warning:
with pytest.warns(expected_warning):
result = spacetime_cut(particle_list_positions, dim, cut_value_tuple)
assert result == expected_result

elif expected_error:
with pytest.raises(expected_error):
result = spacetime_cut(particle_list_positions, dim, cut_value_tuple)

else:
# Apply the spacetime cut
result = spacetime_cut(
particle_list_positions, dim, cut_value_tuple)
result = spacetime_cut(particle_list_positions, dim, cut_value_tuple)
# Assert the result matches the expected outcome
assert result == expected_result

Expand All @@ -290,36 +285,87 @@ def particle_list_pt():
def test_pt_cut(particle_list_pt):
test_cases = [
# Test cases for valid input
((0.5, 1.5), [[particle_list_pt[0][1]]]),
((2.5, None), [[particle_list_pt[0][3], particle_list_pt[0][4]]]),
((None, 3.5), [[particle_list_pt[0][0], particle_list_pt[0]
((0.5, 1.5), None, None, [[particle_list_pt[0][1]]]),
((2.5, None), None, None, [[particle_list_pt[0][3], particle_list_pt[0][4]]]),
((None, 3.5), None, None, [[particle_list_pt[0][0], particle_list_pt[0]
[1], particle_list_pt[0][2], particle_list_pt[0][3]]]),

# Test cases for error conditions
((None, None), ValueError),
((-1, 3), ValueError),
(('a', 3), ValueError),
((3, 2), ValueError),
((None, None, None), TypeError),
((None, None), None, ValueError, None),
((-1, 3), None, ValueError, None),
(('a', 3), None, ValueError, None),
((3, 2), UserWarning, None, [[particle_list_pt[0][2], particle_list_pt[0][3]]]),
((None, None, None), None, TypeError, None),
]

for cut_value_tuple, expected_result in test_cases:
if isinstance(
expected_result,
type) and issubclass(
expected_result,
Exception):
# If expected_result is an Exception, we expect an error to be
# raised
with pytest.raises(expected_result):
pt_cut(particle_list_pt, cut_value_tuple)
for cut_value_tuple, expected_warning, expected_error, expected_result in test_cases:
if expected_warning:
with pytest.warns(expected_warning):
result = pt_cut(particle_list_pt, cut_value_tuple)
assert result == expected_result

elif expected_error:
with pytest.raises(expected_error):
result = pt_cut(particle_list_pt, cut_value_tuple)

else:
# Apply the pt_cut
result = pt_cut(particle_list_pt, cut_value_tuple)
# Assert the result matches the expected outcome
assert result == expected_result


@pytest.fixture
def particle_list_mT():
particle_list = []
mT_E_pz_pairs = [
(3, 5, 4),
(4, 5, 3),
(5, 13, 12),
(6, 10, 8),
(7, 25, 24),
]
for m_T, energy, p_z in mT_E_pz_pairs:
p = Particle()
p.E = energy
p.pz = p_z
particle_list.append(p)
return [particle_list]


def test_mT_cut(particle_list_mT):
test_cases = [
# Test cases for valid input
((2.5, 3.5), None, None, [[particle_list_mT[0][0]]]),
((5.5, None), None, None, [[particle_list_mT[0][3], particle_list_mT[0][4]]]),
((None, 6.5), None, None, [[particle_list_mT[0][0], particle_list_mT[0][1],
particle_list_mT[0][2], particle_list_mT[0][3]]]),

# Test cases for error conditions
((None, None), None, ValueError, None),
((-1, 6), None, ValueError, None),
(('a', 5), None, ValueError, None),
((5.7, 3.3), UserWarning, None, [[particle_list_mT[0][1], particle_list_mT[0][2]]]),
((None, None, None), None, TypeError, None),
]

for cut_value_tuple, expected_warning, expected_error, expected_result in test_cases:
if expected_warning:
with pytest.warns(expected_warning):
result = mT_cut(particle_list_mT, cut_value_tuple)
assert result == expected_result

elif expected_error:
with pytest.raises(expected_error):
result = mT_cut(particle_list_mT, cut_value_tuple)

else:
# Apply the mT_cut
result = mT_cut(particle_list_mT, cut_value_tuple)
# Assert the result matches the expected outcome
assert result == expected_result


@pytest.fixture
def particle_list_momentum_rapidity():
particle_list = []
Expand Down
47 changes: 47 additions & 0 deletions tests/test_Particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,16 @@ def test_compute_mass_from_energy_momentum_valid_values():

assert np.isclose(result, expected_result)

def test_compute_mass_from_energy_momentum_invalid_values():
p = Particle()
p.E = 3.0
p.px = 1.0
p.py = 2.0
p.pz = 5.0

with pytest.warns(UserWarning, match=r"|E| >= |p| not fulfilled or not within numerical precision! The mass is set to nan."):
assert np.isnan(p.compute_mass_from_energy_momentum())


def test_compute_mass_from_energy_momentum_missing_values():
p = Particle()
Expand All @@ -733,6 +743,43 @@ def test_compute_mass_from_energy_momentum_zero_energy():

assert np.isclose(result, 0.0)

def test_mT_missing_values():
p = Particle()
# Leave some values as np.nan

result = p.mT()

assert np.isnan(result)

def test_mT_invalid_values():
p = Particle()
p.E = 11.2
p.pz = 11.3

with pytest.warns(UserWarning, match=r"|E| >= |pz| not fulfilled or not within numerical precision! The transverse mass is set to nan."):
assert np.isnan(p.mT())


def test_mT_valid_values():
p = Particle()

# Test for zero
p.E = 0.0
p.pz = 0.0
assert np.isclose(p.mT(), 0.0)

# Test for a few random values
for i in range(10):
energy = np.random.uniform(5.0, 10.0)
p_z = np.random.uniform(0.1, 5.0)

p.E = energy
p.pz = p_z

result = p.mT()
expected_result = np.sqrt(energy**2 - p_z**2)

assert np.isclose(result, expected_result)

def test_compute_charge_from_pdg_valid_values():
p = Particle()
Expand Down
Loading