Skip to content

Commit

Permalink
Merge pull request #57 from pnlbwh/fit-shm-ref
Browse files Browse the repository at this point in the history
add function for fitting shm on ref site
  • Loading branch information
tashrifbillah authored Feb 27, 2020
2 parents 3187882 + 2b1f345 commit 8dfb005
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 56 deletions.
67 changes: 37 additions & 30 deletions lib/harmonization.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ def createTemplate(self):
# createTemplate steps -----------------------------------------------------------------------------------------

# read image lists
refImgs, refMasks= common_processing(self.ref_csv)
refImgs, refMasks= common_processing(self.ref_unproc_csv)
if not self.ref_csv.endswith('.modified'):
self.ref_csv += '.modified'
# debug: use the following line to omit processing again
# refImgs, refMasks = read_imgs_masks(self.ref_csv)

targetImgs, targetMasks= common_processing(self.target_csv)
targetImgs, targetMasks= common_processing(self.tar_unproc_csv)
if not self.target_csv.endswith('.modified'):
self.target_csv += '.modified'
# debug: use the following line to omit processing again
Expand Down Expand Up @@ -274,8 +274,8 @@ def createTemplate(self):

def harmonizeData(self):

from reconstSignal import reconst
from preprocess import dti_harm
from reconstSignal import reconst, approx
from preprocess import dti_harm, common_processing, preprocessing

# check the templatePath
if not exists(self.templatePath):
Expand All @@ -284,56 +284,64 @@ def harmonizeData(self):
if not listdir(self.templatePath):
raise ValueError(f'{self.templatePath} is empty')

# go through each file listed in csv, check their existence, create dti and harm directories
check_csv(self.target_csv, self.force)


if self.debug:
# calcuate diffusion measures of target site before any processing so we are able to compare
# with the ones after harmonization
imgs, masks= read_imgs_masks(self.tar_unproc_csv)
# fit spherical harmonics on reference site
if self.debug and self.ref_csv:
check_csv(self.ref_unproc_csv, self.force)
refImgs, refMasks= read_imgs_masks(self.ref_unproc_csv)
res= []
pool = multiprocessing.Pool(self.N_proc)
for imgPath, maskPath in zip(imgs, masks):
imgPath= convertedPath(imgPath)
maskPath= convertedPath(maskPath)
pool.apply_async(func= dti_harm, args= (imgPath,maskPath,))
for imgPath, maskPath in zip(refImgs, refMasks):
res.append(pool.apply_async(func=preprocessing, args=(imgPath, maskPath)))

attributes = [r.get() for r in res]

pool.close()
pool.join()

for i in range(len(refImgs)):
refImgs[i] = attributes[i][0]
refMasks[i] = attributes[i][1]

pool = multiprocessing.Pool(self.N_proc)
for imgPath, maskPath in zip(refImgs, refMasks):
pool.apply_async(func= approx, args=(imgPath,maskPath,))

pool.close()
pool.join()



# go through each file listed in csv, check their existence, create dti and harm directories
check_csv(self.target_csv, self.force)
targetImgs, targetMasks= common_processing(self.tar_unproc_csv)


# reconstSignal steps ------------------------------------------------------------------------------------------

# read target image list
moving= pjoin(self.templatePath, f'Mean_{self.target}_FA.nii.gz')
imgs, masks= read_imgs_masks(self.tar_unproc_csv)


fm= None

if not self.target_csv.endswith('.modified'):
self.target_csv += '.modified'
fm = open(self.target_csv, 'w')


self.harm_csv= self.target_csv+'.harmonized'
fh= open(self.harm_csv, 'w')
pool = multiprocessing.Pool(self.N_proc)
res= []
for imgPath, maskPath in zip(imgs, masks):
for imgPath, maskPath in zip(targetImgs, targetMasks):
res.append(pool.apply_async(func= reconst, args= (imgPath, maskPath, moving, self.templatePath,)))

for r in res:
imgPath, maskPath, harmImg, harmMask= r.get()

if isinstance(fm, io.IOBase):
fm.write(imgPath + ',' + maskPath + '\n')
harmImg, harmMask= r.get()
fh.write(harmImg + ',' + harmMask + '\n')


pool.close()
pool.join()

if isinstance(fm, io.IOBase):
fm.close()
fh.close()


Expand Down Expand Up @@ -444,10 +452,9 @@ def main(self):
if self.N_proc==-1:
self.N_proc= N_CPU

if self.target_csv.endswith('.modified'):
self.tar_unproc_csv= str(self.target_csv).split('.modified')[0]
else:
self.tar_unproc_csv= str(self.target_csv)
if self.ref_csv:
self.ref_unproc_csv= self.ref_csv.strip('.modified')
self.tar_unproc_csv= self.target_csv.strip('.modified')

if not self.stats:
# check appropriateness of N_shm
Expand Down
50 changes: 27 additions & 23 deletions lib/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def nrrd2nifti(imgPath):
return imgPath

nifti_write(imgPath, niftiImgPrefix)

return niftiImgPrefix+'.nii.gz'


Expand Down Expand Up @@ -114,10 +115,6 @@ def preprocessing(imgPath, maskPath):
copyfile(inPrefix + '.bval', outPrefix + '.bval')

maskPath= maskPath

if debug:
dti_harm(outPrefix+'.nii.gz', maskPath)

imgPath= outPrefix+'.nii.gz'


Expand All @@ -135,15 +132,15 @@ def preprocessing(imgPath, maskPath):
write_bvals(outPrefix + '.bval', bvals)

maskPath= maskPath

if debug:
dti_harm(outPrefix+'.nii.gz', maskPath)

imgPath= outPrefix+'.nii.gz'


try:
sp_high = np.array([float(i) for i in resample.split('x')])
except:
sp_high = lowResImgHdr['pixdim'][1:4]

# modifies data, mask, and headers
sp_high = np.array([float(i) for i in resample.split('x')])
if resample and (abs(sp_high-lowResImgHdr['pixdim'][1:4])>10e-3).any():
inPrefix = imgPath.split('.nii')[0]
outPrefix = inPrefix + '_resampled'
Expand All @@ -157,20 +154,26 @@ def preprocessing(imgPath, maskPath):
else:
maskPath= maskPath.split('.nii')[0]+ '_resampled.nii.gz'

if debug:
dti_harm(outPrefix+'.nii.gz', maskPath)

imgPath= outPrefix+'.nii.gz'


return (imgPath, maskPath)



def common_processing(caselist):

imgs, masks = read_caselist(caselist)

# to avoid MemoryError, decouple preprocessing (spm_bspline) and dti_harm (rish)

# compute dti_harm of unprocessed data
pool = multiprocessing.Pool(N_proc)
for imgPath,maskPath in zip(imgs,masks):
pool.apply_async(func= dti_harm, args= (imgPath,maskPath))
pool.close()
pool.join()


# preprocess data
res=[]
pool = multiprocessing.Pool(N_proc)
for imgPath,maskPath in zip(imgs,masks):
Expand All @@ -181,25 +184,26 @@ def common_processing(caselist):
pool.close()
pool.join()


f = open(caselist + '.modified', 'w')
for i in range(len(imgs)):
imgs[i] = attributes[i][0]
masks[i] = attributes[i][1]
f.write(f'{imgs[i]},{masks[i]}\n')
f.close()


# the following imgs, masks is for diagnosing MemoryError i.e. computing rish w/o preprocessing
# to diagnose, comment all the above and uncomment the following
# imgs, masks = read_caselist(caselist+'.modified')

# experimentally found ncpu=4 to be memory optimal
pool = multiprocessing.Pool(4)


# compute dti_harm of preprocessed data
pool = multiprocessing.Pool(N_proc)
for imgPath,maskPath in zip(imgs,masks):
pool.apply_async(func= dti_harm, args= (imgPath,maskPath))
pool.close()
pool.join()


if debug:
#TODO compute dti_harm for all intermediate data _denoised, _denoised_bmapped, _bmapped
pass


return (imgs, masks)
Expand Down
43 changes: 40 additions & 3 deletions lib/reconstSignal.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,45 @@ def findLargestConnectMask(img, mask):
return mask


def approx(imgPath, maskPath):

print(f'Fitting spherical harmonics on {imgPath} ...')

directory = dirname(imgPath)
inPrefix = imgPath.split('.nii')[0]
prefix = psplit(inPrefix)[-1]
outPrefix = pjoin(directory, 'harm', prefix)

b0, shm_coeff, qb_model = rish(imgPath, maskPath, inPrefix, outPrefix, N_shm)
B = qb_model.B

img= load(imgPath)
hdr= img.header
affine= img.affine

S_hat= np.dot(shm_coeff, B.T)
# keep only upper half of the reconstructed signal
S_hat= S_hat[..., :int(S_hat.shape[3]/2)]
np.nan_to_num(S_hat).clip(min= 0., max= 1., out= S_hat)

# affine= templateAffine for all Scale_L{i}
mappedFile= pjoin(directory, f'{prefix}_mapped_cs.nii.gz')
save_nifti(mappedFile, S_hat, affine, hdr)

# un-normalize approximated data
S_hat_dwi= applymask(S_hat, b0) # overriding applymask function with a nonbinary mask b0

# place b0s in proper indices
S_hat_final= stack_b0(qb_model.gtab.b0s_mask, S_hat_dwi, b0)

# save approximated data
harmImg= pjoin(directory, f'reconstructed_{prefix}.nii.gz')
if force or not isfile(harmImg):
save_nifti(harmImg, S_hat_final, affine, hdr)
copyfile(inPrefix + '.bvec', harmImg.split('.nii')[0] + '.bvec')
copyfile(inPrefix + '.bval', harmImg.split('.nii')[0] + '.bval')


def ring_masking(directory, prefix, maskPath, shm_coeff, b0, qb_model, hdr):

B = qb_model.B
Expand Down Expand Up @@ -164,8 +203,6 @@ def ring_masking(directory, prefix, maskPath, shm_coeff, b0, qb_model, hdr):

def reconst(imgPath, maskPath, moving, templatePath):

imgPath, maskPath = preprocessing(imgPath, maskPath)

img = load(imgPath)

directory = dirname(imgPath)
Expand All @@ -187,7 +224,7 @@ def reconst(imgPath, maskPath, moving, templatePath):
copyfile(inPrefix + '.bvec', harmImg.split('.nii')[0] + '.bvec')
copyfile(inPrefix + '.bval', harmImg.split('.nii')[0] + '.bval')

return (imgPath, maskPath, harmImg, harmMask)
return (harmImg, harmMask)


def stack_b0(b0s_mask, dwi, b0):
Expand Down
37 changes: 37 additions & 0 deletions lib/tests/pipeline_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,55 @@ export TEMPLATE_CONSTRUCT_CORES=6
--nproc -1 \
--create --debug --force || EXIT 'harmonization.py with --create --debug --force failed'

../../harmonization.py \
--bvalMap 1000 \
--resample 1.5x1.5x1.5 \
--template ./template/ \
--ref_list connectom.txt \
--tar_list prisma.txt \
--ref_name CONNECTOM \
--tar_name PRISMA \
--travelHeads \
--nproc -1 \
--create --debug || EXIT 'harmonization.py with --create --debug failed'

# --process and --debug block
../../harmonization.py \
--bvalMap 1000 \
--resample 1.5x1.5x1.5 \
--template ./template/ \
--tar_list prisma.txt \
--tar_name PRISMA \
--ref_list connectom.txt \
--nproc -1 \
--process --debug --force || EXIT 'harmonization.py with --process --debug --force failed'

../../harmonization.py \
--bvalMap 1000 \
--resample 1.5x1.5x1.5 \
--template ./template/ \
--tar_list prisma.txt \
--tar_name PRISMA \
--ref_list connectom.txt \
--nproc -1 \
--process --debug || EXIT 'harmonization.py with --process --debug failed'

# ===============================================================================================================

# same bvalue, resolution block
cp connectom.txt.modified connectom_same.txt
cp prisma.txt.modified prisma_same.txt
../../harmonization.py \
--template ./template/ \
--ref_list connectom_same.txt \
--tar_list prisma_same.txt \
--ref_name CONNECTOM \
--tar_name PRISMA \
--nproc -1 \
--create --process || EXIT 'harmonization.py for same bvalue, resolution with --create --process failed'


# ===============================================================================================================
# compute statistics
../fa_skeleton_test.py -i connectom.txt.modified \
-s CONNECTOM -t template/ || EXIT 'fa_skeleton_test.py failed for modified reference'
Expand Down

0 comments on commit 8dfb005

Please sign in to comment.