diff --git a/CHANGELOG.md b/CHANGELOG.md index 93345f90..e9f870ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,17 @@ The main categories for changes in this file are: A `Deprecated` section could be added if needed for soon-to-be removed features. +## v.2.0.0 +Date: + +### Added +* Particle: Add member function to calculate the transverse mass of a particle +* Oscar: Add transverse mass cut to methods +* Jetscape: Add transverse mass cut to methods + +### Changed +* Filter: Perform general clean up to reduce code duplications + ## v1.3.0-Newton Date: 2024-07-25 diff --git a/docs/source/classes/Particle/index.rst b/docs/source/classes/Particle/index.rst index d899ce0f..ca72ae79 100644 --- a/docs/source/classes/Particle/index.rst +++ b/docs/source/classes/Particle/index.rst @@ -18,6 +18,7 @@ Particle .. automethod:: Particle.proper_time .. automethod:: Particle.compute_mass_from_energy_momentum .. automethod:: Particle.compute_charge_from_pdg +.. automethod:: Particle.mT .. automethod:: Particle.is_meson .. automethod:: Particle.is_baryon .. automethod:: Particle.is_hadron diff --git a/src/sparkx/Filter.py b/src/sparkx/Filter.py index 9eb8cead..8080c0a4 100644 --- a/src/sparkx/Filter.py +++ b/src/sparkx/Filter.py @@ -12,6 +12,53 @@ import warnings +def __ensure_tuple_is_valid_else_raise_error(value_tuple, allow_none=False): + """ + Validates a tuple for specific conditions. + + This function checks if the input is a tuple of length two, where both + elements are either numbers or None. If allow_none is set to False, the + function raises an error if any of the elements is None. If allow_none is + set to True, the function raises an error if both elements are None. + + Parameters + ---------- + value_tuple : tuple + The tuple to be validated. Expected to be of length two. + allow_none : bool, optional + Determines whether None values are allowed in the tuple. Default is False. + + Raises + ------ + TypeError + If the input value is not a tuple or if it's not of length two. + ValueError + If non-numeric value is found in the tuple, or if None values are not + allowed and found, or if both elements are None when allow_none is True. + """ + if not isinstance(value_tuple, tuple) or len(value_tuple) != 2: + raise TypeError('Input value must be a tuple of length two') + + elif any(val is not None and not isinstance(val, (int, float)) for val in value_tuple): + raise ValueError('Non-numeric value found in given tuple') + + elif (value_tuple[0] is not None and value_tuple[1] is not None) and \ + (value_tuple[0] >= value_tuple[1]): + warn_msg = ( + 'Lower limit {} is greater than upper limit {}. ' + 'Switched order is assumed in the following.' + ).format(value_tuple[0], value_tuple[1]) + warnings.warn(warn_msg) + + elif not allow_none: + if (value_tuple[0] is None or value_tuple[1] is None): + raise ValueError('At least one value in the tuple is None') + + elif allow_none: + if value_tuple[0] is None and value_tuple[1] is None: + raise ValueError('At least one cut limit must be set to a number') + + def charged_particles(particle_list): """ Keep only charged particles in particle_list. @@ -323,49 +370,46 @@ def spacetime_cut(particle_list, dim, cut_value_tuple): list of lists Filtered list of lists containing particle objects for each event """ - if not isinstance(cut_value_tuple, tuple) or len(cut_value_tuple) != 2: - raise TypeError('Input value must be a tuple of length two') - elif any(val is not None and not isinstance(val, (int, float)) for val in cut_value_tuple): - raise ValueError('Non-numeric value found in cut_value_tuple') - elif cut_value_tuple[0] is None and cut_value_tuple[1] is None: - raise ValueError('At least one cut limit must be a number') - elif dim == "t" and cut_value_tuple[0] < 0: - raise ValueError('Time boundary must be positive or zero.') + if not isinstance(cut_value_tuple, tuple): + raise TypeError('Input value must be a tuple containing either ' + + 'positive numbers or None of length two') + + __ensure_tuple_is_valid_else_raise_error(cut_value_tuple, allow_none=True) + if dim not in ("x", "y", "z", "t"): raise ValueError('Only "t, x, y and z are possible dimensions.') if cut_value_tuple[0] is None: - if (dim != "t"): - lower_cut = float('-inf') - else: - lower_cut = 0.0 + lower_cut = float('-inf') else: lower_cut = cut_value_tuple[0] + if cut_value_tuple[1] is None: upper_cut = float('inf') else: upper_cut = cut_value_tuple[1] - if upper_cut < lower_cut: - raise ValueError('The upper cut is smaller than the lower cut!') + # Ensure cut values are in the correct order + lim_max = max(upper_cut, lower_cut) + lim_min = min(upper_cut, lower_cut) updated_particle_list = [] for i in range(0, len(particle_list)): if (dim == "t"): particle_list_tmp = [elem for elem in particle_list[i] if ( - lower_cut <= elem.t <= upper_cut and not np.isnan(elem.t))] + lim_min <= elem.t <= lim_max and not np.isnan(elem.t))] elif (dim == "x"): particle_list_tmp = [elem for elem in particle_list[i] if ( - lower_cut <= elem.x <= upper_cut and not np.isnan(elem.x))] + lim_min <= elem.x <= lim_max and not np.isnan(elem.x))] elif (dim == "y"): particle_list_tmp = [elem for elem in particle_list[i] if ( - lower_cut <= elem.y <= upper_cut and not np.isnan(elem.y))] + lim_min <= elem.y <= lim_max and not np.isnan(elem.y))] else: particle_list_tmp = [elem for elem in particle_list[i] if ( - lower_cut <= elem.z <= upper_cut and not np.isnan(elem.z))] + lim_min <= elem.z <= lim_max and not np.isnan(elem.z))] updated_particle_list.append(particle_list_tmp) - particle_list = updated_particle_list - return particle_list + + return updated_particle_list def pt_cut(particle_list, cut_value_tuple): @@ -390,37 +434,106 @@ def pt_cut(particle_list, cut_value_tuple): list of lists Filtered list of lists containing particle objects for each event """ - if not isinstance(cut_value_tuple, tuple) or len(cut_value_tuple) != 2: + if not isinstance(cut_value_tuple, tuple): raise TypeError('Input value must be a tuple containing either ' + 'positive numbers or None of length two') - elif any(val is not None and not isinstance(val, (int, float)) for val in cut_value_tuple): - raise ValueError('Non-numeric value found in cut_value_tuple') - elif (cut_value_tuple[0] is not None and cut_value_tuple[0] < 0) or \ - (cut_value_tuple[1] is not None and cut_value_tuple[1] < 0): + + __ensure_tuple_is_valid_else_raise_error(cut_value_tuple, allow_none=True) + + # Check if the cut limits are positive if they are not None + if (cut_value_tuple[0] is not None and cut_value_tuple[0] < 0) or \ + (cut_value_tuple[1] is not None and cut_value_tuple[1] < 0): raise ValueError('The cut limits must be positive or None') - elif cut_value_tuple[0] is None and cut_value_tuple[1] is None: - raise ValueError('At least one cut limit must be a number') + # Assign numerical values to the cut limits. Even though we check for + # non negative values, we send a left None value to -inf for numerical + # reasons. This is still consistent with the logic of the cut, as the + # lower cut applies to the absolute value of the pT, which is limited to + # positive values. if cut_value_tuple[0] is None: - lower_cut = 0.0 + lower_cut = float('-inf') else: lower_cut = cut_value_tuple[0] + if cut_value_tuple[1] is None: upper_cut = float('inf') else: upper_cut = cut_value_tuple[1] - if upper_cut < lower_cut: - raise ValueError('The upper cut is smaller than the lower cut!') + # Ensure cut values are in the correct order + lim_max = max(upper_cut, lower_cut) + lim_min = min(upper_cut, lower_cut) updated_particle_list = [] for i in range(0, len(particle_list)): particle_list_tmp = [elem for elem in particle_list[i] if - (lower_cut <= elem.pt_abs() <= upper_cut + (lim_min <= elem.pt_abs() <= lim_max and not np.isnan(elem.pt_abs()))] updated_particle_list.append(particle_list_tmp) - particle_list = updated_particle_list - return particle_list + + return updated_particle_list + + +def mT_cut(particle_list, cut_value_tuple): + """ + Apply transverse mass cut to all events by passing an acceptance range by + ::code`cut_value_tuple`. All particles outside this range will + be removed. + + Parameters + ---------- + particle_list: + List with lists containing particle objects for the events + + cut_value_tuple : tuple + Tuple with the upper and lower limits of the mT acceptance + range :code:`(cut_min, cut_max)`. If one of the limits is not + required, set it to :code:`None`, i.e. :code:`(None, cut_max)` + or :code:`(cut_min, None)`. + + Returns + ------- + list of lists + Filtered list of lists containing particle objects for each event + """ + if not isinstance(cut_value_tuple, tuple): + raise TypeError('Input value must be a tuple containing either ' + + 'positive numbers or None of length two') + + __ensure_tuple_is_valid_else_raise_error(cut_value_tuple, allow_none=True) + + # Check if the cut limits are positive if they are not None + if (cut_value_tuple[0] is not None and cut_value_tuple[0] < 0) or \ + (cut_value_tuple[1] is not None and cut_value_tuple[1] < 0): + raise ValueError('The cut limits must be positive or None') + + # Assign numerical values to the cut limits. Even though we check for + # non negative values, we send a left None value to -inf for numerical + # reasons. This is still consistent with the logic of the cut, as the + # lower cut applies to the absolute value of the pT, which is limited to + # positive values. + if cut_value_tuple[0] is None: + lower_cut = float('-inf') + else: + lower_cut = cut_value_tuple[0] + + if cut_value_tuple[1] is None: + upper_cut = float('inf') + else: + upper_cut = cut_value_tuple[1] + + # Ensure cut values are in the correct order + lim_max = max(upper_cut, lower_cut) + lim_min = min(upper_cut, lower_cut) + + updated_particle_list = [] + for i in range(0, len(particle_list)): + particle_list_tmp = [elem for elem in particle_list[i] if + (lim_min <= elem.mT() <= lim_max + and not np.isnan(elem.mT()))] + updated_particle_list.append(particle_list_tmp) + + return updated_particle_list def rapidity_cut(particle_list, cut_value): @@ -449,32 +562,9 @@ def rapidity_cut(particle_list, cut_value): Filtered list of lists containing particle objects for each event """ if isinstance(cut_value, tuple): - if len(cut_value) != 2: - raise TypeError('If input value is a tuple, then it must contain ' + - 'two numbers') - elif any(not isinstance(val, (int, float)) for val in cut_value): - raise ValueError('Non-numeric value found in cut_value') - - if cut_value[0] > cut_value[1]: - warn_msg = 'Lower limit {} is greater that upper limit {}. Switched order is assumed in the following.'.format( - cut_value[0], cut_value[1]) - warnings.warn(warn_msg) - - elif not isinstance(cut_value, (int, float)): - raise TypeError('Input value must be a number or a tuple ' + - 'with the cut limits (cut_min, cut_max)') + __ensure_tuple_is_valid_else_raise_error(cut_value, allow_none=False) - if isinstance(cut_value, (int, float)): - # cut symmetrically around 0 - limit = np.abs(cut_value) - - updated_particle_list = [] - for i in range(0, len(particle_list)): - particle_list_tmp = [elem for elem in particle_list[i] if - (-limit <= elem.momentum_rapidity_Y() <= limit - and not np.isnan(elem.momentum_rapidity_Y()))] - updated_particle_list.append(particle_list_tmp) - elif isinstance(cut_value, tuple): + # Ensure cut values are in the correct order lim_max = max(cut_value[0], cut_value[1]) lim_min = min(cut_value[0], cut_value[1]) @@ -486,8 +576,22 @@ def rapidity_cut(particle_list, cut_value): elem.momentum_rapidity_Y()))] updated_particle_list.append(particle_list_tmp) - particle_list = updated_particle_list - return particle_list + elif isinstance(cut_value, (int, float)): + # cut symmetrically around 0 + limit = np.abs(cut_value) + + updated_particle_list = [] + for i in range(0, len(particle_list)): + particle_list_tmp = [elem for elem in particle_list[i] if + (-limit <= elem.momentum_rapidity_Y() <= limit + and not np.isnan(elem.momentum_rapidity_Y()))] + updated_particle_list.append(particle_list_tmp) + + else: + raise TypeError('Input value must be a number or a tuple ' + + 'with the cut limits (cut_min, cut_max)') + + return updated_particle_list def pseudorapidity_cut(particle_list, cut_value): @@ -516,44 +620,35 @@ def pseudorapidity_cut(particle_list, cut_value): Filtered list of lists containing particle objects for each event """ if isinstance(cut_value, tuple): - if len(cut_value) != 2: - raise TypeError('If input value is a tuple, then it must contain ' + - 'two numbers') - elif any(not isinstance(val, (int, float)) for val in cut_value): - raise ValueError('Non-numeric value found in cut_value') - - if cut_value[0] > cut_value[1]: - warn_msg = 'Lower limit {} is greater that upper limit {}. Switched order is assumed in the following.'.format( - cut_value[0], cut_value[1]) - warnings.warn(warn_msg) - - elif not isinstance(cut_value, (int, float)): - raise TypeError('Input value must be a number or a tuple ' + - 'with the cut limits (cut_min, cut_max)') + __ensure_tuple_is_valid_else_raise_error(cut_value, allow_none=False) - if isinstance(cut_value, (int, float)): - # cut symmetrically around 0 - limit = np.abs(cut_value) + # Ensure cut values are in the correct order + lim_max = max(cut_value[0], cut_value[1]) + lim_min = min(cut_value[0], cut_value[1]) updated_particle_list = [] for i in range(0, len(particle_list)): particle_list_tmp = [elem for elem in particle_list[i] if - (-limit <= elem.pseudorapidity() <= limit + (lim_min <= elem.pseudorapidity() <= lim_max and not np.isnan(elem.pseudorapidity()))] updated_particle_list.append(particle_list_tmp) - elif isinstance(cut_value, tuple): - lim_max = max(cut_value[0], cut_value[1]) - lim_min = min(cut_value[0], cut_value[1]) + + elif isinstance(cut_value, (int, float)): + # cut symmetrically around 0 + limit = np.abs(cut_value) updated_particle_list = [] for i in range(0, len(particle_list)): particle_list_tmp = [elem for elem in particle_list[i] if - (lim_min <= elem.pseudorapidity() <= lim_max + (-limit <= elem.pseudorapidity() <= limit and not np.isnan(elem.pseudorapidity()))] updated_particle_list.append(particle_list_tmp) - particle_list = updated_particle_list - return particle_list + else: + raise TypeError('Input value must be a number or a tuple ' + + 'with the cut limits (cut_min, cut_max)') + + return updated_particle_list def spatial_rapidity_cut(particle_list, cut_value): @@ -582,44 +677,35 @@ def spatial_rapidity_cut(particle_list, cut_value): Filtered list of lists containing particle objects for each event """ if isinstance(cut_value, tuple): - if len(cut_value) != 2: - raise TypeError('If input value is a tuple, then it must contain ' + - 'two numbers') - elif any(not isinstance(val, (int, float)) for val in cut_value): - raise ValueError('Non-numeric value found in cut_value') - - if cut_value[0] > cut_value[1]: - warn_msg = 'Lower limit {} is greater that upper limit {}. Switched order is assumed in the following.'.format( - cut_value[0], cut_value[1]) - warnings.warn(warn_msg) - - elif not isinstance(cut_value, (int, float)): - raise TypeError('Input value must be a number or a tuple ' + - 'with the cut limits (cut_min, cut_max)') + __ensure_tuple_is_valid_else_raise_error(cut_value, allow_none=False) - if isinstance(cut_value, (int, float)): - # cut symmetrically around 0 - limit = np.abs(cut_value) + # Ensure cut values are in the correct order + lim_max = max(cut_value[0], cut_value[1]) + lim_min = min(cut_value[0], cut_value[1]) updated_particle_list = [] for i in range(0, len(particle_list)): particle_list_tmp = [elem for elem in particle_list[i] if - (-limit <= elem.spatial_rapidity() <= limit + (lim_min <= elem.spatial_rapidity() <= lim_max and not np.isnan(elem.spatial_rapidity()))] updated_particle_list.append(particle_list_tmp) - elif isinstance(cut_value, tuple): - lim_max = max(cut_value[0], cut_value[1]) - lim_min = min(cut_value[0], cut_value[1]) + + elif isinstance(cut_value, (int, float)): + # cut symmetrically around 0 + limit = np.abs(cut_value) updated_particle_list = [] for i in range(0, len(particle_list)): particle_list_tmp = [elem for elem in particle_list[i] if - (lim_min <= elem.spatial_rapidity() <= lim_max + (-limit <= elem.spatial_rapidity() <= limit and not np.isnan(elem.spatial_rapidity()))] updated_particle_list.append(particle_list_tmp) - particle_list = updated_particle_list - return particle_list + else: + raise TypeError('Input value must be a number or a tuple ' + + 'with the cut limits (cut_min, cut_max)') + + return updated_particle_list def multiplicity_cut(particle_list, min_multiplicity): diff --git a/src/sparkx/Jetscape.py b/src/sparkx/Jetscape.py index 5edd0b4a..4aad7d64 100644 --- a/src/sparkx/Jetscape.py +++ b/src/sparkx/Jetscape.py @@ -374,6 +374,8 @@ def __apply_kwargs_filters(self, event, filters_dict): event, filters_dict['lower_event_energy_cut']) elif i == 'pt_cut': event = pt_cut(event, filters_dict['pt_cut']) + elif i == 'mT_cut': + event = mT_cut(event, filters_dict['mT_cut']) elif i == 'rapidity_cut': event = rapidity_cut(event, filters_dict['rapidity_cut']) elif i == 'pseudorapidity_cut': @@ -739,6 +741,32 @@ def pt_cut(self, cut_value_tuple): return self + def mT_cut(self, cut_value_tuple): + """ + Apply transverse mass cut to all events by passing an acceptance + range by ::code`cut_value_tuple`. All particles outside this range will + be removed. + + Parameters + ---------- + cut_value_tuple : tuple + Tuple with the upper and lower limits of the mT acceptance + range :code:`(cut_min, cut_max)`. If one of the limits is not + required, set it to :code:`None`, i.e. :code:`(None, cut_max)` + or :code:`(cut_min, None)`. + + Returns + ------- + self : Oscar object + Containing only particles complying with the transverse mass + cut for all events + """ + + self.particle_list_ = mT_cut(self.particle_list_, cut_value_tuple) + self.__update_num_output_per_event_after_filter() + + return self + def rapidity_cut(self, cut_value): """ Apply rapidity cut to all events and remove all particles with rapidity diff --git a/src/sparkx/Oscar.py b/src/sparkx/Oscar.py index ab2635d8..5ccc18b9 100644 --- a/src/sparkx/Oscar.py +++ b/src/sparkx/Oscar.py @@ -431,6 +431,8 @@ def __apply_kwargs_filters(self, event, filters_dict): filters_dict['spacetime_cut'][1]) elif i == 'pt_cut': event = pt_cut(event, filters_dict['pt_cut']) + elif i == 'mT_cut': + event = mT_cut(event, filters_dict['mT_cut']) elif i == 'rapidity_cut': event = rapidity_cut(event, filters_dict['rapidity_cut']) elif i == 'pseudorapidity_cut': @@ -921,6 +923,32 @@ def pt_cut(self, cut_value_tuple): return self + def mT_cut(self, cut_value_tuple): + """ + Apply transverse mass cut to all events by passing an acceptance + range by ::code`cut_value_tuple`. All particles outside this range will + be removed. + + Parameters + ---------- + cut_value_tuple : tuple + Tuple with the upper and lower limits of the mT acceptance + range :code:`(cut_min, cut_max)`. If one of the limits is not + required, set it to :code:`None`, i.e. :code:`(None, cut_max)` + or :code:`(cut_min, None)`. + + Returns + ------- + self : Oscar object + Containing only particles complying with the transverse mass + cut for all events + """ + + self.particle_list_ = mT_cut(self.particle_list_, cut_value_tuple) + self.__update_num_output_per_event_after_filter() + + return self + def rapidity_cut(self, cut_value): """ Apply rapidity cut to all events and remove all particles with rapidity diff --git a/src/sparkx/Particle.py b/src/sparkx/Particle.py index 6f3bf681..3359abf1 100755 --- a/src/sparkx/Particle.py +++ b/src/sparkx/Particle.py @@ -149,6 +149,8 @@ class Particle: Compute mass from energy momentum relation compute_charge_from_pdg: Compute charge from PDG code + mT: + Compute transverse mass is_meson: Is the particle a meson? is_baryon: @@ -1051,11 +1053,12 @@ def compute_mass_from_energy_momentum(self): self.pz): return np.nan else: - if np.abs(self.E**2. - self.p_abs()**2.) > 1e-16 and\ - self.E**2. - self.p_abs()**2. > 0.: + if abs(self.E) >= abs(self.p_abs()): return np.sqrt(self.E**2. - self.p_abs()**2.) else: - return 0. + warnings.warn('|E| >= |p| not fulfilled or not within numerical precision! ' + 'The mass is set to nan.') + return np.nan def compute_charge_from_pdg(self): """ @@ -1076,6 +1079,29 @@ def compute_charge_from_pdg(self): return np.nan return PDGID(self.pdg).charge + def mT(self): + """ + Compute the transverse mass :math:`m_{T}=\\sqrt{E^2-p_z^2}` of the particle. + + Returns + ------- + float + transverse mass + + Notes + ----- + If one of the needed particle quantities is not given, then `np.nan` + is returned. + """ + if np.isnan(self.E) or np.isnan(self.pz): + return np.nan + elif abs(self.E) >= abs(self.pz): + return np.sqrt(self.E**2. - self.pz**2.) + else: + warnings.warn('|E| >= |pz| not fulfilled or not within numerical precision! ' + 'The transverse mass is set to nan.') + return np.nan + def is_meson(self): """ Is the particle a meson? diff --git a/tests/test_Filter.py b/tests/test_Filter.py index 0b0267db..39434cf6 100644 --- a/tests/test_Filter.py +++ b/tests/test_Filter.py @@ -233,45 +233,40 @@ def particle_list_positions(): def test_spacetime_cut(particle_list_positions): test_cases = [ # Test cases for valid input - ('t', (0.5, 1.5), [[particle_list_positions[0][0]]]), - ('t', (1.5, 2.5), [[]]), - ('x', - (-0.5, - 0.5), + ('t', (0.5, 1.5), None, None, [[particle_list_positions[0][0]]]), + ('t', (1.5, 2.5), None, None, [[]]), + ('x',(-0.5, 0.5), None, None, [[particle_list_positions[0][0], particle_list_positions[0][2], particle_list_positions[0][3]]]), - ('y', (0.5, None), [[particle_list_positions[0][2]]]), - ('z', - (None, - 0.5), + ('y', (0.5, None), None, None, [[particle_list_positions[0][2]]]), + ('z', (None, 0.5), None, None, [[particle_list_positions[0][0], particle_list_positions[0][1], particle_list_positions[0][2]]]), # Test cases for error conditions - ('t', (None, None), ValueError), - ('t', (1.5, 0.5), ValueError), - ('t', (0.5,), TypeError), - ('t', ('a', 1.5), ValueError), - ('w', (0.5, 1.5), ValueError), - ('x', (1.5, 0.5), ValueError), + ('t', (None, None), None, ValueError, None), + ('t', (1.5, 0.5), UserWarning, None, [[particle_list_positions[0][0]]]), + ('t', (0.5,), None, TypeError, None), + ('t', ('a', 1.5), None, ValueError, None), + ('w', (0.5, 1.5), None, ValueError, None), + ('x', (1.5, 0.5), UserWarning, None, [[particle_list_positions[0][1]]]), ] - for dim, cut_value_tuple, expected_result in test_cases: - if isinstance( - expected_result, - type) and issubclass( - expected_result, - Exception): - # If expected_result is an Exception, we expect an error to be - # raised - with pytest.raises(expected_result): - spacetime_cut(particle_list_positions, dim, cut_value_tuple) + for dim, cut_value_tuple, expected_warning, expected_error, expected_result in test_cases: + if expected_warning: + with pytest.warns(expected_warning): + result = spacetime_cut(particle_list_positions, dim, cut_value_tuple) + assert result == expected_result + + elif expected_error: + with pytest.raises(expected_error): + result = spacetime_cut(particle_list_positions, dim, cut_value_tuple) + else: # Apply the spacetime cut - result = spacetime_cut( - particle_list_positions, dim, cut_value_tuple) + result = spacetime_cut(particle_list_positions, dim, cut_value_tuple) # Assert the result matches the expected outcome assert result == expected_result @@ -290,29 +285,29 @@ def particle_list_pt(): def test_pt_cut(particle_list_pt): test_cases = [ # Test cases for valid input - ((0.5, 1.5), [[particle_list_pt[0][1]]]), - ((2.5, None), [[particle_list_pt[0][3], particle_list_pt[0][4]]]), - ((None, 3.5), [[particle_list_pt[0][0], particle_list_pt[0] + ((0.5, 1.5), None, None, [[particle_list_pt[0][1]]]), + ((2.5, None), None, None, [[particle_list_pt[0][3], particle_list_pt[0][4]]]), + ((None, 3.5), None, None, [[particle_list_pt[0][0], particle_list_pt[0] [1], particle_list_pt[0][2], particle_list_pt[0][3]]]), # Test cases for error conditions - ((None, None), ValueError), - ((-1, 3), ValueError), - (('a', 3), ValueError), - ((3, 2), ValueError), - ((None, None, None), TypeError), + ((None, None), None, ValueError, None), + ((-1, 3), None, ValueError, None), + (('a', 3), None, ValueError, None), + ((3, 2), UserWarning, None, [[particle_list_pt[0][2], particle_list_pt[0][3]]]), + ((None, None, None), None, TypeError, None), ] - for cut_value_tuple, expected_result in test_cases: - if isinstance( - expected_result, - type) and issubclass( - expected_result, - Exception): - # If expected_result is an Exception, we expect an error to be - # raised - with pytest.raises(expected_result): - pt_cut(particle_list_pt, cut_value_tuple) + for cut_value_tuple, expected_warning, expected_error, expected_result in test_cases: + if expected_warning: + with pytest.warns(expected_warning): + result = pt_cut(particle_list_pt, cut_value_tuple) + assert result == expected_result + + elif expected_error: + with pytest.raises(expected_error): + result = pt_cut(particle_list_pt, cut_value_tuple) + else: # Apply the pt_cut result = pt_cut(particle_list_pt, cut_value_tuple) @@ -320,6 +315,57 @@ def test_pt_cut(particle_list_pt): assert result == expected_result +@pytest.fixture +def particle_list_mT(): + particle_list = [] + mT_E_pz_pairs = [ + (3, 5, 4), + (4, 5, 3), + (5, 13, 12), + (6, 10, 8), + (7, 25, 24), + ] + for m_T, energy, p_z in mT_E_pz_pairs: + p = Particle() + p.E = energy + p.pz = p_z + particle_list.append(p) + return [particle_list] + + +def test_mT_cut(particle_list_mT): + test_cases = [ + # Test cases for valid input + ((2.5, 3.5), None, None, [[particle_list_mT[0][0]]]), + ((5.5, None), None, None, [[particle_list_mT[0][3], particle_list_mT[0][4]]]), + ((None, 6.5), None, None, [[particle_list_mT[0][0], particle_list_mT[0][1], + particle_list_mT[0][2], particle_list_mT[0][3]]]), + + # Test cases for error conditions + ((None, None), None, ValueError, None), + ((-1, 6), None, ValueError, None), + (('a', 5), None, ValueError, None), + ((5.7, 3.3), UserWarning, None, [[particle_list_mT[0][1], particle_list_mT[0][2]]]), + ((None, None, None), None, TypeError, None), + ] + + for cut_value_tuple, expected_warning, expected_error, expected_result in test_cases: + if expected_warning: + with pytest.warns(expected_warning): + result = mT_cut(particle_list_mT, cut_value_tuple) + assert result == expected_result + + elif expected_error: + with pytest.raises(expected_error): + result = mT_cut(particle_list_mT, cut_value_tuple) + + else: + # Apply the mT_cut + result = mT_cut(particle_list_mT, cut_value_tuple) + # Assert the result matches the expected outcome + assert result == expected_result + + @pytest.fixture def particle_list_momentum_rapidity(): particle_list = [] diff --git a/tests/test_Particle.py b/tests/test_Particle.py index ae7963d3..8faafc03 100644 --- a/tests/test_Particle.py +++ b/tests/test_Particle.py @@ -712,6 +712,16 @@ def test_compute_mass_from_energy_momentum_valid_values(): assert np.isclose(result, expected_result) +def test_compute_mass_from_energy_momentum_invalid_values(): + p = Particle() + p.E = 3.0 + p.px = 1.0 + p.py = 2.0 + p.pz = 5.0 + + with pytest.warns(UserWarning, match=r"|E| >= |p| not fulfilled or not within numerical precision! The mass is set to nan."): + assert np.isnan(p.compute_mass_from_energy_momentum()) + def test_compute_mass_from_energy_momentum_missing_values(): p = Particle() @@ -733,6 +743,43 @@ def test_compute_mass_from_energy_momentum_zero_energy(): assert np.isclose(result, 0.0) +def test_mT_missing_values(): + p = Particle() + # Leave some values as np.nan + + result = p.mT() + + assert np.isnan(result) + +def test_mT_invalid_values(): + p = Particle() + p.E = 11.2 + p.pz = 11.3 + + with pytest.warns(UserWarning, match=r"|E| >= |pz| not fulfilled or not within numerical precision! The transverse mass is set to nan."): + assert np.isnan(p.mT()) + + +def test_mT_valid_values(): + p = Particle() + + # Test for zero + p.E = 0.0 + p.pz = 0.0 + assert np.isclose(p.mT(), 0.0) + + # Test for a few random values + for i in range(10): + energy = np.random.uniform(5.0, 10.0) + p_z = np.random.uniform(0.1, 5.0) + + p.E = energy + p.pz = p_z + + result = p.mT() + expected_result = np.sqrt(energy**2 - p_z**2) + + assert np.isclose(result, expected_result) def test_compute_charge_from_pdg_valid_values(): p = Particle()