Skip to content

Commit

Permalink
Add mdr keyword to T1() class
Browse files Browse the repository at this point in the history
  • Loading branch information
plaresmedima committed Oct 1, 2024
1 parent 8ea0c18 commit e8a27ed
Showing 1 changed file with 73 additions and 6 deletions.
79 changes: 73 additions & 6 deletions ukat/mapping/t1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import warnings

import mdreg

from . import fitting


Expand Down Expand Up @@ -148,8 +150,11 @@ class T1:
apart from TI
"""

#@Alex: suggestion: make affine a keyword parameter with default np.eye(4).
# As it is the last argument in the list this will not break existing code
# And it means the user is not forced to provide a dummy affine when it plays no role.
def __init__(self, pixel_array, inversion_list, affine, tss=0, tss_axis=-2,
mask=None, parameters=2, molli=False, multithread=True):
mask=None, parameters=2, molli=False, multithread=True, mdr=False):
"""Initialise a T1 class instance.
Parameters
Expand Down Expand Up @@ -201,12 +206,71 @@ def __init__(self, pixel_array, inversion_list, affine, tss=0, tss_axis=-2,
increase in speed distributing the calculation would generate.
'auto' attempts to apply multithreading where appropriate based
on the number of voxels being fit.
mdr : bool, optional
Default 'False`
If True, this performs a motion correction with model-driven
registration before performing the final fit to the model function.
"""

assert multithread is True \
or multithread is False \
or multithread == 'auto', f'multithreaded must be True,' \
f'False or auto. You entered ' \
f'{multithread}'

# @Alex: I have moved this up so multithreading
# settings can be reused in mdreg, which requires a True or False
# value. In this case (elastix) it is actually unnecessary as
# parallelization is not a real option anyway. But it would be relevant if
# wanted to run with skimage, for instamce.
if multithread == 'auto':
npixels = np.prod(pixel_array.shape[:-1])
if npixels > 20:
multithread = True
else:
multithread = False

if mdr:
pixel_array, deform, _, _ = mdreg.fit(
pixel_array,
fit_image = {
'func': _T1_fit,
'inversion_list': inversion_list,
'affine': affine,
'tss': tss,
'tss_axis': tss_axis,
'mask': mask,
'parameters': parameters,
'molli': molli,
'multithread': multithread,
},
# @Alex: These coreg settings are default so technically speaking do not have
# to be specified here. I am leaving it in nevertheless as a template in case we want
# to explore alternative coregistration settings later. The fit_coreg
# dictionary could also be exposed as a keyword argument in T1.__init__()
# in case we want to give the user the option to modify the detail.
# As it stands the user has no way of modifying the way mdr runs, eg.
# change verbosity level, stopping criteria or coreg options.
# I didn't change it as that may exactly be the intention of ukat?
fit_coreg = {
'package': 'elastix',
'parallel': False, # elastix is not parallelizable
}
)
# @Alex: At the moment the order of the dimensions in the deformation field returned by mdreg is awkward.
# Current dimensions are (x,y,2,t) for 2-dimensional pixel_arrays,
# and (x,y,z,3,t) for 3 dimensional. Better would be (x,y,t,2) and (x,y,z,t,3)
# - i.e add a new dimension at the end.
# We will probable change this in mdreg at some point so I put in an ad-hoc
# reordering at this stage, We can just take it out later if mdreg is updated.
self.deformation_field = np.swapaxes(deform, -2, -1)
# @Alex: Hack to avoid magnitude corrected model being selected
# This needs a better solution as this prevents a magnitude corrected mode
pixel_array = np.abs(pixel_array)
else:
# @Alex: is this the expected default?
self.deformation_field = None


self.pixel_array = pixel_array
self.shape = pixel_array.shape[:-1]
Expand All @@ -230,11 +294,6 @@ def __init__(self, pixel_array, inversion_list, affine, tss=0, tss_axis=-2,
self.tss = 0
self.parameters = parameters
self.molli = molli
if multithread == 'auto':
if self.n_vox > 20:
multithread = True
else:
multithread = False
self.multithread = multithread

# Some sanity checks
Expand Down Expand Up @@ -556,3 +615,11 @@ def magnitude_correct(pixel_array):
sign = -(phase_offset / np.abs(phase_offset))
corrected_array = sign * np.abs(pixel_array)
return corrected_array


# Private wrapper for use by mdreg
def _T1_fit(pixel_array, inversion_list=None, affine=None, **kwargs):
# Alex: added the abs() here to avoid the bug with model selection
# This needs a better solution as this prevents a magnitude corrected mode
map = T1(np.abs(pixel_array), inversion_list, affine, **kwargs)
return map.get_fit_signal(), None

0 comments on commit e8a27ed

Please sign in to comment.