Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Select function in segmentation resampling workflow #450

Merged
merged 5 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions smriprep/workflows/surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,11 +1170,11 @@
def init_segs_to_native_wf(
*,
image_type: ty.Literal['T1w', 'T2w'] = 'T1w',
segmentation: ty.Literal['aseg', 'aparc_aseg', 'wmparc'] = 'aseg',
segmentation: ty.Literal['aseg', 'aparc_aseg', 'aparc_a2009s', 'aparc_dkt'] | str = 'aseg',
name: str = 'segs_to_native_wf',
) -> Workflow:
"""
Get a segmentation from FreeSurfer conformed space into native T1w space.
Get a segmentation from FreeSurfer conformed space into native anatomical space.

Workflow Graph
.. workflow::
Expand Down Expand Up @@ -1219,30 +1219,15 @@

lta = pe.Node(ConcatenateXFMs(out_fmt='fs'), name='lta', run_without_submitting=True)

# Resample from T1.mgz to T1w.nii.gz, applying any offset in fsnative2anat_xfm,
# Resample from Freesurfer anat to native anat, applying any offset in fsnative2anat_xfm,
# and convert to NIfTI while we're at it
resample = pe.Node(
fs.ApplyVolTransform(transformed_file='seg.nii.gz', interp='nearest'),
name='resample',
)

if segmentation.startswith('aparc'):
if segmentation == 'aparc_aseg':

def _sel(x):
return [parc for parc in x if 'aparc+' in parc][0] # noqa

elif segmentation == 'aparc_a2009s':

def _sel(x):
return [parc for parc in x if 'a2009s+' in parc][0] # noqa

elif segmentation == 'aparc_dkt':

def _sel(x):
return [parc for parc in x if 'DKTatlas+' in parc][0] # noqa

segmentation = (segmentation, _sel)
select_seg = pe.Node(niu.Function(function=_select_seg), name='select_seg')
select_seg.inputs.segmentation = segmentation

anat = 'T2' if image_type == 'T2w' else 'T1'

Expand All @@ -1254,7 +1239,8 @@
('fsnative2anat_xfm', 'in_xfms')]),
(fssource, lta, [(anat, 'moving')]),
(inputnode, resample, [('in_file', 'target_file')]),
(fssource, resample, [(segmentation, 'source_file')]),
(fssource, select_seg, [(segmentation, 'in_files')]),
(select_seg, resample, [('out', 'source_file')]),
(lta, resample, [('out_xfm', 'lta_file')]),
(resample, outputnode, [('transformed_file', 'out_file')]),
]) # fmt:skip
Expand Down Expand Up @@ -1678,3 +1664,17 @@

ret = tuple(all_surfs[surface] for surface in surfaces)
return ret if len(ret) > 1 else ret[0]


def _select_seg(in_files, segmentation):
if isinstance(in_files, str):
return in_files

seg_mapping = {'aparc_aseg': 'aparc+', 'aparc_a2009s': 'a2009s+', 'aparc_dkt': 'DKTatlas+'}
if segmentation in seg_mapping:
segmentation = seg_mapping[segmentation]

for fl in in_files:
if segmentation in fl:
return fl
raise FileNotFoundError(f'No segmentation containing "{segmentation}" was found.')

Check warning on line 1680 in smriprep/workflows/surfaces.py

View check run for this annotation

Codecov / codecov/patch

smriprep/workflows/surfaces.py#L1680

Added line #L1680 was not covered by tests
15 changes: 14 additions & 1 deletion smriprep/workflows/tests/test_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from smriprep.interfaces.tests.data import load as load_test_data

from ..surfaces import init_anat_ribbon_wf, init_gifti_surfaces_wf
from ..surfaces import _select_seg, init_anat_ribbon_wf, init_gifti_surfaces_wf


def test_ribbon_workflow(tmp_path: Path):
Expand Down Expand Up @@ -53,3 +53,16 @@ def test_ribbon_workflow(tmp_path: Path):
assert np.allclose(ribbon.affine, expected.affine)
# Mask data is binary, so we can use np.array_equal
assert np.array_equal(ribbon.dataobj, expected.dataobj)


@pytest.mark.parametrize(
('in_files', 'segmentation', 'expected'),
[
('aparc+aseg.mgz', 'aparc_aseg', 'aparc+aseg.mgz'),
(['a2009s+aseg.mgz', 'aparc+aseg.mgz'], 'aparc_aseg', 'aparc+aseg.mgz'),
(['a2009s+aseg.mgz', 'aparc+aseg.mgz'], 'aparc_a2009s', 'a2009s+aseg.mgz'),
('wmparc.mgz', 'wmparc.mgz', 'wmparc.mgz'),
],
)
def test_select_seg(in_files, segmentation, expected):
assert _select_seg(in_files, segmentation) == expected