From e8a27ed0d184dc3b8cbaff752fa8cf9351856144 Mon Sep 17 00:00:00 2001 From: plaresmedima Date: Tue, 1 Oct 2024 22:45:46 +0100 Subject: [PATCH] Add mdr keyword to T1() class --- ukat/mapping/t1.py | 79 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 73 insertions(+), 6 deletions(-) diff --git a/ukat/mapping/t1.py b/ukat/mapping/t1.py index 0fb63f4..0bf018c 100644 --- a/ukat/mapping/t1.py +++ b/ukat/mapping/t1.py @@ -3,6 +3,8 @@ import os import warnings +import mdreg + from . import fitting @@ -148,8 +150,11 @@ class T1: apart from TI """ + #@Alex: suggestion: make affine a keyword parameter with default np.eye(4). + # As it is the last argument in the list this will not break existing code + # And it means the user is not forced to provide a dummy affine when it plays no role. def __init__(self, pixel_array, inversion_list, affine, tss=0, tss_axis=-2, - mask=None, parameters=2, molli=False, multithread=True): + mask=None, parameters=2, molli=False, multithread=True, mdr=False): """Initialise a T1 class instance. Parameters @@ -201,12 +206,71 @@ def __init__(self, pixel_array, inversion_list, affine, tss=0, tss_axis=-2, increase in speed distributing the calculation would generate. 'auto' attempts to apply multithreading where appropriate based on the number of voxels being fit. + mdr : bool, optional + Default 'False` + If True, this performs a motion correction with model-driven + registration before performing the final fit to the model function. """ + assert multithread is True \ or multithread is False \ or multithread == 'auto', f'multithreaded must be True,' \ f'False or auto. You entered ' \ f'{multithread}' + + # @Alex: I have moved this up so multithreading + # settings can be reused in mdreg, which requires a True or False + # value. In this case (elastix) it is actually unnecessary as + # parallelization is not a real option anyway. But it would be relevant if + # wanted to run with skimage, for instamce. + if multithread == 'auto': + npixels = np.prod(pixel_array.shape[:-1]) + if npixels > 20: + multithread = True + else: + multithread = False + + if mdr: + pixel_array, deform, _, _ = mdreg.fit( + pixel_array, + fit_image = { + 'func': _T1_fit, + 'inversion_list': inversion_list, + 'affine': affine, + 'tss': tss, + 'tss_axis': tss_axis, + 'mask': mask, + 'parameters': parameters, + 'molli': molli, + 'multithread': multithread, + }, + # @Alex: These coreg settings are default so technically speaking do not have + # to be specified here. I am leaving it in nevertheless as a template in case we want + # to explore alternative coregistration settings later. The fit_coreg + # dictionary could also be exposed as a keyword argument in T1.__init__() + # in case we want to give the user the option to modify the detail. + # As it stands the user has no way of modifying the way mdr runs, eg. + # change verbosity level, stopping criteria or coreg options. + # I didn't change it as that may exactly be the intention of ukat? + fit_coreg = { + 'package': 'elastix', + 'parallel': False, # elastix is not parallelizable + } + ) + # @Alex: At the moment the order of the dimensions in the deformation field returned by mdreg is awkward. + # Current dimensions are (x,y,2,t) for 2-dimensional pixel_arrays, + # and (x,y,z,3,t) for 3 dimensional. Better would be (x,y,t,2) and (x,y,z,t,3) + # - i.e add a new dimension at the end. + # We will probable change this in mdreg at some point so I put in an ad-hoc + # reordering at this stage, We can just take it out later if mdreg is updated. + self.deformation_field = np.swapaxes(deform, -2, -1) + # @Alex: Hack to avoid magnitude corrected model being selected + # This needs a better solution as this prevents a magnitude corrected mode + pixel_array = np.abs(pixel_array) + else: + # @Alex: is this the expected default? + self.deformation_field = None + self.pixel_array = pixel_array self.shape = pixel_array.shape[:-1] @@ -230,11 +294,6 @@ def __init__(self, pixel_array, inversion_list, affine, tss=0, tss_axis=-2, self.tss = 0 self.parameters = parameters self.molli = molli - if multithread == 'auto': - if self.n_vox > 20: - multithread = True - else: - multithread = False self.multithread = multithread # Some sanity checks @@ -556,3 +615,11 @@ def magnitude_correct(pixel_array): sign = -(phase_offset / np.abs(phase_offset)) corrected_array = sign * np.abs(pixel_array) return corrected_array + + +# Private wrapper for use by mdreg +def _T1_fit(pixel_array, inversion_list=None, affine=None, **kwargs): + # Alex: added the abs() here to avoid the bug with model selection + # This needs a better solution as this prevents a magnitude corrected mode + map = T1(np.abs(pixel_array), inversion_list, affine, **kwargs) + return map.get_fit_signal(), None