diff --git a/pyproject.toml b/pyproject.toml index e11d98077..9cdbdce0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "sentry-sdk ~= 2.15.0", # for usage reports "templateflow ~= 24.2.0", "toml", + "tqdm", ] dynamic = ["version"] diff --git a/xcp_d/data/boilerplate.bib b/xcp_d/data/boilerplate.bib index 40d1e8b12..d446bfb53 100644 --- a/xcp_d/data/boilerplate.bib +++ b/xcp_d/data/boilerplate.bib @@ -810,6 +810,17 @@ @article{lindquist2019modular doi={10.1002/hbm.24528} } +@article{mathias2004algorithms, + title={Algorithms for spectral analysis of irregularly sampled time series}, + author={Mathias, Adolf and Grond, Florian and Guardans, Ramon and Seese, Detlef and Canela, Miguel and Diebner, Hans H}, + journal={Journal of Statistical Software}, + volume={11}, + pages={1--27}, + year={2004}, + url={https://doi.org/10.18637/jss.v011.i02}, + doi={10.18637/jss.v011.i02} +} + @article{hermosillo2022precision, title={A precision functional atlas of network probabilities and individual-specific network topography}, author={Hermosillo, Robert JM and Moore, Lucille A and Fezcko, Eric and Dworetsky, Ally and Pines, Adam and Conan, Gregory and Mooney, Michael A and Randolph, Anita and Adeyemo, Babatunde and Earl, Eric and others}, diff --git a/xcp_d/utils/utils.py b/xcp_d/utils/utils.py index 0446a7a9c..07e1e131c 100644 --- a/xcp_d/utils/utils.py +++ b/xcp_d/utils/utils.py @@ -492,7 +492,7 @@ def denoise_with_nilearn( def _interpolate(*, arr, sample_mask, TR): - """Replace high-motion volumes with cubic-spline interpolated values. + """Replace high-motion volumes with a Lomb-Scargle periodogram-based interpolation method. This function applies Nilearn's :func:`~nilearn.signal._interpolate_volumes` function, followed by an extra step that replaces extrapolated, high-motion values at the beginning and @@ -515,46 +515,159 @@ def _interpolate(*, arr, sample_mask, TR): Notes ----- This function won't work if sample_mask is all zeros, but that should never happen. + + The function uses the least squares spectral analysis method described in + :footcite:t:`power_fd_dvars`, which in turn is based on the method described in + :footcite:t:`mathias2004algorithms`. + + References + ---------- + .. footbibliography:: """ - from nilearn import signal + from tqdm import trange - outlier_idx = list(np.where(~sample_mask)[0]) - n_volumes = arr.shape[0] + from xcp_d.utils.utils import get_transform - interpolated_arr = signal._interpolate_volumes( - arr, - sample_mask=sample_mask, - t_r=TR, - extrapolate=True, - ) - # Replace any high-motion volumes at the beginning or end of the run with the closest - # low-motion volume's data. - # Use https://stackoverflow.com/a/48106843/2589328 to group consecutive blocks of outliers. - gaps = [[start, end] for start, end in zip(outlier_idx, outlier_idx[1:]) if start + 1 < end] - edges = iter(outlier_idx[:1] + sum(gaps, []) + outlier_idx[-1:]) - consecutive_outliers_idx = list(zip(edges, edges)) - first_outliers = consecutive_outliers_idx[0] - last_outliers = consecutive_outliers_idx[-1] - - # Replace outliers at beginning of run - if first_outliers[0] == 0: - LOGGER.warning( - f"Outlier volumes at beginning of run ({first_outliers[0]}-{first_outliers[1]}) " - "will be replaced with first non-outlier volume's values." + n_volumes = arr.shape[0] + time = np.arange(0, n_volumes * TR, TR) + censored_time = time[sample_mask] + n_voxels = arr.shape[1] + + interpolated_arr = arr.copy() + for i_voxel in trange(n_voxels, desc="Interpolating high-motion volumes"): + voxel_data = arr[:, i_voxel] + interpolated_voxel_data = get_transform( + censored_time=censored_time, + arr=voxel_data[sample_mask], + uncensored_time=time, + oversampling_factor=4, + TR=TR, ) - interpolated_arr[: first_outliers[1] + 1, :] = interpolated_arr[first_outliers[1] + 1, :] - # Replace outliers at end of run - if last_outliers[1] == n_volumes - 1: - LOGGER.warning( - f"Outlier volumes at end of run ({last_outliers[0]}-{last_outliers[1]}) " - "will be replaced with last non-outlier volume's values." - ) - interpolated_arr[last_outliers[0] :, :] = interpolated_arr[last_outliers[0] - 1, :] + # Replace high-motion volumes in interpolated array with the modified data + interpolated_arr[~sample_mask, i_voxel] = interpolated_voxel_data[~sample_mask, 0] return interpolated_arr +def get_transform(*, censored_time, arr, uncensored_time, oversampling_factor, TR): + """Interpolate high-motion volumes in a time series using least squares spectral analysis. + + Parameters + ---------- + censored_time : ndarray of shape (C,) + Time points for which observations are present. + C = number of low-motion time points + arr : ndarray of shape (C, S) + Observations in columns. The number of rows equals the number of time points. + C = number of low-motion time points + S = number of voxels + uncensored_time : ndarray of shape (T,) + Time points for which to reconstruct the original time series. + T = total number of time points + oversampling_factor : int + Oversampling frequency, generally >= 4. + + Returns + ------- + reconstructed_arr : ndarray of shape (T, S) + The reconstructed time series. + T = number of time points in uncensored_time + S = number of voxels in arr + + Notes + ----- + This function is translated from Anish Mitra's MATLAB function ``getTransform``, + available at https://www.jonathanpower.net/2014-ni-motion-2.html. + + The function implements the least squares spectral analysis method described in + :footcite:t:`power_fd_dvars`, which in turn is based on the method described in + :footcite:t:`mathias2004algorithms`. + + References + ---------- + .. footbibliography:: + """ + import warnings + + import numpy as np + + assert arr.ndim == 1 + assert censored_time.ndim == 1 + assert uncensored_time.ndim == 1 + assert arr.shape[0] == censored_time.shape[0] + assert uncensored_time.shape[0] > censored_time.shape[0] + + arr = arr[:, None] + fs = 1 / TR + n_volumes = arr.shape[0] # Number of time points in censored array + n_voxels = arr.shape[1] # Number of voxels + time_span = np.max(censored_time) - np.min(censored_time) # Total time span + n_oversampled_timepoints = int((time_span / TR) * oversampling_factor) + + # calculate sampling frequencies + max_freq = 0.5 * fs + frequencies_hz = np.linspace( + 1 / (time_span * oversampling_factor), + max_freq, + n_oversampled_timepoints, + ) + + # angular frequencies and constant offsets + frequencies_angular = 2 * np.pi * frequencies_hz + offsets = np.arctan2( + np.sum(np.sin(2 * np.dot(frequencies_angular[:, None], censored_time[None, :])), axis=1), + np.sum(np.cos(2 * np.dot(frequencies_angular[:, None], censored_time[None, :])), axis=1), + ) / (2 * frequencies_angular) + + # spectral power sin and cosine terms + spectral_power = np.dot(frequencies_angular[:, None], censored_time[None, :]) - ( + frequencies_angular[:, None] * offsets[:, None] + ) + cterm = np.cos(spectral_power) + sterm = np.sin(spectral_power) + + D = arr.copy() + D = D.reshape((1, n_volumes, n_voxels)) + + # This calculation is done by separately for the numerator, denominator, and the division + cos_mult = cterm[:, :, None] * D + numerator = np.sum(cos_mult, axis=1) + denominator = np.sum(cterm**2, axis=1)[:, None] + power_cos = numerator / denominator + + # Repeat the above for Sine term + sin_mult = sterm[:, :, None] * D + numerator = np.sum(sin_mult, axis=1) + denominator = np.sum(sterm**2, axis=1)[:, None] + power_sin = numerator / denominator + + # The inverse function to re-construct the original time series + T_rep = np.repeat(uncensored_time[None, :], repeats=len(frequencies_hz), axis=0)[:, :, None] + prod = T_rep * frequencies_angular[:, None, None] + sin_t = np.sin(prod) + cos_t = np.cos(prod) + sw_p = sin_t * power_sin[:, None, :] + cw_p = cos_t * power_cos[:, None, :] + S = np.sum(sw_p, axis=0) + C = np.sum(cw_p, axis=0) + reconstructed_arr = C + S + reconstructed_arr = reconstructed_arr.reshape((len(uncensored_time), arr.shape[1])) + + # Normalize the reconstructed spectrum, needed when oversampling_factor > 1 + Std_H = np.std(reconstructed_arr, axis=0) + Std_h = np.std(arr, axis=0) + with warnings.filterwarnings("error") as w: + try: + norm_fac = Std_H / Std_h + except RuntimeWarning: + raise ValueError(arr) + + reconstructed_arr = reconstructed_arr / norm_fac[None, :] + + return reconstructed_arr + + def _select_first(lst): """Select the first element in a list.""" return lst[0]