-
Notifications
You must be signed in to change notification settings - Fork 38
/
dataset.py
34 lines (26 loc) · 1.17 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch.utils.data
import numpy as np, h5py
import random
def CreateDatasetSynthesis(phase, input_path, contrast1 = 'T1', contrast2 = 'T2'):
target_file = input_path + "/data_{}_{}.mat".format(phase, contrast1)
data_fs_s1=LoadDataSet(target_file)
target_file = input_path + "/data_{}_{}.mat".format(phase, contrast2)
data_fs_s2=LoadDataSet(target_file)
dataset=torch.utils.data.TensorDataset(torch.from_numpy(data_fs_s1),torch.from_numpy(data_fs_s2))
return dataset
#Dataset loading from load_dir and converintg to 256x256
def LoadDataSet(load_dir, variable = 'data_fs', padding = True, Norm = True):
f = h5py.File(load_dir,'r')
if np.array(f[variable]).ndim==3:
data=np.expand_dims(np.transpose(np.array(f[variable]),(0,2,1)),axis=1)
else:
data=np.transpose(np.array(f[variable]),(1,0,3,2))
data=data.astype(np.float32)
if padding:
pad_x=int((256-data.shape[2])/2)
pad_y=int((256-data.shape[3])/2)
print('padding in x-y with:'+str(pad_x)+'-'+str(pad_y))
data=np.pad(data,((0,0),(0,0),(pad_x,pad_x),(pad_y,pad_y)))
if Norm:
data=(data-0.5)/0.5
return data