Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft spectral analysis-based interpolation #1138

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"sentry-sdk ~= 2.15.0", # for usage reports
"templateflow ~= 24.2.0",
"toml",
"tqdm",
]
dynamic = ["version"]

Expand Down
11 changes: 11 additions & 0 deletions xcp_d/data/boilerplate.bib
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,17 @@ @article{lindquist2019modular
doi={10.1002/hbm.24528}
}

@article{mathias2004algorithms,
title={Algorithms for spectral analysis of irregularly sampled time series},
author={Mathias, Adolf and Grond, Florian and Guardans, Ramon and Seese, Detlef and Canela, Miguel and Diebner, Hans H},
journal={Journal of Statistical Software},
volume={11},
pages={1--27},
year={2004},
url={https://doi.org/10.18637/jss.v011.i02},
doi={10.18637/jss.v011.i02}
}

@article{hermosillo2022precision,
title={A precision functional atlas of network probabilities and individual-specific network topography},
author={Hermosillo, Robert JM and Moore, Lucille A and Fezcko, Eric and Dworetsky, Ally and Pines, Adam and Conan, Gregory and Mooney, Michael A and Randolph, Anita and Adeyemo, Babatunde and Earl, Eric and others},
Expand Down
177 changes: 145 additions & 32 deletions xcp_d/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def denoise_with_nilearn(


def _interpolate(*, arr, sample_mask, TR):
"""Replace high-motion volumes with cubic-spline interpolated values.
"""Replace high-motion volumes with a Lomb-Scargle periodogram-based interpolation method.

This function applies Nilearn's :func:`~nilearn.signal._interpolate_volumes` function,
followed by an extra step that replaces extrapolated, high-motion values at the beginning and
Expand All @@ -515,46 +515,159 @@ def _interpolate(*, arr, sample_mask, TR):
Notes
-----
This function won't work if sample_mask is all zeros, but that should never happen.

The function uses the least squares spectral analysis method described in
:footcite:t:`power_fd_dvars`, which in turn is based on the method described in
:footcite:t:`mathias2004algorithms`.

References
----------
.. footbibliography::
"""
from nilearn import signal
from tqdm import trange

outlier_idx = list(np.where(~sample_mask)[0])
n_volumes = arr.shape[0]
from xcp_d.utils.utils import get_transform

interpolated_arr = signal._interpolate_volumes(
arr,
sample_mask=sample_mask,
t_r=TR,
extrapolate=True,
)
# Replace any high-motion volumes at the beginning or end of the run with the closest
# low-motion volume's data.
# Use https://stackoverflow.com/a/48106843/2589328 to group consecutive blocks of outliers.
gaps = [[start, end] for start, end in zip(outlier_idx, outlier_idx[1:]) if start + 1 < end]
edges = iter(outlier_idx[:1] + sum(gaps, []) + outlier_idx[-1:])
consecutive_outliers_idx = list(zip(edges, edges))
first_outliers = consecutive_outliers_idx[0]
last_outliers = consecutive_outliers_idx[-1]

# Replace outliers at beginning of run
if first_outliers[0] == 0:
LOGGER.warning(
f"Outlier volumes at beginning of run ({first_outliers[0]}-{first_outliers[1]}) "
"will be replaced with first non-outlier volume's values."
n_volumes = arr.shape[0]
time = np.arange(0, n_volumes * TR, TR)
censored_time = time[sample_mask]
n_voxels = arr.shape[1]

interpolated_arr = arr.copy()
for i_voxel in trange(n_voxels, desc="Interpolating high-motion volumes"):
voxel_data = arr[:, i_voxel]
interpolated_voxel_data = get_transform(
censored_time=censored_time,
arr=voxel_data[sample_mask],
uncensored_time=time,
oversampling_factor=4,
TR=TR,
)
interpolated_arr[: first_outliers[1] + 1, :] = interpolated_arr[first_outliers[1] + 1, :]

# Replace outliers at end of run
if last_outliers[1] == n_volumes - 1:
LOGGER.warning(
f"Outlier volumes at end of run ({last_outliers[0]}-{last_outliers[1]}) "
"will be replaced with last non-outlier volume's values."
)
interpolated_arr[last_outliers[0] :, :] = interpolated_arr[last_outliers[0] - 1, :]
# Replace high-motion volumes in interpolated array with the modified data
interpolated_arr[~sample_mask, i_voxel] = interpolated_voxel_data[~sample_mask, 0]

return interpolated_arr


def get_transform(*, censored_time, arr, uncensored_time, oversampling_factor, TR):
"""Interpolate high-motion volumes in a time series using least squares spectral analysis.

Parameters
----------
censored_time : ndarray of shape (C,)
Time points for which observations are present.
C = number of low-motion time points
arr : ndarray of shape (C, S)
Observations in columns. The number of rows equals the number of time points.
C = number of low-motion time points
S = number of voxels
uncensored_time : ndarray of shape (T,)
Time points for which to reconstruct the original time series.
T = total number of time points
oversampling_factor : int
Oversampling frequency, generally >= 4.

Returns
-------
reconstructed_arr : ndarray of shape (T, S)
The reconstructed time series.
T = number of time points in uncensored_time
S = number of voxels in arr

Notes
-----
This function is translated from Anish Mitra's MATLAB function ``getTransform``,
available at https://www.jonathanpower.net/2014-ni-motion-2.html.

The function implements the least squares spectral analysis method described in
:footcite:t:`power_fd_dvars`, which in turn is based on the method described in
:footcite:t:`mathias2004algorithms`.

References
----------
.. footbibliography::
"""
import warnings

import numpy as np

assert arr.ndim == 1
assert censored_time.ndim == 1
assert uncensored_time.ndim == 1
assert arr.shape[0] == censored_time.shape[0]
assert uncensored_time.shape[0] > censored_time.shape[0]

arr = arr[:, None]
fs = 1 / TR
n_volumes = arr.shape[0] # Number of time points in censored array
n_voxels = arr.shape[1] # Number of voxels
time_span = np.max(censored_time) - np.min(censored_time) # Total time span
n_oversampled_timepoints = int((time_span / TR) * oversampling_factor)

# calculate sampling frequencies
max_freq = 0.5 * fs
frequencies_hz = np.linspace(
1 / (time_span * oversampling_factor),
max_freq,
n_oversampled_timepoints,
)

# angular frequencies and constant offsets
frequencies_angular = 2 * np.pi * frequencies_hz
offsets = np.arctan2(
np.sum(np.sin(2 * np.dot(frequencies_angular[:, None], censored_time[None, :])), axis=1),
np.sum(np.cos(2 * np.dot(frequencies_angular[:, None], censored_time[None, :])), axis=1),
) / (2 * frequencies_angular)

# spectral power sin and cosine terms
spectral_power = np.dot(frequencies_angular[:, None], censored_time[None, :]) - (
frequencies_angular[:, None] * offsets[:, None]
)
cterm = np.cos(spectral_power)
sterm = np.sin(spectral_power)

D = arr.copy()
D = D.reshape((1, n_volumes, n_voxels))

# This calculation is done by separately for the numerator, denominator, and the division
cos_mult = cterm[:, :, None] * D
numerator = np.sum(cos_mult, axis=1)
denominator = np.sum(cterm**2, axis=1)[:, None]
power_cos = numerator / denominator

# Repeat the above for Sine term
sin_mult = sterm[:, :, None] * D
numerator = np.sum(sin_mult, axis=1)
denominator = np.sum(sterm**2, axis=1)[:, None]
power_sin = numerator / denominator

# The inverse function to re-construct the original time series
T_rep = np.repeat(uncensored_time[None, :], repeats=len(frequencies_hz), axis=0)[:, :, None]
prod = T_rep * frequencies_angular[:, None, None]
sin_t = np.sin(prod)
cos_t = np.cos(prod)
sw_p = sin_t * power_sin[:, None, :]
cw_p = cos_t * power_cos[:, None, :]
S = np.sum(sw_p, axis=0)
C = np.sum(cw_p, axis=0)
reconstructed_arr = C + S
reconstructed_arr = reconstructed_arr.reshape((len(uncensored_time), arr.shape[1]))

# Normalize the reconstructed spectrum, needed when oversampling_factor > 1
Std_H = np.std(reconstructed_arr, axis=0)
Std_h = np.std(arr, axis=0)
with warnings.filterwarnings("error") as w:
try:
norm_fac = Std_H / Std_h
except RuntimeWarning:
raise ValueError(arr)

reconstructed_arr = reconstructed_arr / norm_fac[None, :]

return reconstructed_arr


def _select_first(lst):
"""Select the first element in a list."""
return lst[0]
Expand Down
Loading