Skip to content

Commit

Permalink
Merge pull request #39 from pnlbwh/avert-mem-err
Browse files Browse the repository at this point in the history
Avert mem err
  • Loading branch information
suheyla2 authored Dec 9, 2019
2 parents 606f0df + 38b4915 commit 51a0760
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 59 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,12 @@ is advisable to leave out at least two cores for other processes to run smoothly

**NOTE** See [Caveats/Issues](#caveatsissues) that may occur while using many processors in parallel.

Furthermore, you can define the environment variable `TEMPLATE_CONSTRUCT_CORES` to use a different number of processors
for `antsMultivariateTemplateConstruction2.sh` independent of `--nproc` used for rest of the processes in *dMRIharmonization*:

export TEMPLATE_CONSTRUCT_CORES=32


# Order of spherical harmonics

RISH features are derived from spherical harmonic coefficients. The order of spherical harmonic coefficients you can use
Expand Down
13 changes: 7 additions & 6 deletions lib/buildTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
eps= 2.2204e-16
SCRIPTDIR= os.path.dirname(__file__)
config = configparser.ConfigParser()
# config.read(os.path.join(SCRIPTDIR,'config.ini'))
config.read(f'/tmp/harm_config_{os.getpid()}.ini')
N_shm = int(config['DEFAULT']['N_shm'])
N_proc = int(config['DEFAULT']['N_proc'])
Expand Down Expand Up @@ -83,15 +82,16 @@ def createAntsCaselist(imgs, file):


def antsMult(caselist, outPrefix):


N_core=os.getenv('TEMPLATE_CONSTRUCT_CORES')
check_call((' ').join([os.path.join(SCRIPTDIR, 'antsMultivariateTemplateConstruction2_fixed_random_seed.sh'),
'-d', '3',
'-g', '0.2',
'-k', '2',
'-t', "BSplineSyN[0.1,26,0]",
'-r', '1',
'-c', '2',
'-j', str(N_proc),
'-j', str(N_core) if N_core else str(N_proc),
'-f', '8x4x2x1',
'-o', outPrefix,
caselist]), shell= True)
Expand Down Expand Up @@ -183,7 +183,6 @@ def stat_calc(ref, target, mask):
np.nan_to_num(per_diff).clip(max=100., min=-100., out= per_diff)
per_diff_smooth= smooth(per_diff)
scale= ref/(target+eps)
scale.clip(max=10., min= 0., out= scale)

return (delta, per_diff, per_diff_smooth, scale)

Expand Down Expand Up @@ -211,6 +210,7 @@ def difference_calc(refSite, targetSite, refImgs, targetImgs,
per_diff_smooth= []
scale= []
if travelHeads:
print('Using travelHeads for computing templates of',dm)
for refImg, targetImg in zip(refImgs, targetImgs):
prefix = os.path.basename(refImg).split('.')[0]
ref= load_nifti(os.path.join(templatePath, f'{prefix}_Warped{dm}.nii.gz'))[0]
Expand Down Expand Up @@ -246,8 +246,9 @@ def difference_calc(refSite, targetSite, refImgs, targetImgs,
save_nifti(os.path.join(templatePath, f'PercentageDiff_{dm}smooth.nii.gz'),
np.mean(per_diff_smooth, axis= 0), templateAffine, templateHdr)

save_nifti(os.path.join(templatePath, f'Scale_{dm}.nii.gz'),
np.sqrt(np.mean(scale, axis= 0)), templateAffine, templateHdr)
if 'L' in dm:
save_nifti(os.path.join(templatePath, f'Scale_{dm}.nii.gz'),
np.sqrt(np.mean(scale, axis= 0)), templateAffine, templateHdr)



Expand Down
2 changes: 1 addition & 1 deletion lib/dti.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def dti(imgPath, maskPath, inPrefix, outPrefix, tool='FSL'):
] & FG


gfa_vol = gfa(masked_vol)
gfa_vol = np.nan_to_num(gfa(masked_vol))
save_nifti(outPrefix + '_GFA.nii.gz', gfa_vol, vol.affine, vol.header)


Expand Down
36 changes: 25 additions & 11 deletions lib/harmonization.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def createTemplate(self):
# ATTN: antsMultivariateTemplateConstruction2.sh requires absolute path for caselist
antsMult(os.path.abspath(antsMultCaselist), self.templatePath)

# # load templateHdr
# load templateHdr
templateHdr= load(os.path.join(self.templatePath, 'template0.nii.gz')).header


Expand All @@ -244,26 +244,26 @@ def createTemplate(self):
pool.close()
pool.join()

print('dti statistics: mean, std(FA, MD) calculation of reference site')
print('calculating dti statistics i.e. mean, std for reference site')
refMaskPath= dti_stat(self.reference, refImgs, refMasks, self.templatePath, templateHdr)
print('dti statistics: mean, std(FA, MD) calculation of target site')
print('calculating dti statistics i.e. mean, std for target site')
targetMaskPath= dti_stat(self.target, targetImgs, targetMasks, self.templatePath, templateHdr)

print('masking dti statistics of reference site')
_= template_masking(refMaskPath, targetMaskPath, self.templatePath, self.reference)
print('masking dti statistics of target site')
templateMask= template_masking(refMaskPath, targetMaskPath, self.templatePath, self.target)

print('rish_statistics mean, std(L{i}) calculation of reference site')
print('calculating rish_statistics i.e. mean, std calculation for reference site')
rish_stat(self.reference, refImgs, self.templatePath, templateHdr)
print('rish_statistics mean, std(L{i}) calculation of target site')
print('calculating rish_statistics i.e. mean, std calculation for target site')
rish_stat(self.target, targetImgs, self.templatePath, templateHdr)

print('calculating scale map for diffusionMeasures')
print('calculating templates for diffusionMeasures')
difference_calc(self.reference, self.target, refImgs, targetImgs, self.templatePath, templateHdr,
templateMask, self.diffusionMeasures)

print('calculating scale map for rishFeatures')
print('calculating templates for rishFeatures')
difference_calc(self.reference, self.target, refImgs, targetImgs, self.templatePath, templateHdr,
templateMask, [f'L{i}' for i in range(0, self.N_shm+1, 2)])

Expand Down Expand Up @@ -342,9 +342,6 @@ def post_debug(self):

from debug_fa import sub2tmp2mni

refImgs, _ = read_imgs_masks(self.ref_csv)
targetImgs, _= read_imgs_masks(self.target_csv)

print('\n\n Reference site')
sub2tmp2mni(self.templatePath, self.reference, self.ref_csv, ref= True)

Expand All @@ -361,9 +358,17 @@ def post_debug(self):
def showStat(self):

from debug_fa import analyzeStat
from datetime import datetime

print('\n\nPrinting statistics :\n\n')


# save statistics for future
statFile= os.path.join(self.templatePath, 'meanFAstat.txt')
f= open(statFile,'a')
stdout= sys.stdout
sys.stdout= f

print(datetime.now().strftime('%c'),'\n')
print(f'{self.reference} site: ')
ref_mean = analyzeStat(self.ref_csv, self.templatePath)
printStat(ref_mean, self.ref_csv)
Expand All @@ -375,6 +380,15 @@ def showStat(self):
print(f'{self.target} site after harmonization: ')
target_mean_after = analyzeStat(self.harm_csv, self.templatePath)
printStat(target_mean_after, self.harm_csv)

f.close()
sys.stdout= stdout

# print statistics on console
with open(statFile) as f:
print(f.read())

print('\nThe statistics are also saved in ', statFile)


def sanityCheck(self):
Expand Down
53 changes: 30 additions & 23 deletions lib/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

SCRIPTDIR= os.path.dirname(__file__)
config = configparser.ConfigParser()
# config.read(os.path.join(SCRIPTDIR,'config.ini'))
config.read(f'/tmp/harm_config_{os.getpid()}.ini')

N_shm = int(config['DEFAULT']['N_shm'])
Expand Down Expand Up @@ -62,22 +61,11 @@ def dti_harm(imgPath, maskPath):
prefix = os.path.split(inPrefix)[-1]

outPrefix = os.path.join(directory, 'dti', prefix)

# if the dti output exists with the same prefix, don't dtifit again
if not os.path.exists(outPrefix+'_FA.nii.gz'):
dti(imgPath, maskPath, inPrefix, outPrefix)
dti(imgPath, maskPath, inPrefix, outPrefix)

outPrefix = os.path.join(directory, 'harm', prefix)
b0, shm_coeff, qb_model= rish(imgPath, maskPath, inPrefix, outPrefix, N_shm)

return (b0, shm_coeff, qb_model)


# def pre_dti_harm(imgPath, maskPath):
def pre_dti_harm(itr):
imgPath, maskPath = preprocessing(itr[0], itr[1])
dti_harm(imgPath, maskPath)
return (imgPath, maskPath)
rish(imgPath, maskPath, inPrefix, outPrefix, N_shm)


# convert NRRD to NIFTI on the fly
def nrrd2nifti(imgPath):
Expand Down Expand Up @@ -163,21 +151,40 @@ def preprocessing(imgPath, maskPath):


def common_processing(caselist):
imgs, masks = read_caselist(caselist)
f = open(caselist + '.modified', 'w')

pool = multiprocessing.Pool(N_proc) # Use all available cores, otherwise specify the number you want as an argument
imgs, masks = read_caselist(caselist)

# to avoid MemoryError, decouple preprocessing (spm_bspline) and dti_harm (rish)
res=[]
pool = multiprocessing.Pool(N_proc)
for imgPath,maskPath in zip(imgs,masks):
res.append(pool.apply_async(func= preprocessing, args= (imgPath,maskPath)))

attributes= [r.get() for r in res]

pool.close()
pool.join()

res = pool.map_async(pre_dti_harm, np.hstack((np.reshape(imgs, (len(imgs), 1)), np.reshape(masks, (len(masks), 1)))))
attributes = res.get()

f = open(caselist + '.modified', 'w')
for i in range(len(imgs)):
imgs[i] = attributes[i][0]
masks[i] = attributes[i][1]
f.write(f'{imgs[i]},{masks[i]}\n')

f.close()


# the following imgs, masks is for diagnosing MemoryError i.e. computing rish w/o preprocessing
# to diagnose, comment all the above and uncomment the following
# imgs, masks = read_caselist(caselist+'.modified')

# experimentally found ncpu=4 to be memroy optimal
pool = multiprocessing.Pool(4)
for imgPath,maskPath in zip(imgs,masks):
pool.apply_async(func= dti_harm, args= (imgPath,maskPath))
pool.close()
pool.join()


f.close()
return (imgs, masks)

return (imgs, masks)
16 changes: 9 additions & 7 deletions lib/reconstSignal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from buildTemplate import applyXform
from local_med_filter import local_med_filter
from preprocess import dti_harm, preprocessing
from rish import rish

eps= 2.2204e-16
SCRIPTDIR= os.path.dirname(__file__)
config = configparser.ConfigParser()
# config.read(os.path.join(SCRIPTDIR,'config.ini'))
config.read(f'/tmp/harm_config_{os.getpid()}.ini')
N_shm = int(config['DEFAULT']['N_shm'])
N_proc = int(config['DEFAULT']['N_proc'])
Expand All @@ -32,20 +32,22 @@
def antsReg(img, mask, mov, outPrefix):

if mask:
check_call((' ').join(['antsRegistrationSyNQuick.sh',
p= Popen((' ').join(['antsRegistrationSyNQuick.sh',
'-d', '3',
'-f', img,
'-x', mask,
'-m', mov,
'-o', outPrefix,
'-e', '123456']), shell= True)
p.wait()
else:
check_call((' ').join(['antsRegistrationSyNQuick.sh',
p= Popen((' ').join(['antsRegistrationSyNQuick.sh',
'-d', '3',
'-f', img,
'-m', mov,
'-o', outPrefix,
'-e', '123456']), shell= True)
p.wait()

def antsApply(templatePath, directory, prefix):

Expand Down Expand Up @@ -163,14 +165,14 @@ def reconst(imgPath, maskPath, moving, templatePath, preFlag):
if preFlag:
imgPath, maskPath = preprocessing(imgPath, maskPath)

# provide full sampled shm_coeff, qb_model.B
# provide imgPath header
img = load(imgPath)
b0, shm_coeff, qb_model = dti_harm(imgPath, maskPath)

directory = os.path.dirname(imgPath)
inPrefix = imgPath.split('.')[0]
prefix = os.path.split(inPrefix)[-1]
prefix = os.path.split(inPrefix)[-1]
outPrefix = os.path.join(directory, 'harm', prefix)
b0, shm_coeff, qb_model = rish(imgPath, maskPath, inPrefix, outPrefix, N_shm)


print(f'Registering template FA to {imgPath} space ...')
outPrefix = os.path.join(directory, 'harm', 'ToSubjectSpace_' + prefix)
Expand Down
7 changes: 5 additions & 2 deletions lib/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def resize_spm(lowResImg, inPrefix):
savemat(dataFile, {'lowResImg': lowResImg})

# call MATLAB_Runtime based spm bspline interpolation
check_call([os.path.join(FILEDIR,'spm_bspline_exec', 'bspline')+' '+inPrefix], shell= True)
p= Popen((' ').join([os.path.join(FILEDIR,'spm_bspline_exec', 'bspline'), inPrefix]), shell=True)
p.wait()

highResImg= np.nan_to_num(loadmat(inPrefix+'_resampled.mat')['highResImg'])

Expand Down Expand Up @@ -133,7 +134,9 @@ def resampling(lowResImgPath, lowResMaskPath, lowResImg, lowResImgHdr, lowResMas

# unring the b0
highResB0Path = lowResImgPath.split('.')[0] + '_resampled_bse.nii.gz'
check_call(['unring.a64', highResB0PathTmp, highResB0Path])
p= Popen((' ').join(['unring.a64', highResB0PathTmp, highResB0Path]), shell= True)
p.wait()

check_call(['rm', highResB0PathTmp])
b0_gibs = load(highResB0Path).get_data()
np.nan_to_num(b0_gibs).clip(min= 0., out= b0_gibs) # using min= 1. is unnecessary
Expand Down
Loading

0 comments on commit 51a0760

Please sign in to comment.