diff --git a/function.py b/function.py index 0f7f4c3..10ebed0 100755 --- a/function.py +++ b/function.py @@ -3,6 +3,8 @@ import numpy as np import scipy.ndimage as snd from torch.autograd import Variable +from torchvision.transforms import ToPILImage, ToTensor +import torchvision.transforms.functional as PIL from dataset import VolumeDataset, BlockDataset from torch.utils.data import DataLoader from model import UNet2d @@ -22,7 +24,27 @@ def write_nifti(data, aff, shape, out_path): img=nib.Nifti1Image(data, aff) img.to_filename(out_path) -def rotate_volume(vol): pass +def rotate_volume(vol): + tp_trans=ToPILImage() + tt_trans=ToTensor() + + angle=np.array([1, 1, 1]) + for i in range(3): + if i==0: + old_vol=vol + dim=old_vol.shape[i] + for j in range(dim): + if i==0: + one_slice=old_vol[j, :, :] + elif i==1: + one_slice=old_vol[:, j, :] + else: # i==2 + one_slice=old_vol[:, :, j] + one_slice_pil=tp_trans(one_slice) + one_slice_pil=PIL.rotate(one_slice_pil, angle[i], + resample=PIL.Image.BILINEAR, expand=True) + one_slice=tt_trans(one_slice_pil) + if j==0: pass # Create New Vol def estimate_dice(gt_msk, prt_msk): intersection=gt_msk*prt_msk