Skip to content

Commit

Permalink
Merge pull request #97 from deepanshs/djs/dim_equality_with_tol
Browse files Browse the repository at this point in the history
Check dimension equality within tolerance
  • Loading branch information
deepanshs authored Feb 9, 2024
2 parents 5a99998 + a94e84f commit 0ca2006
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 10 deletions.
2 changes: 1 addition & 1 deletion csdmpy/dimension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __str__(self):
def __eq__(self, other):
"""Overrides the default implementation."""
other = other.subtype if isinstance(other, Dimension) else other
return True if self.subtype == other else False
return self.subtype == other

def __mul__(self, other):
"""Multiply the Dimension object by a right scalar."""
Expand Down
4 changes: 3 additions & 1 deletion csdmpy/dimension/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import warnings
from copy import deepcopy

import numpy as np

from csdmpy.utils import validate


Expand All @@ -23,7 +25,7 @@ def __init__(self, label, application, description):
def __eq__(self, other):
"""Check if two objects are equal"""
check = [getattr(self, _) == getattr(other, _) for _ in __class__.__slots__]
return False if False in check else True
return np.all(check)

@property
def label(self):
Expand Down
6 changes: 5 additions & 1 deletion csdmpy/dimension/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from csdmpy.dimension.quantitative import ReciprocalDimension
from csdmpy.units import frequency_ratio
from csdmpy.units import ScalarQuantity
from csdmpy.utils import assert_params
from csdmpy.utils import check_and_assign_bool
from csdmpy.utils import check_scalar_object
from csdmpy.utils import validate
Expand Down Expand Up @@ -70,7 +71,10 @@ def __eq__(self, other):
other = other.subtype if hasattr(other, "subtype") else other
if not isinstance(other, LinearDimension):
return False
check = [getattr(self, _) == getattr(other, _) for _ in __class__.__slots__[:4]]

non_quantitative = ["reciprocal", "_complex_fft"]
quantitative = ["_count", "_increment"]
check = assert_params(self, other, quantitative, non_quantitative)
check += [super().__eq__(other)]
return np.all(check)

Expand Down
11 changes: 5 additions & 6 deletions csdmpy/dimension/monotonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from csdmpy.units import frequency_ratio
from csdmpy.units import scalar_quantity_format
from csdmpy.units import ScalarQuantity
from csdmpy.utils import assert_params
from csdmpy.utils import check_scalar_object


Expand Down Expand Up @@ -63,12 +64,10 @@ def __eq__(self, other):
if not isinstance(other, MonotonicDimension):
return False

check = [
self._count == other._count,
np.all(self._coordinates == other._coordinates),
self.reciprocal == other.reciprocal,
super().__eq__(other),
]
non_quantitative = ["reciprocal"]
quantitative = ["_count", "_coordinates"]
check = assert_params(self, other, quantitative, non_quantitative)
check += [super().__eq__(other)]
return np.all(check)

def __repr__(self):
Expand Down
10 changes: 9 additions & 1 deletion csdmpy/dimension/quantitative.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from csdmpy.units import check_quantity_name
from csdmpy.units import ScalarQuantity
from csdmpy.utils import _axis_label
from csdmpy.utils import assert_params
from csdmpy.utils import type_error
from csdmpy.utils import validate

Expand Down Expand Up @@ -64,7 +65,14 @@ def __init__(
self._equivalencies = None

def __eq__(self, other):
check = [getattr(self, _) == getattr(other, _) for _ in __class__.__slots__]
non_quantitative = [
"_quantity_name",
"_unit",
"_equivalent_unit",
"_equivalencies",
]
quantitative = ["_coordinates_offset", "_origin_offset", "_period"]
check = assert_params(self, other, quantitative, non_quantitative)
check += [super().__eq__(other)]
return np.all(check)

Expand Down
26 changes: 26 additions & 0 deletions csdmpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from copy import deepcopy

import numpy as np
from astropy import units as u
from astropy.units.quantity import Quantity

from .units import ScalarQuantity
Expand Down Expand Up @@ -431,3 +432,28 @@ def np_check_pads(pads, n_dims):
return tuple([pads[0] for _ in range(n_dims)])

return pads


def assert_params(obj_1, obj_2, quantitative, non_quantitative):
"""Assert close quantitaive and exact non-quantitative values"""
check = [getattr(obj_1, _) == getattr(obj_2, _) for _ in non_quantitative]
for _ in quantitative:
a = getattr(obj_1, _)
b = getattr(obj_2, _)
if a is None or b is None:
check += [a == b]
else:
check += [is_quantity_equal(a, b)]
return check


def is_quantity_equal(quantity_a, quantity_b):
"""Check if two quantities are within number precision."""
try:
if isinstance(quantity_a, u.Quantity):
quantity_a = quantity_a.to(quantity_b.unit).value
quantity_b = quantity_b.value
res = np.allclose(quantity_a, quantity_b, equal_nan=True)
except u.UnitConversionError:
res = False
return res
13 changes: 13 additions & 0 deletions tests/dimension_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,19 @@ def test_dimension_scale():
assert type(dim2.quantity_name) is str


def test_dimension_equality_within_precision():
"""Test dimension equality within precision"""
dim_1 = cp.as_dimension([1, 10], type="monotonic", unit="s")
dim_2 = cp.as_dimension([1 - 1e-8, 10 + 1e-12], type="monotonic", unit="s")

assert dim_2 == dim_1

dim_1 = cp.as_dimension([1, 10], type="monotonic", unit="s")
dim_2 = cp.as_dimension([1 - 1e-8, 10 + 1e-12], type="monotonic", unit="m")

assert dim_2 != dim_1


def test_attribute_unit_update_linear():
"""Test attribute units"""
d1 = cp.as_dimension(
Expand Down

0 comments on commit 0ca2006

Please sign in to comment.