Skip to content

Commit

Permalink
Raise error if atlas affines don't match across participants (#1075)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo authored Mar 12, 2024
1 parent f9a597d commit 9fe6603
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 163 deletions.
120 changes: 120 additions & 0 deletions xcp_d/interfaces/bids.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Adapted interfaces from Niworkflows."""

import os
import shutil
from json import loads
from pathlib import Path

import nibabel as nb
import numpy as np
from bids.layout import Config
from nipype import logging
from nipype.interfaces.base import (
Expand All @@ -16,6 +20,8 @@
from niworkflows.interfaces.bids import DerivativesDataSink as BaseDerivativesDataSink
from pkg_resources import resource_filename as pkgrf

from xcp_d.utils.bids import get_entity

# NOTE: Modified for xcpd's purposes
xcp_d_spec = loads(Path(pkgrf("xcp_d", "data/xcp_d_bids_config.json")).read_text())
bids_config = Config.load("bids")
Expand Down Expand Up @@ -190,3 +196,117 @@ def _run_interface(self, runtime):
)

return runtime


class _CopyAtlasInputSpec(BaseInterfaceInputSpec):
name_source = traits.Str(
desc="The source file's name.",
mandatory=False,
)
in_file = File(
exists=True,
desc="The atlas file to copy.",
mandatory=True,
)
output_dir = Directory(
exists=True,
desc="The output directory.",
mandatory=True,
)
atlas = traits.Str(
desc="The atlas name.",
mandatory=True,
)


class _CopyAtlasOutputSpec(TraitedSpec):
out_file = File(
exists=True,
desc="The copied atlas file.",
)


class CopyAtlas(SimpleInterface):
"""Copy atlas file to output directory.
Parameters
----------
name_source : :obj:`str`
The source name of the atlas file.
in_file : :obj:`str`
The atlas file to copy.
output_dir : :obj:`str`
The output directory.
atlas : :obj:`str`
The name of the atlas.
Returns
-------
out_file : :obj:`str`
The path to the copied atlas file.
Notes
-----
I can't use DerivativesDataSink because it has a problem with dlabel CIFTI files.
It gives the following error:
"AttributeError: 'Cifti2Header' object has no attribute 'set_data_dtype'"
I can't override the CIFTI atlas's data dtype ahead of time because setting it to int8 or int16
somehow converts all of the values in the data array to weird floats.
This could be a version-specific nibabel issue.
I've also updated this function to handle JSON and TSV files as well.
"""

input_spec = _CopyAtlasInputSpec
output_spec = _CopyAtlasOutputSpec

def _run_interface(self, runtime):
output_dir = self.inputs.output_dir
in_file = self.inputs.in_file
name_source = self.inputs.name_source
atlas = self.inputs.atlas

atlas_out_dir = os.path.join(output_dir, f"xcp_d/atlases/atlas-{atlas}")

if in_file.endswith(".json"):
out_basename = f"atlas-{atlas}_dseg.json"
elif in_file.endswith(".tsv"):
out_basename = f"atlas-{atlas}_dseg.tsv"
else:
extension = ".nii.gz" if name_source.endswith(".nii.gz") else ".dlabel.nii"
space = get_entity(name_source, "space")
res = get_entity(name_source, "res")
den = get_entity(name_source, "den")
cohort = get_entity(name_source, "cohort")

cohort_str = f"_cohort-{cohort}" if cohort else ""
res_str = f"_res-{res}" if res else ""
den_str = f"_den-{den}" if den else ""
if extension == ".dlabel.nii":
out_basename = f"space-{space}_atlas-{atlas}{den_str}{cohort_str}_dseg{extension}"
elif extension == ".nii.gz":
out_basename = f"space-{space}_atlas-{atlas}{res_str}{cohort_str}_dseg{extension}"

os.makedirs(atlas_out_dir, exist_ok=True)
out_file = os.path.join(atlas_out_dir, out_basename)

if out_file.endswith(".nii.gz") and os.path.isfile(out_file):
# Check that native-resolution atlas doesn't have a different resolution from the last
# run's atlas.
old_img = nb.load(out_file)
new_img = nb.load(in_file)
if not np.allclose(old_img.affine, new_img.affine):
raise ValueError(
f"Existing '{atlas}' atlas affine ({out_file}) is different from the input "
f"file affine ({in_file})."
)

# Don't copy the file if it exists, to prevent any race conditions between parallel
# processes.
if not os.path.isfile(out_file):
shutil.copyfile(in_file, out_file)

self._results["out_file"] = out_file

return runtime
84 changes: 84 additions & 0 deletions xcp_d/tests/test_interfaces_bids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Tests for xcp_d.interfaces.bids."""

import os

import pytest

from xcp_d.interfaces import bids
from xcp_d.utils import atlas


def test_copy_atlas(tmp_path_factory):
"""Test xcp_d.interfaces.bids.CopyAtlas."""
tmpdir = tmp_path_factory.mktemp("test_copy_atlas")
os.makedirs(os.path.join(tmpdir, "xcp_d"), exist_ok=True)

# NIfTI
atlas_file, _, _ = atlas.get_atlas_nifti("Gordon")
name_source = "sub-01_task-A_run-01_space-MNI152NLin2009cAsym_res-2_desc-z_bold.nii.gz"
copyatlas = bids.CopyAtlas(
name_source=name_source, in_file=atlas_file, output_dir=tmpdir, atlas="Y"
)
result = copyatlas.run(cwd=tmpdir)
assert os.path.isfile(result.outputs.out_file)
assert (
os.path.basename(result.outputs.out_file)
== "space-MNI152NLin2009cAsym_atlas-Y_res-2_dseg.nii.gz"
)

# Check that the NIfTI file raises an error if the resolution varies
# Gordon atlas is 1mm, HCP is 2mm
atlas_file_diff_affine, _, _ = atlas.get_atlas_nifti("HCP")
with pytest.raises(ValueError, match="is different from the input file affine"):
copyatlas = bids.CopyAtlas(
name_source=name_source,
in_file=atlas_file_diff_affine,
output_dir=tmpdir,
atlas="Y",
)
copyatlas.run(cwd=tmpdir)

# CIFTI
atlas_file, atlas_labels_file, atlas_metadata_file = atlas.get_atlas_cifti("Gordon")
name_source = "sub-01_task-imagery_run-01_space-fsLR_den-91k_desc-denoised_bold.dtseries.nii"
copyatlas = bids.CopyAtlas(
name_source=name_source, in_file=atlas_file, output_dir=tmpdir, atlas="Y"
)
result = copyatlas.run(cwd=tmpdir)
assert os.path.isfile(result.outputs.out_file)
assert (
os.path.basename(result.outputs.out_file) == "space-fsLR_atlas-Y_den-91k_dseg.dlabel.nii"
)

# TSV
name_source = "sub-01_task-imagery_run-01_space-fsLR_den-91k_desc-denoised_bold.dtseries.nii"
copyatlas = bids.CopyAtlas(
name_source=name_source, in_file=atlas_labels_file, output_dir=tmpdir, atlas="Y"
)
result = copyatlas.run(cwd=tmpdir)
assert os.path.isfile(result.outputs.out_file)
assert os.path.basename(result.outputs.out_file) == "atlas-Y_dseg.tsv"

# JSON
name_source = "sub-01_task-imagery_run-01_space-fsLR_den-91k_desc-denoised_bold.dtseries.nii"
copyatlas = bids.CopyAtlas(
name_source=name_source, in_file=atlas_metadata_file, output_dir=tmpdir, atlas="Y"
)
result = copyatlas.run(cwd=tmpdir)
assert os.path.isfile(result.outputs.out_file)
assert os.path.basename(result.outputs.out_file) == "atlas-Y_dseg.json"

# Ensure that out_file isn't overwritten if it already exists
fake_in_file = os.path.join(tmpdir, "fake.json")
with open(fake_in_file, "w") as fo:
fo.write("fake")

copyatlas = bids.CopyAtlas(
name_source=name_source, in_file=fake_in_file, output_dir=tmpdir, atlas="Y"
)
result = copyatlas.run(cwd=tmpdir)
assert os.path.isfile(result.outputs.out_file)
assert os.path.basename(result.outputs.out_file) == "atlas-Y_dseg.json"
# The file should not be overwritten, so the contents shouldn't be "fake"
with open(result.outputs.out_file, "r") as fo:
assert fo.read() != "fake"
54 changes: 0 additions & 54 deletions xcp_d/tests/test_utils_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,57 +45,3 @@ def test_get_atlas_cifti():

with pytest.raises(FileNotFoundError, match="DNE"):
atlas.get_atlas_cifti("tofail")


def test_copy_atlas(tmp_path_factory):
"""Test xcp_d.utils.atlas.copy_atlas."""
tmpdir = tmp_path_factory.mktemp("test_copy_atlas")
os.makedirs(os.path.join(tmpdir, "xcp_d"), exist_ok=True)

# NIfTI
atlas_file, _, _ = atlas.get_atlas_nifti("Gordon")
name_source = "sub-01_task-A_run-01_space-MNI152NLin2009cAsym_res-2_desc-z_bold.nii.gz"
out_file = atlas.copy_atlas(
name_source=name_source, in_file=atlas_file, output_dir=tmpdir, atlas="Y"
)
assert os.path.isfile(out_file)
assert os.path.basename(out_file) == "space-MNI152NLin2009cAsym_atlas-Y_res-2_dseg.nii.gz"

# CIFTI
atlas_file, atlas_labels_file, atlas_metadata_file = atlas.get_atlas_cifti("Gordon")
name_source = "sub-01_task-imagery_run-01_space-fsLR_den-91k_desc-denoised_bold.dtseries.nii"
out_file = atlas.copy_atlas(
name_source=name_source, in_file=atlas_file, output_dir=tmpdir, atlas="Y"
)
assert os.path.isfile(out_file)
assert os.path.basename(out_file) == "space-fsLR_atlas-Y_den-91k_dseg.dlabel.nii"

# TSV
name_source = "sub-01_task-imagery_run-01_space-fsLR_den-91k_desc-denoised_bold.dtseries.nii"
out_file = atlas.copy_atlas(
name_source=name_source, in_file=atlas_labels_file, output_dir=tmpdir, atlas="Y"
)
assert os.path.isfile(out_file)
assert os.path.basename(out_file) == "atlas-Y_dseg.tsv"

# JSON
name_source = "sub-01_task-imagery_run-01_space-fsLR_den-91k_desc-denoised_bold.dtseries.nii"
out_file = atlas.copy_atlas(
name_source=name_source, in_file=atlas_metadata_file, output_dir=tmpdir, atlas="Y"
)
assert os.path.isfile(out_file)
assert os.path.basename(out_file) == "atlas-Y_dseg.json"

# Ensure that out_file isn't overwritten if it already exists
fake_in_file = os.path.join(tmpdir, "fake.json")
with open(fake_in_file, "w") as fo:
fo.write("fake")

out_file = atlas.copy_atlas(
name_source=name_source, in_file=fake_in_file, output_dir=tmpdir, atlas="Y"
)
assert os.path.isfile(out_file)
assert os.path.basename(out_file) == "atlas-Y_dseg.json"
# The file should not be overwritten, so the contents shouldn't be "fake"
with open(out_file, "r") as fo:
assert fo.read() != "fake"
65 changes: 0 additions & 65 deletions xcp_d/utils/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,68 +155,3 @@ def get_atlas_cifti(atlas):
)

return atlas_file, atlas_labels_file, atlas_metadata_file


def copy_atlas(name_source, in_file, output_dir, atlas):
"""Copy atlas file to output directory.
Parameters
----------
name_source : :obj:`str`
The source name of the atlas file.
in_file : :obj:`str`
The atlas file to copy.
output_dir : :obj:`str`
The output directory.
atlas : :obj:`str`
The name of the atlas.
Returns
-------
out_file : :obj:`str`
The path to the copied atlas file.
Notes
-----
I can't use DerivativesDataSink because it has a problem with dlabel CIFTI files.
It gives the following error:
"AttributeError: 'Cifti2Header' object has no attribute 'set_data_dtype'"
I can't override the CIFTI atlas's data dtype ahead of time because setting it to int8 or int16
somehow converts all of the values in the data array to weird floats.
This could be a version-specific nibabel issue.
I've also updated this function to handle JSON and TSV files as well.
"""
import os
import shutil

from xcp_d.utils.bids import get_entity

if in_file.endswith(".json"):
out_basename = f"atlas-{atlas}_dseg.json"
elif in_file.endswith(".tsv"):
out_basename = f"atlas-{atlas}_dseg.tsv"
else:
extension = ".nii.gz" if name_source.endswith(".nii.gz") else ".dlabel.nii"
space = get_entity(name_source, "space")
res = get_entity(name_source, "res")
den = get_entity(name_source, "den")
cohort = get_entity(name_source, "cohort")

cohort_str = f"_cohort-{cohort}" if cohort else ""
res_str = f"_res-{res}" if res else ""
den_str = f"_den-{den}" if den else ""
if extension == ".dlabel.nii":
out_basename = f"space-{space}_atlas-{atlas}{den_str}{cohort_str}_dseg{extension}"
elif extension == ".nii.gz":
out_basename = f"space-{space}_atlas-{atlas}{res_str}{cohort_str}_dseg{extension}"

atlas_out_dir = os.path.join(output_dir, f"xcp_d/atlases/atlas-{atlas}")
os.makedirs(atlas_out_dir, exist_ok=True)
out_file = os.path.join(atlas_out_dir, out_basename)
# Don't copy the file if it exists, to prevent any race conditions between parallel processes.
if not os.path.isfile(out_file):
shutil.copyfile(in_file, out_file)

return out_file
9 changes: 5 additions & 4 deletions xcp_d/utils/bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,12 @@ def collect_data(
# This probably works well for resolution (1 typically means 1x1x1,
# 2 typically means 2x2x2, etc.), but probably doesn't work well for density.
resolutions = layout.get_res(**queries["bold"])
densities = layout.get_den(**queries["bold"])
if len(resolutions) > 1:
queries["bold"]["resolution"] = resolutions[0]
if len(resolutions) >= 1:
# This will also select res-* when there are both res-* and native-resolution files.
queries["bold"]["res"] = resolutions[0]

if len(densities) > 1:
densities = layout.get_den(**queries["bold"])
if len(densities) >= 1:
queries["bold"]["den"] = densities[0]

# Check for anatomical images, and determine if T2w xfms must be used.
Expand Down
Loading

0 comments on commit 9fe6603

Please sign in to comment.