Skip to content

Commit

Permalink
Add mean_afid_set function
Browse files Browse the repository at this point in the history
This commit adds a metric function that computes the mean spatial
position given a list of AfidSets. Adds in associated tests and
docstrings.

Additionally, makes use of lambda function rather than iterating across
list for tests.
  • Loading branch information
kaitj committed Sep 27, 2023
1 parent b446618 commit 4ed6d70
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 39 deletions.
68 changes: 42 additions & 26 deletions afids_utils/metrics.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,55 @@
"""Methods for computing various metrics pertaining to AFIDs"""
from __future__ import annotations

from afids_utils.afids import AfidDistance, AfidPosition
from afids_utils.afids import AfidPosition, AfidSet


def compute_distance(
afid_position: AfidPosition, template_position: AfidPosition
) -> AfidDistance:
"""Compute distance between two AfidPositions, returning Euclidean
distance, as well as distances along each spatial dimension
def compute_mean_afid_sets(afid_sets: list[AfidSet]) -> AfidSet:
"""Compute the mean spatial coordinates of corresponding AFIDs across a
list of ``AfidSet`` objects.
Parameters
----------
afid_position
Input AfidPosition containing floating-point spatial coordinates
(x, y, z)
template_position
Template AfidPosition to compute distance against
afid_sets
List of AfidSets to compute mean from
Returns
-------
AfidDistance
Object containing distances along each spatial dimension
(x, y, z) and Euclidean distance respectively
AfidSet
Object containing mean spatial components for each AFID
Raises
------
ValueError
If there are different coordinate systems in provided list of
``AfidSet`` objects
"""
# Compute distances
x_dist, y_dist, z_dist = afid_position - template_position
euc = (x_dist**2 + y_dist**2 + z_dist**2) ** 0.5

return AfidDistance(
label=afid_position.label,
desc=afid_position.desc,
x=x_dist,
y=y_dist,
z=z_dist,
euc=euc,
# Check if coordinate systems are all the same
if not all(
afid_set.coord_system == afid_sets[0].coord_system
for afid_set in afid_sets
):
raise ValueError(
"Mismatched coordinate system in provided list of AfidSet"
)

num_sets = len(afid_sets)
mean_afid_set = AfidSet(
slicer_version="Unknown",
coord_system=afid_sets[0].coord_system,
afids=[
AfidPosition(
label=afid_sets[0].afids[idx].label,
x=sum(afid_set.afids[idx].x for afid_set in afid_sets)
/ num_sets,
y=sum(afid_set.afids[idx].y for afid_set in afid_sets)
/ num_sets,
z=sum(afid_set.afids[idx].z for afid_set in afid_sets)
/ num_sets,
desc=afid_sets[0].afids[idx].desc,
)
for idx in range(len(afid_sets[0].afids))
],
)

return mean_afid_set
11 changes: 8 additions & 3 deletions afids_utils/tests/test_afids.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,14 @@ def test_valid_afid_set(self, afid_set1: AfidSet, afid_set2: AfidSet):
)

# Check afids property is correct (list[AfidDistances])
assert isinstance(afid_distance_set.afids, list)
for idx in range(len(afid_set1.afids)):
assert isinstance(afid_distance_set.afids[idx], AfidDistance)
assert isinstance(afid_distance_set.afids, list) and all(
list(
map(
lambda x: isinstance(x, AfidDistance),
afid_distance_set.afids,
)
)
)

@given(
afid_set1=af_st.afid_sets(randomize_header=False),
Expand Down
38 changes: 28 additions & 10 deletions afids_utils/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
from __future__ import annotations

import pytest
from hypothesis import given

import afids_utils.metrics as af_metrics
import afids_utils.tests.strategies as af_st
from afids_utils.afids import AfidDistance, AfidPosition
from afids_utils.afids import AfidPosition, AfidSet


class TestMetrics:
@given(pos1=af_st.afid_positions(), pos2=af_st.afid_positions())
def test_compute_distance(self, pos1: AfidPosition, pos2: AfidPosition):
res = af_metrics.compute_distance(pos1, pos2)
assert isinstance(res, AfidDistance)
class TestMeanAfidSet:
@given(
afid_set1=af_st.afid_sets(randomize_header=False),
afid_set2=af_st.afid_sets(randomize_header=False),
)
def test_mismatched_coords(self, afid_set1: AfidSet, afid_set2: AfidSet):
# Manually set the coord system to mistmatch
afid_set2.coord_system = "LPS"

with pytest.raises(ValueError, match=r"Mismatched coordinate.*"):
af_metrics.compute_mean_afid_sets([afid_set1, afid_set2])

@given(
afid_set1=af_st.afid_sets(randomize_header=False),
afid_set2=af_st.afid_sets(randomize_header=False),
)
def test_valid_afid_sets(self, afid_set1: AfidSet, afid_set2: AfidSet):
mean_afid_set = af_metrics.compute_mean_afid_sets(
[afid_set1, afid_set2]
)

# Check internals
assert isinstance(res.x, float)
assert isinstance(res.y, float)
assert isinstance(res.z, float)
assert isinstance(res.euc, float) and res.euc >= 0
assert isinstance(mean_afid_set.afids, list) and all(
list(
map(lambda x: isinstance(x, AfidPosition), mean_afid_set.afids)
)
)

0 comments on commit 4ed6d70

Please sign in to comment.