diff --git a/csdmpy/dimension/__init__.py b/csdmpy/dimension/__init__.py index 28c01ea..2690859 100644 --- a/csdmpy/dimension/__init__.py +++ b/csdmpy/dimension/__init__.py @@ -172,7 +172,7 @@ def __str__(self): def __eq__(self, other): """Overrides the default implementation.""" other = other.subtype if isinstance(other, Dimension) else other - return True if self.subtype == other else False + return self.subtype == other def __mul__(self, other): """Multiply the Dimension object by a right scalar.""" diff --git a/csdmpy/dimension/base.py b/csdmpy/dimension/base.py index d8f946e..988f0c4 100644 --- a/csdmpy/dimension/base.py +++ b/csdmpy/dimension/base.py @@ -3,6 +3,8 @@ import warnings from copy import deepcopy +import numpy as np + from csdmpy.utils import validate @@ -23,7 +25,7 @@ def __init__(self, label, application, description): def __eq__(self, other): """Check if two objects are equal""" check = [getattr(self, _) == getattr(other, _) for _ in __class__.__slots__] - return False if False in check else True + return np.all(check) @property def label(self): diff --git a/csdmpy/dimension/linear.py b/csdmpy/dimension/linear.py index 1d5870d..d42ae05 100644 --- a/csdmpy/dimension/linear.py +++ b/csdmpy/dimension/linear.py @@ -7,6 +7,7 @@ from csdmpy.dimension.quantitative import ReciprocalDimension from csdmpy.units import frequency_ratio from csdmpy.units import ScalarQuantity +from csdmpy.utils import assert_params from csdmpy.utils import check_and_assign_bool from csdmpy.utils import check_scalar_object from csdmpy.utils import validate @@ -70,7 +71,10 @@ def __eq__(self, other): other = other.subtype if hasattr(other, "subtype") else other if not isinstance(other, LinearDimension): return False - check = [getattr(self, _) == getattr(other, _) for _ in __class__.__slots__[:4]] + + non_quantitative = ["reciprocal", "_complex_fft"] + quantitative = ["_count", "_increment"] + check = assert_params(self, other, quantitative, non_quantitative) check += [super().__eq__(other)] return np.all(check) diff --git a/csdmpy/dimension/monotonic.py b/csdmpy/dimension/monotonic.py index ab59264..2a2eca8 100644 --- a/csdmpy/dimension/monotonic.py +++ b/csdmpy/dimension/monotonic.py @@ -9,6 +9,7 @@ from csdmpy.units import frequency_ratio from csdmpy.units import scalar_quantity_format from csdmpy.units import ScalarQuantity +from csdmpy.utils import assert_params from csdmpy.utils import check_scalar_object @@ -63,12 +64,10 @@ def __eq__(self, other): if not isinstance(other, MonotonicDimension): return False - check = [ - self._count == other._count, - np.all(self._coordinates == other._coordinates), - self.reciprocal == other.reciprocal, - super().__eq__(other), - ] + non_quantitative = ["reciprocal"] + quantitative = ["_count", "_coordinates"] + check = assert_params(self, other, quantitative, non_quantitative) + check += [super().__eq__(other)] return np.all(check) def __repr__(self): diff --git a/csdmpy/dimension/quantitative.py b/csdmpy/dimension/quantitative.py index 87e9801..1441e4a 100644 --- a/csdmpy/dimension/quantitative.py +++ b/csdmpy/dimension/quantitative.py @@ -8,6 +8,7 @@ from csdmpy.units import check_quantity_name from csdmpy.units import ScalarQuantity from csdmpy.utils import _axis_label +from csdmpy.utils import assert_params from csdmpy.utils import type_error from csdmpy.utils import validate @@ -64,7 +65,14 @@ def __init__( self._equivalencies = None def __eq__(self, other): - check = [getattr(self, _) == getattr(other, _) for _ in __class__.__slots__] + non_quantitative = [ + "_quantity_name", + "_unit", + "_equivalent_unit", + "_equivalencies", + ] + quantitative = ["_coordinates_offset", "_origin_offset", "_period"] + check = assert_params(self, other, quantitative, non_quantitative) check += [super().__eq__(other)] return np.all(check) diff --git a/csdmpy/utils.py b/csdmpy/utils.py index b024f82..d2a88bc 100644 --- a/csdmpy/utils.py +++ b/csdmpy/utils.py @@ -2,6 +2,7 @@ from copy import deepcopy import numpy as np +from astropy import units as u from astropy.units.quantity import Quantity from .units import ScalarQuantity @@ -431,3 +432,28 @@ def np_check_pads(pads, n_dims): return tuple([pads[0] for _ in range(n_dims)]) return pads + + +def assert_params(obj_1, obj_2, quantitative, non_quantitative): + """Assert close quantitaive and exact non-quantitative values""" + check = [getattr(obj_1, _) == getattr(obj_2, _) for _ in non_quantitative] + for _ in quantitative: + a = getattr(obj_1, _) + b = getattr(obj_2, _) + if a is None or b is None: + check += [a == b] + else: + check += [is_quantity_equal(a, b)] + return check + + +def is_quantity_equal(quantity_a, quantity_b): + """Check if two quantities are within number precision.""" + try: + if isinstance(quantity_a, u.Quantity): + quantity_a = quantity_a.to(quantity_b.unit).value + quantity_b = quantity_b.value + res = np.allclose(quantity_a, quantity_b, equal_nan=True) + except u.UnitConversionError: + res = False + return res diff --git a/tests/dimension_test.py b/tests/dimension_test.py index 38246f2..7302aee 100644 --- a/tests/dimension_test.py +++ b/tests/dimension_test.py @@ -707,6 +707,19 @@ def test_dimension_scale(): assert type(dim2.quantity_name) is str +def test_dimension_equality_within_precision(): + """Test dimension equality within precision""" + dim_1 = cp.as_dimension([1, 10], type="monotonic", unit="s") + dim_2 = cp.as_dimension([1 - 1e-8, 10 + 1e-12], type="monotonic", unit="s") + + assert dim_2 == dim_1 + + dim_1 = cp.as_dimension([1, 10], type="monotonic", unit="s") + dim_2 = cp.as_dimension([1 - 1e-8, 10 + 1e-12], type="monotonic", unit="m") + + assert dim_2 != dim_1 + + def test_attribute_unit_update_linear(): """Test attribute units""" d1 = cp.as_dimension(