diff --git a/afids_utils/afids.py b/afids_utils/afids.py index 9aa70db0..20aa1f1a 100644 --- a/afids_utils/afids.py +++ b/afids_utils/afids.py @@ -349,11 +349,11 @@ def afids(self): raise ValueError("Mismatched coordinate systems") # Compute distances between AfidSets - afids = [ + distances = [ AfidDistance(afid_set1_position, afid_set2_position) for afid_set1_position, afid_set2_position in zip( self.afid_set1.afids, self.afid_set2.afids ) ] - return afids + return distances diff --git a/afids_utils/metrics.py b/afids_utils/metrics.py index d0c6eb91..15e948b6 100644 --- a/afids_utils/metrics.py +++ b/afids_utils/metrics.py @@ -1,17 +1,17 @@ """Methods for computing various metrics pertaining to AFIDs""" from __future__ import annotations -from afids_utils.afids import AfidPosition, AfidSet +from afids_utils.afids import AfidDistanceSet, AfidPosition, AfidSet -def compute_mean_afid_sets(afid_sets: list[AfidSet]) -> AfidSet: - """Compute the mean spatial coordinates of corresponding AFIDs across a - list of ``AfidSet`` objects. +def mean_afid_sets(afid_sets: list[AfidSet]) -> AfidSet: + """Calculate the average spatial coordinates for corresponding AFIDs + within a list of ``AfidSet`` objects. Parameters ---------- afid_sets - List of AfidSets to compute mean from + List of ``AfidSet`` to compute mean from Returns ------- @@ -21,9 +21,14 @@ def compute_mean_afid_sets(afid_sets: list[AfidSet]) -> AfidSet: Raises ------ ValueError - If there are different coordinate systems in provided list of - ``AfidSet`` objects + If input list does not consist of all ``AfidSet`` objects or if there + are different coordinate systems in provided list of ``AfidSet`` + objects """ + # Check to make sure input datatype is correct + if not all(list(map(lambda x: isinstance(x, AfidSet), afid_sets))): + raise ValueError("Input is not a collection of AfidSet objects") + # Check if coordinate systems are all the same if not all( afid_set.coord_system == afid_sets[0].coord_system @@ -53,3 +58,75 @@ def compute_mean_afid_sets(afid_sets: list[AfidSet]) -> AfidSet: ) return mean_afid_set + + +def mean_distances( + afid_distance_sets: list[AfidDistanceSet], +) -> list[float]: + """Calculate the average distance from a collection of + ``AfidDistanceSet`` objects. Ensure that one of the ``AfidSet`` objects + used to compute each ``AfidDistanceSet`` is consistent across all sets in + the list. + + Parameters + ---------- + afid_distance_sets + List of ``AfidDistanceSet`` objects to compute mean distance from + + Returns + ------- + list[float] + Dictionary object describing average spatial component and Euclidean + distances + + Raises + ------ + ValueError + If no single common ``AfidSet`` used to compute ``AfidDistance`` or + list does not consist of all ``AfidDistanceSet`` objects + """ + # Check to make sure all input types are correct + if not all( + list(map(lambda x: isinstance(x, AfidDistanceSet), afid_distance_sets)) + ): + raise ValueError( + "Input is not a collection of AfidDistanceSet objects" + ) + + # Check for common AfidSet + if afid_distance_sets[0].afid_set1 in [ + afid_distance_sets[1].afid_set1, + afid_distance_sets[1].afid_set2, + ]: + common_set = afid_distance_sets[0].afid_set1 + elif afid_distance_sets[0].afid_set2 in [ + afid_distance_sets[1].afid_set1, + afid_distance_sets[1].afid_set2, + ]: + common_set = afid_distance_sets[0].afid_set2 + else: + raise ValueError( + "No single common AfidSet found within AfidDistanceSet objects" + ) + + for idx in range(2, len(afid_distance_sets)): + if common_set not in [ + afid_distance_sets[idx].afid_set1, + afid_distance_sets[idx].afid_set2, + ]: + raise ValueError( + "No single common AfidSet found within AfidDistanceSet objects" + ) + + # Compute mean distance for each AFID + num_pairs = len(afid_distance_sets) + mean_distance = [ + sum( + afid_distance_set.afids[idx].distance + for afid_distance_set in afid_distance_sets + ) + / num_pairs + for idx in range(len(afid_distance_sets[0].afids)) + ] + + return mean_distance diff --git a/afids_utils/tests/strategies.py b/afids_utils/tests/strategies.py index 14567871..4900d0e6 100644 --- a/afids_utils/tests/strategies.py +++ b/afids_utils/tests/strategies.py @@ -112,7 +112,7 @@ def afid_sets( randomize_header: bool = True, ) -> AfidSet: slicer_version = draw(st.from_regex(r"\d+\.\d+")) - coord_system = draw(st.sampled_from(["RAS", "LPS", "0", "1"])) + coord_system = draw(st.sampled_from(["RAS", "LPS"])) # Set (in)valid number of Afid coordinates in a list afid_pos = [] diff --git a/afids_utils/tests/test_metrics.py b/afids_utils/tests/test_metrics.py index c053bc62..8c605534 100644 --- a/afids_utils/tests/test_metrics.py +++ b/afids_utils/tests/test_metrics.py @@ -1,37 +1,101 @@ from __future__ import annotations +from copy import deepcopy + import pytest from hypothesis import given import afids_utils.metrics as af_metrics import afids_utils.tests.strategies as af_st -from afids_utils.afids import AfidPosition, AfidSet +from afids_utils.afids import AfidDistanceSet, AfidPosition, AfidSet class TestMeanAfidSet: @given( afid_set1=af_st.afid_sets(randomize_header=False), - afid_set2=af_st.afid_sets(randomize_header=False), ) - def test_mismatched_coords(self, afid_set1: AfidSet, afid_set2: AfidSet): + def test_mismatched_coords(self, afid_set1: AfidSet): # Manually set the coord system to mistmatch + afid_set2 = deepcopy(afid_set1) afid_set2.coord_system = "LPS" with pytest.raises(ValueError, match=r"Mismatched coordinate.*"): - af_metrics.compute_mean_afid_sets([afid_set1, afid_set2]) + af_metrics.mean_afid_sets([afid_set1, afid_set2]) @given( afid_set1=af_st.afid_sets(randomize_header=False), - afid_set2=af_st.afid_sets(randomize_header=False), ) - def test_valid_afid_sets(self, afid_set1: AfidSet, afid_set2: AfidSet): - mean_afid_set = af_metrics.compute_mean_afid_sets( - [afid_set1, afid_set2] - ) + def test_valid_afid_sets(self, afid_set1: AfidSet): + mean_afid_set = af_metrics.mean_afid_sets([afid_set1, afid_set1]) # Check internals assert isinstance(mean_afid_set.afids, list) and all( list( - map(lambda x: isinstance(x, AfidPosition), mean_afid_set.afids) + map( + lambda afid: isinstance(afid, AfidPosition), + mean_afid_set.afids, + ) ) ) + + def test_invalid_input_type(self): + with pytest.raises(ValueError, match=r".*collection of AfidSet.*"): + af_metrics.mean_afid_sets(["random string"]) + + +class TestMeanDistances: + @given( + afid_set1=af_st.afid_sets(randomize_header=False), + afid_set2=af_st.afid_sets(randomize_header=False), + ) + def test_valid_afid_sets(self, afid_set1: AfidSet, afid_set2: AfidSet): + afid_distance_sets = [ + AfidDistanceSet(afid_set1, afid_set2), + AfidDistanceSet(afid_set2, afid_set2), + ] + mean_distances = af_metrics.mean_distances(afid_distance_sets) + + # Check list objects + assert isinstance(mean_distances, list) and all( + list(map(lambda dist: isinstance(dist, float), mean_distances)) + ) + assert all(list(map(lambda dist: dist >= 0, mean_distances))) + + def test_invalid_input_type(self): + with pytest.raises( + ValueError, match=r".*collection of AfidDistanceSet.*" + ): + af_metrics.mean_distances(["random string"]) + + @given( + afid_set1=af_st.afid_sets(randomize_header=False), + afid_set2=af_st.afid_sets(min_value=1.0, randomize_header=False), + ) + def test_no_common_afid_sets_short( + self, afid_set1: AfidSet, afid_set2: AfidSet + ): + afid_distance_sets = [ + AfidDistanceSet(afid_set1, afid_set1), + AfidDistanceSet(afid_set2, afid_set2), + ] + + with pytest.raises(ValueError, match=r"No single common AfidSet.*"): + af_metrics.mean_distances(afid_distance_sets) + + @given( + afid_set1=af_st.afid_sets(randomize_header=False), + afid_set2=af_st.afid_sets(min_value=1.0, randomize_header=False), + ) + def test_no_common_afid_sets_long( + self, + afid_set1: AfidSet, + afid_set2: AfidSet, + ): + afid_distance_sets = [ + AfidDistanceSet(afid_set1, afid_set1), + AfidDistanceSet(afid_set1, afid_set2), + AfidDistanceSet(afid_set2, afid_set2), + ] + + with pytest.raises(ValueError, match=r"No single common AfidSet.*"): + af_metrics.mean_distances(afid_distance_sets)