diff --git a/CCNet/README.md b/CCNet/README.md new file mode 100644 index 00000000..0bbca494 --- /dev/null +++ b/CCNet/README.md @@ -0,0 +1,291 @@ +# Overview + +This directory contains all information needed to run inference with the readily trained FastSurferVINN or train it from scratch. CCNet is capable of whole brain segmentation into 95 classes in under 1 minute, mimicking FreeSurfer's anatomical segmentation and cortical parcellation (DKTatlas). The network architecture incorporates local and global competition via competitive dense blocks and competitive skip pathways, as well as multi-slice information aggregation that specifically tailor network performance towards accurate segmentation of both cortical and sub-cortical structures. +![](/images/detailed_network.png) + +The network was trained with conformed images (UCHAR, 1-0.7 mm voxels and standard slice orientation). These specifications are checked in the run_prediction.py script and the image is automatically conformed if it does not comply. + +# 1. Inference + +The *CCNet* directory contains all the source code and modules needed to run the scripts. A list of python libraries used within the code can be found in __requirements.txt__. The main script is called __run_prediction.py__ within which certain options can be selected and set via the command line: + +#### General +* `--in_dir`: Path to the input volume directory (e.g /your/path/to/ADNI/fs60) or +* `--csv_file`: Path to csv-file listing input volume directories +* `--t1`: name of the T1-weighted MRI_volume (like mri_volume.mgz, __default: orig.mgz__) +* `--conformed_name`: name of the conformed MRI_volume (the input volume is always first conformed, if not already, and the result is saved under the given name, __default: orig.mgz__) +* `--t`: search tag limits processing to subjects matching the pattern (e.g. sub-* or 1030*...) +* `--sd`: Path to output directory (where should predictions be saved). Will be created if it does not already exist. +* `--seg_log`: name of log-file (information about processing is stored here; If not set, logs will not be saved). Saved in the same directory as the predictions. +* `--strip`: strip suffix from path definition of input file to yield correct subject name. (Optional, if full path is defined for `--t1`) +* `--lut`: FreeSurfer-style Color Lookup Table with labels to use in final prediction. Default: ./config/FastSurfer_ColorLUT.tsv +* `--seg`: Name of intermediate DL-based segmentation file (similar to aparc+aseg). +* `--cfg_cor`: Path to the coronal config file +* `--cfg_sag`: Path to the axial config file +* `--cfg_ax`: Path to the sagittal config file + +#### Checkpoints +* `--ckpt_sag`: path to sagittal network checkpoint +* `--ckpt_cor`: path to coronal network checkpoint +* `--ckpt_ax`: path to axial network checkpoint + +#### Optional commands +* `--clean`: clean up segmentation after running it (optional) +* `--device `:Device for processing (_auto_, _cpu_, _cuda_, _cuda:_), where cuda means Nvidia GPU; you can select which one e.g. "cuda:1". Default: "auto", check GPU and then CPU +* `--viewagg_device `: Define where the view aggregation should be run on. + Can be _auto_ or a device (see --device). + By default (_auto_), the program checks if you have enough memory to run the view aggregation on the gpu. + The total memory is considered for this decision. + If this fails, or you actively overwrote the check with setting `--viewagg_device cpu`, view agg is run on the cpu. + Equivalently, if you define `--viewagg_device gpu`, view agg will be run on the gpu (no memory check will be done). +* `--batch_size`: Batch size for inference. Default=1 + + +### Example Command Evaluation Single Subject +To run the network on MRI-volumes of subjectX in ./data (specified by `--t1` flag; e.g. ./data/subjectX/t1-weighted.nii.gz), change into the *CCNet* directory and run the following commands: + +``` +python3 run_prediction.py --t1 ../data/subjectX/t1-weighted.nii.gz \ +--sd ../output \ +--t subjectX \ +--seg_log ../output/temp_Competitive.log \ +``` + +The output will be stored in: + +- ../output/subjectX/mri/aparc.DKTatlas+aseg.deep.mgz (large segmentation) +- ../output/subjectX/mri/mask.mgz (brain mask) +- ../output/subjectX/mri/aseg_noCC.mgz (reduced segmentation) + +Here the logfile "temp_Competitive.log" will include the logfiles of all subjects. If left out, the logs will be written to stdout + + +### Example Command Evaluation whole directory +To run the network on all subjects MRI-volumes in ./data, change into the *CCNet* directory and run the following command: + +``` +python3 run_prediction.py --in_dir ../data \ +--sd ../output \ +--seg_log ../output/temp_Competitive.log \ +``` + +The output will be stored in: + +- ../output/subjectX/mri/aparc.DKTatlas+aseg.deep.mgz (large segmentation) +- ../output/subjectX/mri/mask.mgz (brain mask) +- ../output/subjectX/mri/aseg_noCC.mgz (reduced segmentation) +- and the log in ../output/temp_Competitive.log + + +# 2. Hdf5-Trainingset Generation + +The *CCNet* directory contains all the source code and modules needed to create a hdf5-file from given MRI volumes. Here, we use the orig.mgz output from freesurfer as the input image and the aparc.DKTatlas+aseg.mgz as the ground truth. The mapping functions are set-up accordingly as well and need to be changed if you use a different segmentation as ground truth. +A list of python libraries used within the code can be found in __requirements.txt__. The main script is called __generate_hdf5.py__ within which certain options can be selected and set via the command line: + +#### General +* `--hdf5_name`: Path and name of the to-be-created hdf5-file. Default: ../data/hdf5_set/Multires_coronal.hdf5 +* `--data_dir`: Directory with images to load. Default: /data +* `--pattern`: Pattern to match only certain files in the directory +* `--csv_file`: Csv-file listing subjects to load (can be used instead of data_dir; one complete path per line (up to the subject directory)) + Example: You have a directory called **dataset** with three different datasets (**D1**, **D2** and **D3**). You want to include subject1, subject10 and subject20 from D1 and D2. Your csv-file would then look like this: + + /dataset/D1/subject1 + /dataset/D1/subject10 + /dataset/D1/subject20 + /dataset/D2/subject1 + /dataset/D2/subject10 + /dataset/D2/subject20 +* --lut: FreeSurfer-style Color Lookup Table with labels to use in final prediction. Default: ./config/FastSurfer_ColorLUT.tsv + +The actual filename and segmentation ground truth name is specified via `--image_name` and `--gt_name` (e.g. the actual file could be sth. like /dataset/D1/subject1/mri_volume.mgz and /dataset/D1/subject1/segmentation.mgz) + +#### Image Names +* `--image_name`: Default name of original images. FreeSurfer orig.mgz is default (mri/orig.mgz) +* `--gt_name`: Default name for ground truth segmentations. Default: mri/aparc.DKTatlas+aseg.mgz. +* `--gt_nocc`: Segmentation without corpus callosum (used to mask this segmentation in ground truth). For a normal FreeSurfer input, use mri/aseg.auto_noCCseg.mgz. + +#### Image specific options +* `--plane`: Which anatomical plane to use for slicing (axial, coronal or sagittal) +* `--thickness`: Number of pre- and succeeding slices (we use 3 --> total of 7 slices is fed to the network; default: 3) +* `--combi`: Suffixes of labels names to combine. Default: Left- and Right- +* `--sag_mask`: Suffixes of labels names to mask for final sagittal labels. Default: Left- and ctx-rh +* `--max_w`: Overall max weight for any voxel in weight mask. Default: 5 +* `--hires_w`: Weight for high resolution elements (sulci, WM strands, cortex border) in weight mask. Default: None +* `--no_grad`: Turn on to only use median weight frequency (no gradient). Default: False +* `--gm`: Turn on to add cortex mask for hires-processing. Default: False +* `--processing`: Use aseg, aparc or no specific mapping processing. Default: aparc +* `--sizes`: Resolutions of images in the dataset. Default: 256 +* `--edge_w`: Weight for edges in weight mask. Default=5 + +#### Example Command Axial (Single Resolution) +``` +python3 generate_hdf5.py \ +--hdf5_name ../data/training_set_axial.hdf5 \ +--csv_file ../training_set_subjects_dirs.csv \ +--thickness 3 \ +--plane axial \ +--image_name mri/orig.mgz \ +--gt_name mri/aparc.DKTatlas+aseg.mgz \ +--gt_nocc mri/aseg.auto_noCCseg.mgz +--max_w 5 \ +--edge_w 4 \ +--hires_w 4 \ +--sizes 256 + +``` + +#### Example Command Coronal (Single Resolution) +``` +python3 generate_hdf5.py \ +--hdf5_name ../data/training_set_coronal.hdf5 \ +--csv_file ../training_set_subjects_dirs.csv \ +--plane coronal \ +--image_name mri/orig.mgz \ +--gt_name mri/aparc.DKTatlas+aseg.mgz \ +--gt_nocc mri/aseg.auto_noCCseg.mgz +--max_w 5 \ +--edge_w 4 \ +--hires_w 4 \ +--sizes 256 + +``` + +#### Example Command Sagittal (Multiple Resolutions) +``` +python3 generate_hdf5.py \ +--hdf5_name ../data/training_set_sagittal.hdf5 \ +--csv_file ../training_set_subjects_dirs.csv \ +--plane sagittal \ +--image_name mri/orig.mgz \ +--gt_name mri/aparc.DKTatlas+aseg.mgz \ +--gt_nocc mri/aseg.auto_noCCseg.mgz +--max_w 5 \ +--edge_w 4 \ +--hires_w 4 \ +--sizes 256 311 320 + +``` + +#### Example Command Sagittal using --data_dir instead of --csv_file +`--data_dir` specifies the path in which the data is located, with `--pattern` we can select subjects from the specified path. By default the pattern is "*" meaning all subjects will be selected. +As an example, imagine you have 19 FreeSurfer processed subjects labeled subject1 to subject19 in the ../data directory: + +``` +/home/user/FastSurfer/data +├── subject1 +├── subject2 +├── subject3 +… +│ +├── subject19 +    ├── mri +    │   ├── aparc.DKTatlas+aseg.mgz +    │   ├── aseg.auto_noCCseg.mgz +    │   ├── orig.mgz +    │   ├── … +    │   … +    ├── scripts +    ├── stats +    ├── surf +    ├── tmp +    ├── touch +    └── trash +``` + +Setting `--pattern` "*" will select all 19 subjects (subject1, ..., subject19). +Now, if only a subset should be used for the hdf5-file (e.g. subject 10 till subject19), this can be done by changing the `--pattern` flag to "subject1[0-9]": + +``` +python3 generate_hdf5.py \ +--hdf5_name ../data/training_set_axial.hdf5 \ +--data_dir ../data \ +--pattern "subject1[0-9]" \ +--plane sagittal \ +--image_name mri/orig.mgz \ +--gt_name mri/aparc.DKTatlas+aseg.mgz \ +--gt_nocc mri/aseg.auto_noCCseg.mgz + +``` + +# 3. Training + +The *CCNet* directory contains all the source code and modules needed to run the scripts. A list of python libraries used within the code can be found in __requirements.txt__. The main training script is called __run_model.py__ whose options can be set through a configuration file and command line arguments: +* `--cfg`: Path to the configuration file. Default: config/FastSurferVINN.yaml +* `--aug`: List of augmentations to use. Default: None. +* `--opt`: List of class options to use. + +The `--cfg` file configures the model to be trained. See config/FastSurferVINN.yaml for an example and config/defaults.py for all options and default values. + +The configuration options include: + +#### Model options +* MODEL_NAME: Name of model [CCNet, FastSurferVINN]. Default: FastSurferVINN +* NUM_CLASSES: Number of classes to predict including background. Axial and coronal: 79 (default), Sagittal: 51. +* NUM_FILTERS: Filter dimensions for Networks (all layers same). Default: 71 +* NUM_CHANNELS: Number of input channels (slice thickness). Default: 7 +* KERNEL_H: Height of Kernel. Default: 3 +* KERNEL_W: Width of Kernel. Default: 3 +* STRIDE_CONV: Stride during convolution. Default: 1 +* STRIDE_POOL: Stride during pooling. Default: 2 +* POOL: Size of pooling filter. Default: 2 +* BASE_RES: Base resolution of the segmentation model (after interpolation layer). Default: 1 + +#### Optimizer options + +* BASE_LR: Base learning rate. Default: 0.01 +* OPTIMIZING_METHOD: Optimization method [sgd, adam, adamW]. Default: adamW +* MOMENTUM: Momentum for optimizer. Default: 0.9 +* NESTEROV: Enables Nesterov for optimizer. Default: True +* LR_SCHEDULER: Learning rate scheduler [step_lr, cosineWarmRestarts, reduceLROnPlateau]. Default: cosineWarmRestarts + + +#### Data options + +* PATH_HDF5_TRAIN: Path to training hdf5-dataset +* PATH_HDF5_VAL: Path to validation hdf5-dataset +* PLANE: Plane to load [axial, coronal, sagittal]. Default: coronal + +#### Training options + +* BATCH_SIZE: Input batch size for training. Default: 16 +* NUM_EPOCHS: Number of epochs to train. Default: 30 +* SIZES: Available image sizes for the multi-scale dataloader. Default: [256, 311 and 320] +* AUG: Augmentations. Default: ["Scaling", "Translation"] + +#### Misc. Options + +* LOG_DIR: Log directory for run +* NUM_GPUS: Number of GPUs to use. Default: 1 +* RNG_SEED: Select random seed. Default: 1 + + +Any option can alternatively be set through the command-line by specifying the option name (as defined in config/defaults.py) followed by a value, such as: `MODEL.NUM_CLASSES 51`. + +To train the network on a given hdf5-set, change into the *CCNet* directory and run +`run_model.py` as in the following examples: + +### Example Command: Training Default FastSurferVINN +Trains FastSurferVINN on multi-resolution images in the coronal plane: +``` +python3 run_model.py \ +--cfg ./config/FastSurferVINN.yaml +``` + +### Example Command: Training FastSurferVINN (Single Resolution) +Trains FastSurferVINN on single-resolution images in the sagittal plane by overriding the NUM_CLASSES, SIZES, PATH_HDF5_TRAIN, and PATH_HDF5_VAL options: +``` +python3 run_model.py \ +--cfg ./config/FastSurferVINN.yaml \ +MODEL.NUM_CLASSES 51 \ +DATA.SIZES 256 \ +DATA.PATH_HDF5_TRAIN ./hdf5_sets/training_sagittal_single_resolution.hdf5 \ +DATA.PATH_HDF5_VAL ./hdf5_sets/validation_sagittal_single_resolution.hdf5 \ +``` + +### Example Command: Training CCNet +Trains CCNet using a provided configuration file and specifying no augmentations: +``` +python3 run_model.py \ +--cfg custom_configs/CCNet.yaml \ +--aug None +``` diff --git a/CCNet/__init__.py b/CCNet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CCNet/config/CCNet_axial.yaml b/CCNet/config/CCNet_axial.yaml new file mode 100644 index 00000000..401ef5dc --- /dev/null +++ b/CCNet/config/CCNet_axial.yaml @@ -0,0 +1,70 @@ +MODEL: + MODEL_NAME: "FastSurferVINN" + LOSS_FUNC: "combined" + NUM_FILTERS: 71 + NUM_FILTERS_INTERPOL: 32 + NUM_CHANNELS: 7 + NUM_CLASSES: 3 + KERNEL_H: 3 + KERNEL_W: 3 + KERNEL_C: 1 + STRIDE_CONV: 1 + STRIDE_POOL: 2 + POOL: 2 + HEIGHT: 256 + WIDTH: 256 + BASE_RES: 1.0 + INTERPOLATION_MODE: "bilinear" + CROP_POSITION: "top_left" + OUT_TENSOR_WIDTH: 429 + OUT_TENSOR_HEIGHT: 429 + +TRAIN: + LOG_INTERVAL: 50 + RESUME: false + RESUME_EXPR_NUM: "Default" + NUM_STEPS: 10 + FINE_TUNE: false + CHECKPOINT_PERIOD: 4 + NUM_EPOCHS: 2000 + BATCH_SIZE: 32 + DEBUG: false + +TEST: + BATCH_SIZE: 16 + +DATA: + SIZES: None + PADDED_SIZE: [429, 9, 7] # [axis1, axis2, thickness] + PLANE: "axial" + PATH_HDF5_TRAIN: "/data/cropped-nocom-loc/axial_train.hdf5" + PATH_HDF5_VAL: "/data/cropped-nocom-loc/axial_val.hdf5" + AUG: [] # "Hemisphere" "Flip" + LUT: "/home/CCNet/config/CC_ColorLUT_nocom.tsv" + #AUG_LIKELYHOOD: [0.8, 0.4] # Gaussian does not support probability + +DATA_LOADER: + NUM_WORKERS: 0 + PIN_MEMORY: true + #PREFETCH_FACTOR: 3 + +OPTIMIZER: + BASE_LR: 0.01 + LR_SCHEDULER: "cosineWarmRestarts" + GAMMA: 0.3 + STEP_SIZE: 5 + ETA_MIN: 0.0001 + T_ZERO: 10 + T_MULT: 2 + MOMENTUM: 0.9 + DAMPENING: 0.0 + NESTEROV: true + WEIGHT_DECAY: 0.0001 + OPTIMIZING_METHOD: "adamW" + +NUM_GPUS: 4 +#INPAINT_LOSS_WEIGHT: 0.5 # Weight of intensity image in loss function +RNG_SEED: 1 +LOG_DIR: "/home/experiments/cc_axial_pad/" +#LOG_LEVEL: 10 +EXPR_NUM: "cc_axial_pad_02" diff --git a/CCNet/config/CCNet_coronal.yaml b/CCNet/config/CCNet_coronal.yaml new file mode 100644 index 00000000..760b2a12 --- /dev/null +++ b/CCNet/config/CCNet_coronal.yaml @@ -0,0 +1,70 @@ +MODEL: + MODEL_NAME: "FastSurferVINN" + LOSS_FUNC: "combined" + NUM_FILTERS: 71 + NUM_FILTERS_INTERPOL: 32 + NUM_CHANNELS: 7 + NUM_CLASSES: 3 + KERNEL_H: 3 + KERNEL_W: 3 + KERNEL_C: 1 + STRIDE_CONV: 1 + STRIDE_POOL: 2 + POOL: 2 + HEIGHT: 256 + WIDTH: 256 + BASE_RES: 1.0 + INTERPOLATION_MODE: "bilinear" + CROP_POSITION: "top_left" + OUT_TENSOR_WIDTH: 429 + OUT_TENSOR_HEIGHT: 429 + +TRAIN: + LOG_INTERVAL: 50 + RESUME: false + RESUME_EXPR_NUM: "Default" + NUM_STEPS: 10 + FINE_TUNE: false + CHECKPOINT_PERIOD: 4 + NUM_EPOCHS: 2000 + BATCH_SIZE: 32 + DEBUG: false + +TEST: + BATCH_SIZE: 16 + +DATA: + SIZES: None + PADDED_SIZE: [9, 429, 7] # [axis1, axis2, thickness] + PLANE: "coronal" + PATH_HDF5_TRAIN: "/data/cropped-nocom-loc/coronal_train.hdf5" + PATH_HDF5_VAL: "/data/cropped-nocom-loc/coronal_val.hdf5" + AUG: [] # "Hemisphere" "Flip" + LUT: "/home/CCNet/config/CC_ColorLUT_nocom.tsv" + #AUG_LIKELYHOOD: [0.8, 0.4] # Gaussian does not support probability + +DATA_LOADER: + NUM_WORKERS: 0 + PIN_MEMORY: true + #PREFETCH_FACTOR: 3 + +OPTIMIZER: + BASE_LR: 0.01 + LR_SCHEDULER: "cosineWarmRestarts" + GAMMA: 0.3 + STEP_SIZE: 5 + ETA_MIN: 0.0001 + T_ZERO: 10 + T_MULT: 2 + MOMENTUM: 0.9 + DAMPENING: 0.0 + NESTEROV: true + WEIGHT_DECAY: 0.0001 + OPTIMIZING_METHOD: "adamW" + +NUM_GPUS: 4 +#INPAINT_LOSS_WEIGHT: 0.5 # Weight of intensity image in loss function +RNG_SEED: 1 +LOG_DIR: "/home/experiments/cc_coronal_pad/" +#LOG_LEVEL: 10 +EXPR_NUM: "cc_coronal_pad_02" \ No newline at end of file diff --git a/CCNet/config/CCNet_sagittal.yaml b/CCNet/config/CCNet_sagittal.yaml new file mode 100644 index 00000000..f020ebc4 --- /dev/null +++ b/CCNet/config/CCNet_sagittal.yaml @@ -0,0 +1,74 @@ +MODEL: + MODEL_NAME: "FastSurferLocalisation" + LOSS_FUNC: "localisation" + NUM_FILTERS: 71 + NUM_FILTERS_INTERPOL: 32 + NUM_CHANNELS: 7 + NUM_CLASSES: 3 + KERNEL_H: 3 + KERNEL_W: 3 + KERNEL_C: 1 + STRIDE_CONV: 1 + STRIDE_POOL: 2 + POOL: 2 + HEIGHT: 256 + WIDTH: 256 + BASE_RES: 1.0 + INTERPOLATION_MODE: "bilinear" + CROP_POSITION: "top_left" + OUT_TENSOR_WIDTH: 429 + OUT_TENSOR_HEIGHT: 429 + WEIGHT_SEG: 1e-5 + WEIGHT_LOC: 1.0 + WEIGHT_DIST: 1e-3 + +TRAIN: + LOG_INTERVAL: 50 + RESUME: false + RESUME_EXPR_NUM: "Default" + NUM_STEPS: 10 + FINE_TUNE: false + CHECKPOINT_PERIOD: 4 + NUM_EPOCHS: 20 + BATCH_SIZE: 4 + DEBUG: True + +TEST: + BATCH_SIZE: 4 + +DATA: + SIZES: [256, 320, 311, 429] + PADDED_SIZE: [429, 429, 7] # [axis1, axis2, thickness] + PLANE: "sagittal" + PATH_HDF5_TRAIN: "../data/sagittal_train.hdf5" + PATH_HDF5_VAL: "../data/sagittal_val.hdf5" + AUG: [] # "Hemisphere" "Flip" + LUT: "../CCNet/config/CC_ColorLUT_nocom.tsv" + #AUG_LIKELYHOOD: [0.8, 0.4] # Gaussian does not support probability + +DATA_LOADER: + NUM_WORKERS: 0 + PIN_MEMORY: true + #PREFETCH_FACTOR: 3 + +OPTIMIZER: + BASE_LR: 0.01 + LR_SCHEDULER: "NoScheduler" + GAMMA: 0.3 + STEP_SIZE: 5 + ETA_MIN: 0.0001 + T_ZERO: 10 + T_MULT: 2 + MOMENTUM: 0.9 + DAMPENING: 0.0 + NESTEROV: true + WEIGHT_DECAY: 0.0001 + OPTIMIZING_METHOD: "adamW" + +NUM_GPUS: 4 +#INPAINT_LOSS_WEIGHT: 0.5 # Weight of intensity image in loss function +RNG_SEED: 1 +LOG_DIR: "/home/experiments/cc_com_sagittal_pad" +#LOG_LEVEL: 10 +EXPR_NUM: "cc_com_sagittal_pad_02" + diff --git a/CCNet/config/CC_ColorLUT.tsv b/CCNet/config/CC_ColorLUT.tsv new file mode 100644 index 00000000..b7965fcb --- /dev/null +++ b/CCNet/config/CC_ColorLUT.tsv @@ -0,0 +1,4 @@ +ID LabelName R G B A +0 Background 0 0 0 0 +192 Corpus_Callosum 0 0 160 0 +250 Fornix 255 0 0 0 \ No newline at end of file diff --git a/CCNet/config/checkpoint_paths.yaml b/CCNet/config/checkpoint_paths.yaml new file mode 100644 index 00000000..582d3d8d --- /dev/null +++ b/CCNet/config/checkpoint_paths.yaml @@ -0,0 +1,7 @@ +url: +- "" + +checkpoint: + axial: "checkpoints/CCNet_axial_v0.1.0.pkl" + coronal: "checkpoints/CCNet_coronal_v0.1.0.pkl" + sagittal: "checkpoints/CCNet_sagittal_v0.1.0.pkl" \ No newline at end of file diff --git a/CCNet/config/debug/CCNet_axial_debug.yaml b/CCNet/config/debug/CCNet_axial_debug.yaml new file mode 100644 index 00000000..250eef3c --- /dev/null +++ b/CCNet/config/debug/CCNet_axial_debug.yaml @@ -0,0 +1,70 @@ +MODEL: + MODEL_NAME: "FastSurferVINN" + LOSS_FUNC: "combined" + NUM_FILTERS: 71 + NUM_FILTERS_INTERPOL: 32 + NUM_CHANNELS: 7 + NUM_CLASSES: 3 + KERNEL_H: 3 + KERNEL_W: 3 + KERNEL_C: 1 + STRIDE_CONV: 1 + STRIDE_POOL: 2 + POOL: 2 + HEIGHT: 256 + WIDTH: 256 + BASE_RES: 1.0 + INTERPOLATION_MODE: "bilinear" + CROP_POSITION: "top_left" + OUT_TENSOR_WIDTH: 429 + OUT_TENSOR_HEIGHT: 429 + +TRAIN: + LOG_INTERVAL: 50 + RESUME: false + RESUME_EXPR_NUM: "Default" + NUM_STEPS: 10 + FINE_TUNE: false + CHECKPOINT_PERIOD: 1 + NUM_EPOCHS: 5 + BATCH_SIZE: 1 + DEBUG: true + +TEST: + BATCH_SIZE: 1 + +DATA: + SIZES: None + PADDED_SIZE: [429, 9, 7] # [axis1, axis2, thickness] + PLANE: "axial" + PATH_HDF5_TRAIN: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/data/axial_train.hdf5" + PATH_HDF5_VAL: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/data/axial_val.hdf5" + AUG: [] # "Hemisphere" "Flip" + LUT: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/CCNet/config/CC_ColorLUT.tsv" + #AUG_LIKELYHOOD: [0.8, 0.4] # Gaussian does not support probability + +DATA_LOADER: + NUM_WORKERS: 0 + PIN_MEMORY: true + #PREFETCH_FACTOR: 3 + +OPTIMIZER: + BASE_LR: 0.01 + LR_SCHEDULER: "cosineWarmRestarts" + GAMMA: 0.3 + STEP_SIZE: 5 + ETA_MIN: 0.0001 + T_ZERO: 10 + T_MULT: 2 + MOMENTUM: 0.9 + DAMPENING: 0.0 + NESTEROV: true + WEIGHT_DECAY: 0.0001 + OPTIMIZING_METHOD: "adamW" + +NUM_GPUS: 1 +#INPAINT_LOSS_WEIGHT: 0.5 # Weight of intensity image in loss function +RNG_SEED: 1 +LOG_DIR: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/CCNet/experiments/debug/" +#LOG_LEVEL: 10 +EXPR_NUM: "cc_axial_debug" diff --git a/CCNet/config/debug/CCNet_coronal_debug.yaml b/CCNet/config/debug/CCNet_coronal_debug.yaml new file mode 100644 index 00000000..2fb1cb26 --- /dev/null +++ b/CCNet/config/debug/CCNet_coronal_debug.yaml @@ -0,0 +1,70 @@ +MODEL: + MODEL_NAME: "FastSurferVINN" + LOSS_FUNC: "combined" + NUM_FILTERS: 71 + NUM_FILTERS_INTERPOL: 32 + NUM_CHANNELS: 7 + NUM_CLASSES: 3 + KERNEL_H: 3 + KERNEL_W: 3 + KERNEL_C: 1 + STRIDE_CONV: 1 + STRIDE_POOL: 2 + POOL: 2 + HEIGHT: 256 + WIDTH: 256 + BASE_RES: 1.0 + INTERPOLATION_MODE: "bilinear" + CROP_POSITION: "top_left" + OUT_TENSOR_WIDTH: 429 + OUT_TENSOR_HEIGHT: 429 + +TRAIN: + LOG_INTERVAL: 50 + RESUME: false + RESUME_EXPR_NUM: "Default" + NUM_STEPS: 10 + FINE_TUNE: false + CHECKPOINT_PERIOD: 1 + NUM_EPOCHS: 5 + BATCH_SIZE: 1 + DEBUG: true + +TEST: + BATCH_SIZE: 1 + +DATA: + SIZES: None + PADDED_SIZE: [9, 429, 7] # [axis1, axis2, thickness] + PLANE: "coronal" + PATH_HDF5_TRAIN: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/data/coronal_train.hdf5" + PATH_HDF5_VAL: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/data/coronal_val.hdf5" + AUG: [] # "Hemisphere" "Flip" + LUT: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/CCNet/config/CC_ColorLUT.tsv" + #AUG_LIKELYHOOD: [0.8, 0.4] # Gaussian does not support probability + +DATA_LOADER: + NUM_WORKERS: 0 + PIN_MEMORY: true + #PREFETCH_FACTOR: 3 + +OPTIMIZER: + BASE_LR: 0.01 + LR_SCHEDULER: "cosineWarmRestarts" + GAMMA: 0.3 + STEP_SIZE: 5 + ETA_MIN: 0.0001 + T_ZERO: 10 + T_MULT: 2 + MOMENTUM: 0.9 + DAMPENING: 0.0 + NESTEROV: true + WEIGHT_DECAY: 0.0001 + OPTIMIZING_METHOD: "adamW" + +NUM_GPUS: 1 +#INPAINT_LOSS_WEIGHT: 0.5 # Weight of intensity image in loss function +RNG_SEED: 1 +LOG_DIR: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/CCNet/experiments/debug/" +#LOG_LEVEL: 10 +EXPR_NUM: "cc_coronal_debug" \ No newline at end of file diff --git a/CCNet/config/debug/CCNet_sagittal_debug.yaml b/CCNet/config/debug/CCNet_sagittal_debug.yaml new file mode 100644 index 00000000..ec0099d1 --- /dev/null +++ b/CCNet/config/debug/CCNet_sagittal_debug.yaml @@ -0,0 +1,74 @@ +MODEL: + MODEL_NAME: "FastSurferLocalisation" + LOSS_FUNC: "localisation" + NUM_FILTERS: 71 + NUM_FILTERS_INTERPOL: 32 + NUM_CHANNELS: 7 + NUM_CLASSES: 3 + KERNEL_H: 3 + KERNEL_W: 3 + KERNEL_C: 1 + STRIDE_CONV: 1 + STRIDE_POOL: 2 + POOL: 2 + HEIGHT: 256 + WIDTH: 256 + BASE_RES: 1.0 + INTERPOLATION_MODE: "bilinear" + CROP_POSITION: "top_left" + OUT_TENSOR_WIDTH: 429 + OUT_TENSOR_HEIGHT: 429 + WEIGHT_SEG: 1e-5 + WEIGHT_LOC: 1.0 + WEIGHT_DIST: 1e-3 + +TRAIN: + LOG_INTERVAL: 50 + RESUME: false + RESUME_EXPR_NUM: "Default" + NUM_STEPS: 10 + FINE_TUNE: false + CHECKPOINT_PERIOD: 1 + NUM_EPOCHS: 5 + BATCH_SIZE: 1 + DEBUG: true + +TEST: + BATCH_SIZE: 1 + +DATA: + SIZES: [256, 320, 311, 429] + PADDED_SIZE: [429, 429, 7] # [axis1, axis2, thickness] + PLANE: "sagittal" + PATH_HDF5_TRAIN: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/data/sagittal_train.hdf5" + PATH_HDF5_VAL: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/data/sagittal_val.hdf5" + AUG: [] # "Hemisphere" "Flip" + LUT: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/CCNet/config/CC_ColorLUT.tsv" + #AUG_LIKELYHOOD: [0.8, 0.4] # Gaussian does not support probability + +DATA_LOADER: + NUM_WORKERS: 0 + PIN_MEMORY: true + #PREFETCH_FACTOR: 3 + +OPTIMIZER: + BASE_LR: 0.01 + LR_SCHEDULER: "NoScheduler" + GAMMA: 0.3 + STEP_SIZE: 5 + ETA_MIN: 0.0001 + T_ZERO: 10 + T_MULT: 2 + MOMENTUM: 0.9 + DAMPENING: 0.0 + NESTEROV: true + WEIGHT_DECAY: 0.0001 + OPTIMIZING_METHOD: "adamW" + +NUM_GPUS: 1 +#INPAINT_LOSS_WEIGHT: 0.5 # Weight of intensity image in loss function +RNG_SEED: 1 +LOG_DIR: "/groups/ag-reuter/projects/corpus_callosum_fornix/FastSurfer/CCNet/experiments/debug" +#LOG_LEVEL: 10 +EXPR_NUM: "cc_sagittal_debug" + diff --git a/CCNet/config/defaults.py b/CCNet/config/defaults.py new file mode 100644 index 00000000..d49d6f7d --- /dev/null +++ b/CCNet/config/defaults.py @@ -0,0 +1,257 @@ +from yacs.config import CfgNode as CN + +_C = CN() + +# ---------------------------------------------------------------------------- # +# Model options +# ---------------------------------------------------------------------------- # +_C.MODEL = CN() + +# Name of model +_C.MODEL.MODEL_NAME = "FastSurferVINN" + +# Loss function, combined = dice loss + cross entropy, combined2 = dice loss + boundary loss +_C.MODEL.LOSS_FUNC = "combined" + +_C.MODEL.WEIGHT_SEG = 1e-5 + +_C.MODEL.WEIGHT_LOC = 1.0 + +_C.MODEL.WEIGHT_DIST = 0.0 + +# Filter dimensions for DenseNet (all layers same) +_C.MODEL.NUM_FILTERS = 71 + +# Filter dimensions for Input Interpolation block (currently all the same) +_C.MODEL.NUM_FILTERS_INTERPOL = 32 + +# Number of UNet layers in Basenetwork (including bottleneck layer!) +_C.MODEL.NUM_BLOCKS = 5 + +# Number of input channels (slice thickness) +_C.MODEL.NUM_CHANNELS = 7 + +# Height of convolution kernels +_C.MODEL.KERNEL_H = 3 + +# Width of convolution kernels +_C.MODEL.KERNEL_W = 3 + +# size of Classifier kernel +_C.MODEL.KERNEL_C = 1 + +# Stride during convolution +_C.MODEL.STRIDE_CONV = 1 + +# Stride during pooling +_C.MODEL.STRIDE_POOL = 2 + +# Size of pooling filter +_C.MODEL.POOL = 2 + +# The height of segmentation model (after interpolation layer) +_C.MODEL.HEIGHT = 256 + +# The width of segmentation model +_C.MODEL.WIDTH = 256 + +# The base resolution of the segmentation model (after interpolation layer) +_C.MODEL.BASE_RES = 1.0 + +# Interpolation mode for up/downsampling in Flex networks +_C.MODEL.INTERPOLATION_MODE = "bilinear" + +# Crop positions for up/downsampling in Flex networks +_C.MODEL.CROP_POSITION = "top_left" + +# Out Tensor dimensions for interpolation layer +_C.MODEL.OUT_TENSOR_WIDTH = 320 +_C.MODEL.OUT_TENSOR_HEIGHT = 320 + +# ---------------------------------------------------------------------------- # +# Training options +# ---------------------------------------------------------------------------- # +_C.TRAIN = CN() + +# input batch size for training +_C.TRAIN.BATCH_SIZE = 16 + +# how many batches to wait before logging training status +_C.TRAIN.LOG_INTERVAL = 50 + +# Resume training from the latest checkpoint in the output directory. +_C.TRAIN.RESUME = False + +# The experiment number to resume from +_C.TRAIN.RESUME_EXPR_NUM = "Default" + +# number of epochs to train +_C.TRAIN.NUM_EPOCHS = 30 + +# number of steps (iteration) which depends on dataset +_C.TRAIN.NUM_STEPS = 10 + +# To fine tune model or not +_C.TRAIN.FINE_TUNE = False + +# checkpoint period +_C.TRAIN.CHECKPOINT_PERIOD = 2 + +# Flag to disable or enable Early Stopping +_C.TRAIN.EARLY_STOPPING = True + +# Mode for early stopping (min = stop when metric is no longer decreasing, max = stop when mwtric is no longer increasing) +_C.TRAIN.EARLY_STOPPING_MODE = "min" + +# Patience = Number of epochs to wait before stopping +_C.TRAIN.EARLY_STOPPING_PATIENCE = 10 + +# Wait = NUmber of epochs before starting early stopping check +_C.TRAIN.EARLY_STOPPING_WAIT = 10 + +# Delta = change below which early stopping starts (previous - current < delta = stop) +_C.TRAIN.EARLY_STOPPING_DELTA = 0.00001 + +# Flag to enable debugging run (smaller dataset, less epochs, etc.) +_C.TRAIN.DEBUG = False + +# ---------------------------------------------------------------------------- # +# Testing options +# ---------------------------------------------------------------------------- # +_C.TEST = CN() + +# input batch size for testing +_C.TEST.BATCH_SIZE = 16 + +# ---------------------------------------------------------------------------- # +# Data options +# ---------------------------------------------------------------------------- # + +_C.DATA = CN() + +# path to training hdf5-dataset +_C.DATA.PATH_HDF5_TRAIN = "" + +# path to validation hdf5-dataset +_C.DATA.PATH_HDF5_VAL = "" +_C.DATA.PATH_HDF5_VAL2 = "" + +# The plane to load ['axial', 'coronal', 'sagittal'] +_C.DATA.PLANE = "coronal" + +# Which classes to use +_C.DATA.CLASS_OPTIONS = ["aseg", "aparc"] + +# Number of classes to predict, including background +_C.MODEL.NUM_CLASSES = 0 + +# Available size for dataloader +# This for the multi-scale dataloader +_C.DATA.SIZES = None + +# the size that all inputs are padded to +_C.DATA.PADDED_SIZE = [320, 320, 320] + +# Augmentations +_C.DATA.AUG = ["Scaling", "Translation"] + +# Augmentation probability +_C.DATA.AUG_LIKELYHOOD = [0.3, 0.3] + +_C.DATA.LUT = "" + +# ---------------------------------------------------------------------------- # +# DataLoader options (common for test and train) +# ---------------------------------------------------------------------------- # +_C.DATA_LOADER = CN() + +# Number of data loader workers +_C.DATA_LOADER.NUM_WORKERS = 8 + +# Load data to pinned host memory. +_C.DATA_LOADER.PIN_MEMORY = True + +# How many batches to prefetch (maximum) +_C.DATA_LOADER.PREFETCH_FACTOR = 2 + +# ---------------------------------------------------------------------------- # +# Optimizer options +# ---------------------------------------------------------------------------- # +_C.OPTIMIZER = CN() + +# Base learning rate. +_C.OPTIMIZER.BASE_LR = 0.01 + +# Learning rate scheduler, step_lr, cosineWarmRestarts, reduceLROnPlateau +_C.OPTIMIZER.LR_SCHEDULER = "cosineWarmRestarts" + +# Multiplicative factor of learning rate decay in step_lr +_C.OPTIMIZER.GAMMA = 0.3 + +# Period of learning rate decay in step_lr +_C.OPTIMIZER.STEP_SIZE = 5 + +# minimum learning in cosine lr policy and reduceLROnPlateau +_C.OPTIMIZER.ETA_MIN = 0.0001 + +# number of iterations for the first restart in cosineWarmRestarts +_C.OPTIMIZER.T_ZERO = 10 + +# A factor increases T_i after a restart in cosineWarmRestarts +_C.OPTIMIZER.T_MULT = 2 + +# factor by which learning rate will be reduce (new_lr = lr*factor, default=0.1) +_C.OPTIMIZER.FACTOR = 0.1 + +# number of epochs to wait before lowering lr (default=5) +_C.OPTIMIZER.PATIENCE = 5 + +# Threshold for measuring new optimum (default=1e-4) +_C.OPTIMIZER.THRESH = 0.0001 + +# Number of epochs to wait before resuming normal operation (default=0) +_C.OPTIMIZER.COOLDOWN=0 + +# Momentum +_C.OPTIMIZER.MOMENTUM = 0.9 + +# Momentum dampening +_C.OPTIMIZER.DAMPENING = 0.0 + +# Nesterov momentum +_C.OPTIMIZER.NESTEROV = True + +# L2 regularization +_C.OPTIMIZER.WEIGHT_DECAY = 1e-4 + +# Optimization method [sgd, adam, adamW] +_C.OPTIMIZER.OPTIMIZING_METHOD = "adamW" + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # + +# Number of GPUs to use +_C.NUM_GPUS = 1 + +# log directory for run +_C.LOG_DIR = "./experiments" + +_C.LOG_LEVEL = 20 + +# experiment number +_C.EXPR_NUM = "Default" + +# Note that non-determinism may still be present due to non-deterministic +# operator implementations in GPU operator libraries. +_C.RNG_SEED = 1 + +# Predict healthy tissue for inpainting / anomaly detection +_C.INPAINT_LOSS_WEIGHT = 0.5 # Weight of intensity image in loss function + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() \ No newline at end of file diff --git a/CCNet/data_loader/__init__.py b/CCNet/data_loader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CCNet/data_loader/augmentation.py b/CCNet/data_loader/augmentation.py new file mode 100644 index 00000000..6e9e7049 --- /dev/null +++ b/CCNet/data_loader/augmentation.py @@ -0,0 +1,1160 @@ + +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# IMPORTS +from __future__ import annotations +import itertools +from numbers import Number +import random +#import monai + +import numpy as np +import torch +import nibabel as nib +import h5py + +from typing import Any, List +from typing import Sequence +from typing import Tuple +from typing import TypeVar + +import numpy as np +import torch +import torchvision + +import torchio as tio +from torchio.transforms.augmentation import RandomTransform +from torchio.transforms import IntensityTransform, RandomFlip +from torchio.data.subject import Subject +from torchio.typing import TypeTripletInt +from torchio.typing import TypeTuple +from torchio.utils import to_tuple + +from scipy.spatial import ConvexHull, Delaunay + +TypeLocations = Sequence[Tuple[TypeTripletInt, TypeTripletInt]] +TensorArray = TypeVar('TensorArray', np.ndarray, torch.Tensor) + +## +# Transformations for evaluation +## +class ToTensorTest(object): + """ + Convert np.ndarrays in sample to Tensors. #TODO: Thats not what is happening here + """ + + def __init__(self, include=['image', 'label', 'weight', 'cutout_mask']) -> None: + self.include = include + + def __call__(self, img): + + if isinstance(img, dict): + for key in self.include: + img[key] = torch.from_numpy(self._clip_and_transpose(img[key])) + elif isinstance(img, np.ndarray) or isinstance(img, torch.Tensor): + img = torch.from_numpy(self.x_clip_and_transpose(img)) + + return img + + @staticmethod + def _clip_and_transpose(img: np.ndarray) -> np.ndarray: + img = img.astype(np.float32) + + # Normalize and clamp between 0 and 1 + img = np.clip(img / 255.0, a_min=0.0, a_max=1.0) + + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = img.transpose((2, 0, 1)) + return img + + + +class ZeroPad2DTest(object): + def __init__(self, output_size, pos='top_left'): + """ + Pad the input with zeros to get output size + :param output_size: + :param pos: position to put the input + """ + if isinstance(output_size, Number): + output_size = (int(output_size), ) * 2 + self.output_size = output_size + self.pos = pos + + def _pad(self, image): + if len(image.shape) == 2: + h, w = image.shape + padded_img = np.zeros(self.output_size, dtype=image.dtype) + else: + h, w, c = image.shape + padded_img = np.zeros(self.output_size + (c,), dtype=image.dtype) + + if self.pos == 'top_left': + padded_img[0: h, 0: w] = image + + return padded_img + + def __call__(self, img): + + if isinstance(img, dict): + for key in ['image', 'label', 'weight']: + img[key] = self._pad(img[key]) + elif isinstance(img, np.ndarray) or isinstance(img, torch.Tensor): + img = self._pad(img) + + return img + + +## +# Transformations for training +## +class ToTensor(object): + """ + Convert ndarrays in sample to Tensors. + """ + + def __init__(self, keys=None): + self.keys = keys + + def __call__(self, sample): + return_dict = dict(**sample) + #img, label, weight = sample['image'], sample['label'], sample['weight'] + + if self.keys == None or 'image' in self.keys: + img = sample['image'] + + #img = img.astype(np.float32) + + # Normalize image and clamp between 0 and 1 + img = np.clip(img / 255.0, a_min=0.0, a_max=1.0) + + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + # if isinstance(img, np.ndarray): + # img = img.transpose((2, 0, 1)) + # elif isinstance(img, torch.Tensor): + # return_dict['image'] = img.permute((2, 0, 1)) + + if self.keys == None: + keys = list(sample.keys()) + else: + keys = self.keys + + for key in keys: + if isinstance(sample[key], np.ndarray): + return_dict[key] = torch.from_numpy(sample[key]) + + return return_dict + + +class ZeroPad2D(object): + def __init__(self, output_size, pos='top_left'): + """ + Pad the input with zeros to get output size + :param output_size: + :param pos: position to put the input + """ + if isinstance(output_size, Number): + output_size = (int(output_size), ) * 2 + self.output_size = output_size + self.pos = pos + + def _pad(self, image): + if len(image.shape) == 2: + h, w = image.shape + padded_img = np.zeros(self.output_size, dtype=image.dtype) + + if self.pos == 'top_left': + padded_img[0: h, 0: w] = image + else: + assert(len(image.shape) == 3) + d, h, w = image.shape + padded_img = np.zeros((d,) + self.output_size, dtype=image.dtype) + + if self.pos == 'top_left': + padded_img[:, 0: h, 0: w] = image + + return padded_img + + def __call__(self, sample): + return_dict = dict(**sample) + img, label, weight = sample['image'], sample['label'], sample['weight'] + + return_dict['image'] = self._pad(img) + return_dict['label'] = self._pad(label) + return_dict['weight'] = self._pad(weight) + + return return_dict #{'image': img, 'label': label, 'weight': weight} + + + + + +# TODO: maybe we should reduce weights for removed areas +class RandomAugmentation(object): + + def __init__(self, p=1.0): + self.probability = p + + def decision(self, prob=None): + if prob is None: + return random.random() < self.probability + else: + return random.random() < prob + + def downweight(self, weight, mask): + weight = weight.copy() + weight[mask] = weight[mask] * 0.5 + return weight + + +class RandomizeScaleFactor(RandomAugmentation): + def __init__(self, mean=0, std=0.1, p=1.0): + super().__init__(p) + self.std = std + self.mean = mean + + def __call__(self, sample): + + if self.decision(): + return_dict = dict(**sample) + #img, label, weight, sf = sample['image'], sample['label'], sample['weight'], sample['scale_factor'] + sf = sample['scale_factor'] + # change 1 to sf.size() for isotropic scale factors (now same noise change added to both dims) + sf = sf + torch.randn(1) * self.std + self.mean + + return_dict['scale_factor'] = sf + + # TODO: check that this really changes randomly + return return_dict + else: + return sample + + + + +class RandomCutout(RandomTransform, IntensityTransform): + r"""Randomly set patches to zero within an image. + + Args: + patch_size: Tuple of integers :math:`(w, h, d)` to swap patches + of size :math:`w \times h \times d`. + If a single number :math:`n` is provided, :math:`w = h = d = n`. + num_cuts: Number patches that will be cut + **kwargs: See :class:`~torchio.transforms.Transform` for additional + keyword arguments. + """ + def __init__( + self, + patch_size: TypeTuple = None, + num_cuts: int = 1, + downweighting_factor: float = 0.5, + **kwargs, + ): + super().__init__(**kwargs) + self.patch_size = np.array(to_tuple(patch_size)) + self.num_cuts = self._parse_num_iterations(num_cuts) + self.downweighting_factor = downweighting_factor + + @staticmethod + def _parse_num_iterations(num_iterations): + if not isinstance(num_iterations, int): + raise TypeError( + 'num_iterations must be an int,' + f'not {num_iterations}', + ) + if num_iterations < 0: + raise ValueError( + 'num_iterations must be positive,' + f'not {num_iterations}', + ) + return num_iterations + + def get_params(self, + tensor: torch.Tensor, + patch_size: np.ndarray, + num_cuts: int, + subject: Subject + ) -> List[Tuple[TypeTripletInt, TypeTripletInt]]: + si, sj, sk = tensor.shape[-3:] + spatial_shape = si, sj, sk # for mypy + if patch_size[0] is None: + patch_size = np.random.randint([0,0,0], spatial_shape, (3,)) + locations = [] + for _ in range(num_cuts): + first_ini, first_fin = self.get_random_indices_from_shape( + spatial_shape, + patch_size.tolist(), + subject + ) + + locations.append((tuple(first_ini), tuple(first_fin))) + return locations # type: ignore[return-value] + + def apply_transform(self, subject: Subject) -> Subject: + THICKSLICE_DIMENSION = 0 + orig_image = self.get_images_dict(subject)['image'] + if self.downweighting_factor > 0: + wgt = subject['weight'].data.clone() + + midslice_no = orig_image.data.shape[THICKSLICE_DIMENSION]//2 + orig_slice= orig_image.data[midslice_no] + + locations = self.get_params( + orig_image, + self.patch_size, + self.num_cuts, + subject + ) + for name, image in self.get_images_dict(subject).items(): + + if name != 'image': + assert(orig_slice.shape == image.shape), f'expected slice dimension, but got {image.shape}, for {name}' + + + img = image.data.clone() + for location in locations: + if name != 'image': # apply to slice if midslice is included in image (applies cutout to weights and segmentation) - experimental + if location[0][THICKSLICE_DIMENSION] <= midslice_no and location[1][THICKSLICE_DIMENSION] >= midslice_no: + continue + else: + location[0][THICKSLICE_DIMENSION] = 0 + location[1][THICKSLICE_DIMENSION] = 0 + + img = self._set_constant(img, location[0], location[1], constant=0) # type: ignore[arg-type] # noqa: E501 + if self.downweighting_factor > 0: + wgt = self._multiply(wgt, location[0], location[1], constant=self.downweighting_factor) + image.set_data(img) + if self.downweighting_factor > 0: + subject['weight'].set_data(wgt) + + if not 'cutout_mask' in subject.keys(): + subject.add_image(tio.LabelMap(tensor=torch.zeros(self.get_images_dict(subject)['label'].data.size(), dtype=bool)), 'cutout_mask') + subject['cutout_mask'] = self._set_constant(subject['cutout_mask'].data, location[0], location[1], constant=1) + + return subject + + @staticmethod + def _set_constant( + image: TensorArray, + index_ini: np.ndarray, + index_fin: np.ndarray, + constant: float = 0 + ) -> TensorArray: + i_ini, j_ini, k_ini = index_ini + i_fin, j_fin, k_fin = index_fin + image[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] = constant + return image + + @staticmethod + def _multiply( + image: TensorArray, + index_ini: np.ndarray, + index_fin: np.ndarray, + constant: float = 0 + ) -> TensorArray: + i_ini, j_ini, k_ini = index_ini + i_fin, j_fin, k_fin = index_fin + #image[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] //= int(1/constant) # throws deprecated warning + image[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin].div_(int(1/constant), rounding_mode='floor') # inplace integer division + return image + + def get_random_indices_from_shape(self, + spatial_shape: Sequence[int], + patch_size: Sequence[int], + subject: Subject=None) -> Tuple[np.ndarray, np.ndarray]: # subject used in child class + assert len(spatial_shape) == 3 + assert len(patch_size) in (1, 3) + shape_array = np.array(spatial_shape) + patch_size_array = np.array(patch_size) + max_index_ini_unchecked = shape_array - patch_size_array + if (max_index_ini_unchecked < 0).any(): + message = ( + f'Patch size {patch_size} cannot be' + f' larger than image spatial shape {spatial_shape}' + ) + raise ValueError(message) + max_index_ini = max_index_ini_unchecked.astype(np.uint16) + coordinates = [] + for max_coordinate in max_index_ini.tolist(): + if max_coordinate == 0: + coordinate = 0 + else: + coordinate = int(torch.randint(max_coordinate, size=(1,)).item()) + coordinates.append(coordinate) + index_ini = np.array(coordinates, np.uint16) + index_fin = index_ini + patch_size_array + return index_ini, index_fin + +class SmartRandomCutout(RandomCutout): + """ + Ensures that at least three of the four corners of the cutout block are within the brain + and adjusts probability for partial cutout in the slice thickness dimension + """ + + def __init__(self, brain_div_patch_size: float = 2, num_cuts: int = 1, full_random_probability=0.1, **kwargs): + super().__init__(None, num_cuts, **kwargs) + self.brain_div_patch_size = brain_div_patch_size + self.full_random_probability = full_random_probability + + def get_brain_mask(self, subject: Subject) -> np.ndarray: + if 'aux_label' in subject.keys(): + brain_mask = subject['aux_label'].data > 0 # brainmask from hdf5 + else: + print(f'WARNING: no brain mask found, using convex hull of aseg instead') + try: # this can fail if the input mask is of unexpected shape (e.g.: not enough points(1) to construct initial simplex (need 4)) + _, brain_mask = getConvexHull(subject) # TODO: make static in pre-processing + except: + return None #super().get_random_indices_from_shape(spatial_shape, patch_size) + return brain_mask + + @staticmethod + def _get_rectangle_around_mask(mask): + + if isinstance(mask, np.ndarray): + mask = torch.from_numpy(mask) + #return np.round(np.mean(np.array(mask.nonzero())[[1,2]], axis=1)).astype(int) # get center of tumor mask in slice + + #print('mask shape:', mask.shape) + + # Find indices where we have mass + _, _, x, y = torch.where(mask) + # mass_x and mass_y are the list of x indices and y indices of mass pixels + + min_xy = torch.min(x), torch.min(y) + max_xy = torch.max(x), torch.max(y) + + return min_xy, max_xy + + def get_random_indices_from_shape(self, + spatial_shape: Sequence[int], + patch_size: Sequence[int], + subject: Subject) -> Tuple[np.ndarray, np.ndarray]: + # get slice thickness dimension + MIDSLICE = spatial_shape[0]//2 + SLICE_THICKNESS_DIM = np.argmin(spatial_shape) + SLICE_DIMENSIONS = list(range(3)) + SLICE_DIMENSIONS.remove(SLICE_THICKNESS_DIM) + + brain_mask = self.get_brain_mask(subject) + + if brain_mask is None: + return super().get_random_indices_from_shape(spatial_shape, patch_size) + + min_coord, max_coord = self._get_rectangle_around_mask(brain_mask) + + brain_spatial_shape = np.zeros(3) + brain_spatial_shape[SLICE_DIMENSIONS] = (max_coord[0]-min_coord[0], max_coord[1]-min_coord[1]) + brain_spatial_shape[SLICE_THICKNESS_DIM] = spatial_shape[SLICE_THICKNESS_DIM] + + max_coutout_size = brain_spatial_shape//self.brain_div_patch_size + + if (max_coutout_size[SLICE_THICKNESS_DIM] <= 0).any() or (max_coutout_size[SLICE_DIMENSIONS] <= 0).any(): # if the visible brain is too small, just do a random cutout + return super().get_random_indices_from_shape(spatial_shape, patch_size) + + index_ini, index_fin = super().get_random_indices_from_shape(brain_spatial_shape, np.random.randint(0, max_coutout_size)) + index_ini = index_ini.astype(int) + index_fin = index_fin.astype(int) + + index_ini[SLICE_DIMENSIONS] += np.array(min_coord) + index_fin[SLICE_DIMENSIONS] += np.array(min_coord) + + + # chance of completely random cutout, otherwise full cutout in slice thickness dimension and ensure that at least three of the four corners are within the brain + if np.random.rand() > self.full_random_probability: + index_ini[SLICE_THICKNESS_DIM] = 0 + index_fin[SLICE_THICKNESS_DIM] = spatial_shape[SLICE_THICKNESS_DIM]-1 #TODO: check for off by one errors + + for _ in range(3): + pts_in_brain = 0 + corner_points = np.array(list(itertools.product(*zip(index_ini[SLICE_DIMENSIONS],index_fin[SLICE_DIMENSIONS])))) + + + for point in corner_points: + if brain_mask[0,MIDSLICE,point[0], point[1]]: + pts_in_brain += 1 + + + + if pts_in_brain >= 3: + return index_ini, index_fin + else: # if the cutout is not within the brain, try again, but constrain the cutout to the brain + index_ini, index_fin = super().get_random_indices_from_shape(brain_spatial_shape, np.random.randint(0, max_coutout_size)) + index_ini[SLICE_THICKNESS_DIM] = 0 + index_fin[SLICE_THICKNESS_DIM] = spatial_shape[SLICE_THICKNESS_DIM]-1 #TODO: check for off by one errors + + index_ini = index_ini.astype(int) + index_fin = index_fin.astype(int) + index_ini[SLICE_DIMENSIONS] += min_coord + index_fin[SLICE_DIMENSIONS] += min_coord + + #print('WARNING: could not find a cutout that was within the brain') + return index_ini, index_fin + else: + return index_ini, index_fin + + + +# class RandGridDistortiond(monai.transforms.RandGridDistortiond): + +# ONE_D_KEYS = ['label', 'cutout_mask', 'weight', 'unmodified_center_slice'] + + +# @property +# def p(self): +# return self.prob + +# @p.setter +# def p(self, value): +# self.prob = value + + +# def __call__(self, data): +# """ +# Args: +# spatial_size: spatial size of the grid. +# """ +# monai_dict = {} +# for key in self.keys: +# if key in self.ONE_D_KEYS: +# #print(key, data[key].data.dtype) +# #print(data[key].data.shape) +# monai_dict[key] = data[key].data[:,3, None] + +# # check if data is int +# #if data[key].data.dtype == torch.uint8: +# # monai_dict[key] = data[key].data.to(torch.float16) +# # print('casted', key, 'to float32') +# #else: + +# else: +# monai_dict[key] = data[key].data + +# out = super().__call__(monai_dict) + +# for key in self.keys: +# if key in self.ONE_D_KEYS: +# data[key].set_data(torch.nn.functional.pad(out[key], pad=(0,0,0,0,3,3), value=0)) +# #print(data[key].data.shape) +# else: +# data[key].set_data(out[key]) + +# return data + + + + +class RandGridDistortiond(): + + def __init__(self, **kwargs): + pass + + + @property + def p(self): + return self.prob + + @p.setter + def p(self, value): + self.prob = value + + + def __call__(self, data): + raise NotImplementedError('use monai version') + + return data + + +class CutoutBRATSTumor(SmartRandomCutout): + + def __init__(self, tumor_mask_hdf5: str, num_cuts: int = 1, cutout_value: int=0, random=False, **kwargs): + super().__init__(None, num_cuts, **kwargs) + self.cutout_value = cutout_value + self.random = random + + # load tumor masks + self.tumor_masks = [] + + #print('loading tumor masks for data augmentation') + # start_t = time.time() + #with h5py.File(tumor_mask_hdf5, 'r') as f: + f = h5py.File(tumor_mask_hdf5, 'r') + #for size in f.keys(): + #self.tumor_masks.extend(list(f[f'{size}']['mask_dataset'])) + sizes = list(f.keys()) + assert(len(sizes) == 1), 'expected only one size in tumor mask hdf5' + size = sizes[0] + self.tumor_masks = f[f'{size}']['mask_dataset'] + + #print(f'loading tumor masks took {time.time()-start_t} seconds') + + def get_desired_tumor_center(self, subject: Subject, midslice=3) -> np.ndarray: + """ + Get a random point within the brain mask as the center of the tumor + + :param subject: subject to get tumor center from + :return: tumor center + """ + + brain_mask = self.get_brain_mask(subject) + + assert(brain_mask is not None), 'brain mask not found - this can happen if convex hull fails; use precomputed brain mask instead' + + # get random point in brain + brain_points = brain_mask.nonzero()[:,[2,3]] + tumor_center = brain_points[np.random.randint(brain_points.shape[0])] + + #assert(brain_mask[0, midslice, tumor_center[0], tumor_center[1]]), 'random point not in brain' + + return tumor_center + + @staticmethod + def _set_random( + image: TensorArray, + mask: np.ndarray, + #constant: float = 0 + ) -> TensorArray: + #i_ini, j_ini, k_ini = index_ini + #i_fin, j_fin, k_fin = index_fin + #image[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] = np.random.rand(image[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin].shape) + + image[mask] = np.random.rand(image.shape[0])[mask] + return image + + @staticmethod + def _set_constant( + image: TensorArray, + mask: np.ndarray, + constant: float = 0 + ) -> TensorArray: + if isinstance(mask, np.ndarray): + mask = torch.from_numpy(mask) + + image[mask] = constant + return image + + @staticmethod + def _set_random( + image: TensorArray, + mask: np.ndarray, + ) -> TensorArray: + if isinstance(mask, np.ndarray): + mask = torch.from_numpy(mask) + + image[mask] = torch.rand(*image.shape)[mask] + return image + + @staticmethod + def _multiply( + image: TensorArray, + mask: np.ndarray, + constant: float = 0 + ) -> TensorArray: + if isinstance(mask, np.ndarray): + mask = torch.from_numpy(mask) + + image[mask] *= constant + return image + + + def get_random_tumor_mask(self, subject: Subject=None): + # TODO: maybe reorder dimensions + t_mask = self.tumor_masks[np.random.randint(len(self.tumor_masks))] + while t_mask.sum() == 0: + print('WARNING: tumor mask is empty, trying again') + t_mask = self.tumor_masks[np.random.randint(len(self.tumor_masks))] + return t_mask + + @staticmethod + def _apply_random_affine(mask: np.ndarray, mask_center: Tuple) -> np.ndarray: + """ + Apply random affine transformation to tumor mask + :param mask: tumor mask + :param subject: subject to get affine from + :return: transformed tumor mask + """ + rotation = (-180, 180) + translation = (0,0) + scaling = (0.5, 2) + shearing = (-0.1, 0.1) + + # get random parameters + rotation = np.random.randint(rotation[0], rotation[1]) + scaling = np.random.rand() * (scaling[1] - scaling[0]) + scaling[0] + shearing = np.random.rand() * (shearing[1] - shearing[0]) + shearing[0] + mask_center = np.array(mask_center) + np.random.randint(-5, 5, size=2) # add random offset to tumor center + + # get random affine transformation + return torchvision.transforms.functional.affine(torch.from_numpy(mask), angle=rotation, translate=translation, scale=scaling, shear=shearing, fill=0, center=list(mask_center)) + + @staticmethod + def _get_mask_center(mask): + + if isinstance(mask, np.ndarray): + mask = torch.from_numpy(mask) + #return np.round(np.mean(np.array(mask.nonzero())[[1,2]], axis=1)).astype(int) # get center of tumor mask in slice + + # Find indices where we have mass + _, x, y = torch.where(mask) + # mass_x and mass_y are the list of x indices and y indices of mass pixels + + cent_x = torch.mean(x, dtype=float).round().int() + cent_y = torch.mean(y, dtype=float).round().int() + + return cent_x, cent_y + + #return torch.median(mask.nonzero(), axis=0)[[0,1]].round().int() # get center of tumor mask in slice + + + def apply_transform(self, subject: Subject) -> Subject: + THICKSLICE_DIMENSION = 0 + + orig_image = self.get_images_dict(subject)['image'] + + MIDSLICE = orig_image.data.shape[THICKSLICE_DIMENSION]//2 + # if self.downweighting_factor > 0: + # wgt = subject['weight'].data.clone() + + midslice_no = orig_image.data.shape[THICKSLICE_DIMENSION]//2 + orig_slice= orig_image.data[midslice_no] + + tumor_mask = self.get_random_tumor_mask(subject=subject) + # pad tumor mask to match image size + if tumor_mask.shape[0] < orig_image.data.shape[2] or tumor_mask.shape[1] < orig_image.data.shape[3]: + tumor_mask = np.pad(tumor_mask, ( + (0, orig_image.data.shape[2] - tumor_mask.shape[0]), + (0, orig_image.data.shape[3] - tumor_mask.shape[1]), + (0,0)), #, (0,0), + 'constant', constant_values=0) + elif tumor_mask.shape[0] > orig_image.data.shape[2] or tumor_mask.shape[1] > orig_image.data.shape[3]: + tumor_mask = tumor_mask[:orig_image.data.shape[2], :orig_image.data.shape[3], :] + elif tumor_mask.shape[0] == orig_image.data.shape[2] and tumor_mask.shape[1] == orig_image.data.shape[3]: + pass + else: + raise ValueError('unexpected tumor mask shape') + + # bring thickness dimension to front + tumor_mask = np.moveaxis(tumor_mask, -1, 0) + + tumor_mask = tumor_mask.astype(bool) + + tumor_mask_center_pre_affine = self._get_mask_center(tumor_mask) + + + #plt.savefig('../../tmp/tumor_mask.png') + + # apply random affine transformation to tumor mask + tumor_mask = self._apply_random_affine(tumor_mask, tumor_mask_center_pre_affine) + + tumor_mask_center = self._get_mask_center(tumor_mask) + desired_center = self.get_desired_tumor_center(subject, midslice=MIDSLICE) + + + # # plot tumor mask with center + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(1,7, figsize=(20,5)) + # for i,ax in enumerate(axs): + # ax.imshow(tumor_mask[i].T) + # ax.scatter(tumor_mask_center[0], tumor_mask_center[1], c='r', label='tumor mask center') + # ax.scatter(desired_center[0], desired_center[1], c='b', label='desired center') + # ax.legend() + # # plt.imshow(tumor_mask[MIDSLICE]) + # # plt.scatter(tumor_mask_center[0], tumor_mask_center[1], c='r', label='tumor mask center') + # # plt.legend() + + + + to_translate = desired_center - torch.tensor(tumor_mask_center) + # translate tumor mask to desired center + tumor_mask = torch.roll(tumor_mask, (to_translate[0], to_translate[1]), dims=(1,2)) + + # # plot transformed tumor mask with desired center + # plt.figure() + # plt.imshow(tumor_mask[MIDSLICE].T) + + # plt.scatter(tumor_mask_center[0], tumor_mask_center[1], c='r', label='tumor mask center') + # plt.scatter(desired_center[0], desired_center[1], c='b', label='desired center') + # plt.legend() + # #plt.savefig('../../tmp/tumor_mask_transformed.png') + + # plt.figure() + # plt.imshow(self.get_brain_mask(subject)[0,MIDSLICE].T) + # plt.scatter(desired_center[0], desired_center[1], c='b', label='desired center') + # plt.legend() + + + tumor_mask = tumor_mask[np.newaxis, ...] # add batch dimension + + # print('tumor mask shape:', tumor_mask.shape) + # print('tumor mask center', tumor_mask_center) + # print('desired center', desired_center) + # print('to translate', to_translate) + # print('sum tumor mask', tumor_mask.sum()) + + + + #for name, image in self.get_images_dict(subject).items(): + + #print('applying cutout for', name) + + # if name != 'image': + # assert(orig_slice.shape == image.shape), f'expected slice dimension, but got {image.shape}, for {name}' + + #print(name) + + img = subject['image'].data.clone() + + # for image consider thickslices + #if name == 'image': # apply to slice if midslice is included in image (applies cutout to weights and segmentation) - experimental + if self.random: + img = self._set_random(img, mask=tumor_mask) + else: + img = self._set_constant(img, mask=tumor_mask, constant=0) # type: ignore[arg-type] # noqa: E501 + + # elif self.downweighting_factor > 0 and name == 'weight': + # print('reducing weight for cutout in image', name) + # img = self._multiply(img, mask=tumor_mask[:, THICKSLICE_DIMENSION], constant=self.downweighting_factor) + # else: + # raise ValueError(f'unknown image type {name}') + #img = self._set_constant(img, mask=tumor_mask[:, THICKSLICE_DIMENSION], constant=0) # type: ignore[arg-type] # noqa: E501 + + subject['image'].set_data(img) + + if self.downweighting_factor > 0: + #print('reducing weight for cutout in image', name) + wgt = self._multiply(subject['weight'].data.clone(), mask=tumor_mask, constant=self.downweighting_factor) + subject['weight'].set_data(wgt) + + if not 'cutout_mask' in subject.keys(): + subject.add_image(tio.LabelMap(tensor=torch.zeros(self.get_images_dict(subject)['label'].data.size(), dtype=bool)), 'cutout_mask') + subject['cutout_mask'] = self._set_constant(subject['cutout_mask'].data, mask=tumor_mask, constant=1) + + return subject + + + +class CutoutBRATSTumorDeterministic(CutoutBRATSTumor): + + def get_random_tumor_mask(self, subject: Subject): + idx = int.from_bytes(subject["subject_id"], byteorder="big") % len(self.tumor_masks) + t_mask = self.tumor_masks[idx] + return t_mask + + def get_tumor_mask(self, index: int): + return self.tumor_masks[index] + + def get_desired_tumor_center(self, subject: Subject, midslice=3) -> np.ndarray: + """ + Get a point within the brain mask as the center of the tumor + + :param subject: subject to get tumor center from + :return: tumor center + """ + + brain_mask = self.get_brain_mask(subject) + + assert(brain_mask is not None), 'brain mask not found - this can happen if convex hull fails; use precomputed brain mask instead' + + # get random point in brain + brain_points = brain_mask.nonzero()[:,[2,3]] + tumor_center = brain_points[int.from_bytes(subject["subject_id"], byteorder="big") % brain_points.shape[0]] + + #assert(brain_mask[0, midslice, tumor_center[0], tumor_center[1]]), 'random point not in brain' + + return tumor_center + +class CutoutTumorMask(IntensityTransform): + + def apply_transform(self, subject: Subject) -> Subject: + + if 'cutout_mask' in subject.keys() and subject['cutout_mask'].data.sum() > 0: + #subject['image'].data[subject['cutout_mask'].data.bool()] = torch.rand(subject['image'].data.shape)[subject['cutout_mask'].data.bool()] + subject['image'].data[subject['cutout_mask'].data.bool()] = torch.ones(subject['image'].data.shape)[subject['cutout_mask'].data.bool()] + + + return subject + +class CutoutTumorMaskInference(object): + + def __call__(self, subject: dict) -> dict: + if 'cutout_mask' in subject.keys(): + #subject['image'][subject['cutout_mask'].bool()] = torch.rand(subject['image'].shape)[subject['cutout_mask'].bool()] + subject['image'][subject['cutout_mask'].bool()] = torch.ones(subject['image'].shape)[subject['cutout_mask'].bool()] + + return subject + + + + +class CutoutHemisphere(IntensityTransform, RandomTransform): + + """ + Sets half of the image to zero - on the medial plane, with a small random offset + + TODO: determine medial plane - e.g. run make_upright on all training images, then and add to hdf5 + """ + + def __init__(self, orientation: str, left: bool, weight_multiplier: int=0.5, **kwargs): + super().__init__(**kwargs) + self.left = left + self.weight_reduction = weight_multiplier + + self.LR_axis, right_oriented = self._get_LR_axis(orientation) + if not right_oriented: + self.left = not self.left + + #if self.LR_axis == -1: + + #print(f'CutoutHemisphere: left = {self.left}, LR_axis = {self.LR_axis}, orientation = {orientation}') + if self.LR_axis == 0: # slice thickness dimension, skip + print('WARNING: disabling CutoutHemisphere transform, as it is not applicable to the slice thickness dimension') + self.enabled = False + else: + self.enabled = True + + self.LR_axis += 1 # add batch dimension + + @staticmethod + def _convert_orientation(orientation: str) -> str: + if orientation == 'sagittal': + orientation = 'RSA' + elif orientation == 'axial': + orientation = 'SAR' + elif orientation == 'coronal': + orientation = 'ARS' + return orientation + + @staticmethod + def _get_LR_axis(orientation: str) -> int: + right_oriented = True + orientation = CutoutHemisphere._convert_orientation(orientation) # handle FastSurfer slicing strings + + if orientation == 'RSA' or orientation == 'SRA' or orientation == 'ARS': + LR_axis = orientation.index('R') + else: + print(f'WARNING - orientation not recognized, standard FastSurfer orientations are RSA [saggital], SRA [axial], ARS [coronal] - instead got {orientation}') + if 'R' in orientation and 'L' not in orientation: + LR_axis = orientation.index('R') + elif 'L' in orientation and 'R' not in orientation: + right_oriented = False + LR_axis = orientation.index('L') + else: + raise ValueError('orientation {orientation} not recognized') + + return LR_axis, right_oriented + + def apply_transform(self, subject: Subject) -> Subject: + if self.enabled: + image = subject['image'] + weight = subject['weight'] + sagittal_size = image.shape[self.LR_axis] # LR axis size + + if self.LR_axis != 0: + brain_center = getBrainCenter(subject) + mid_slice = brain_center[self.LR_axis - 1] # not actually the mid slice, but a crude approximation (not accounting for rotation) + + slc = [slice(None)] * len(image.data.shape) + if self.left: + slc[self.LR_axis] = slice(mid_slice, sagittal_size) + else: + slc[self.LR_axis] = slice(0, mid_slice) + image.data[slc] = 0 + weight.data[slc] *= self.weight_reduction + + + + # slc = [slice(None)] * len(image.data.shape) + # slc[1] = slice(image.data.shape[1]//2, image.data.shape[1]) + # image.data[slc] = torch.max(image.data) + + # slc = [slice(None)] * len(image.data.shape) + # slc[2] = slice(image.data.shape[2]//2, image.data.shape[2]) + # image.data[slc] = torch.max(image.data//2) # right + + # slc = [slice(None)] * len(image.data.shape) + # slc[3] = slice(image.data.shape[3]//2, image.data.shape[3]) + # image.data[slc] = 0 # inferior + return subject + + +class FlipLeftRight(RandomFlip): + """ + Flip depending on orientation + """ + def __init__(self, orientation, **kwargs): + axis, _ = CutoutHemisphere._get_LR_axis(orientation) + + # NOTE: flip probability is per axis, since we only flip one axis this is the same as setting p (probability) to 1.0 + super().__init__(axis, flip_probability=1.0, **kwargs) + + + + + +class CutoutRandomHemisphere(CutoutHemisphere): + + """ + Sets half of the image to zero - on the medial plane, with a small random offset + + TODO: determine medial plane + """ + + def __init__(self, orientation, weight_multiplier=0.5, **kwargs): + super().__init__(left=True, orientation=orientation, weight_multiplier=weight_multiplier, **kwargs) + + def __call__(self, subject: Subject) -> Subject: + self.left = bool(torch.randint(2, size=(1,)).item()) + return super().__call__(subject) + + + + + +class CutoutRandomTumor(RandomAugmentation): + """ + Load tumor mask from brats dataset and generate cutout from tumor mask + """ + + def __init__(self, tumor_seg_paths, p=1.0, pre_load=True, img_size=(256,256,256)): + self.probability = p + self.img_size = img_size + + if pre_load: + self.tumor_masks = [] + + for path in tumor_seg_paths: + + tumor_mask = nib.load(path).get_fdata() + + #tumor_mask.data = tumor_mask.get_fdata()[tumor_mask.get_fdata() > 0] = 1 + #tumor_mask = conform(tumor_mask, order=0, resample_only=True).get_fdata() # NN interpolation & no rescaling keeps intesities + tumor_mask[tumor_mask > 0] = 1 + tumor_mask = tumor_mask.astype(bool) + tumor_mask = self._padding(tumor_mask, *img_size) + + self.tumor_masks.append(tumor_mask) + else: + raise NotImplementedError('only pre-loading is available for now') + + @staticmethod + def _padding(array, xx, yy, zz): + """ + :param array: numpy array + :param xx: desired height + :param yy: desirex width + :return: padded array + """ + h = array.shape[0] + w = array.shape[1] + d = array.shape[2] + + a = (xx - h) // 2 + aa = xx - a - h + + b = (yy - w) // 2 + bb = yy - b - w + + c = (zz - d) // 2 + cc = zz - c - d + + return np.pad(array, pad_width=((a, aa), (b, bb), (c, cc)), mode='constant') + + @staticmethod + def random_small_offset(x, max=None): + x = x + random.randint(0,4) + x = 0 if x<0 else x + if max is not None: + x = max if x>max else x + return x + + + def bbox_3D(self, img): + r = np.any(img, axis=(1, 2)) + c = np.any(img, axis=(0, 2)) + z = np.any(img, axis=(0, 1)) + + rmin, rmax = np.where(r)[0][[0, -1]] + cmin, cmax = np.where(c)[0][[0, -1]] + zmin, zmax = np.where(z)[0][[0, -1]] + + # dilate by 5 mm + rmin -= 5 + rmax += 5 + cmin -= 5 + cmax += 5 + zmin -= 5 + zmax += 5 + + # add variance to tumor masks and check for out of bounds + rmin = self.random_small_offset(rmin, max=self.img_size[0]) + rmax = self.random_small_offset(rmax, max=self.img_size[0]) + + cmin = self.random_small_offset(cmin, max=self.img_size[1]) + cmax = self.random_small_offset(cmax, max=self.img_size[1]) + + zmin = self.random_small_offset(zmin, max=self.img_size[2]) + zmax = self.random_small_offset(zmax, max=self.img_size[2]) + + + bounding_mask = np.zeros(self.img_size, dtype=bool) + bounding_mask[rmin:rmax, cmin:cmax, zmin:zmax] = 1 + + return bounding_mask + + + def __call__(self, sample) -> dict: + if self.decision(): + sample['image'][self.bbox_3D(random.choice(self.tumor_masks))] = 0 + else: + pass + + return {**sample} + + + +# --------------------- general helper functions + +def getConvexHull(subject: Subject): + """ + Get the convex hull of the labels - approximation for brain mask + """ + brain_mask = subject['label'].numpy()[:,3].squeeze() > 0 + return Delaunay(np.array(brain_mask.nonzero()).T), brain_mask + +def in_hull(p, hull): + """ + Test if points in `p` are in `hull` + + `p` should be a `NxK` coordinates of `N` points in `K` dimensions + `hull` is either a scipy.spatial.Delaunay object or the `MxK` array of the + coordinates of `M` points in `K`dimensions for which Delaunay triangulation + will be computed + """ + if not isinstance(hull,Delaunay): + hull = Delaunay(hull) + + return hull.find_simplex(p)>=0 + + +def getBrainCenter(subject: Subject): + """ + get the center of the labels + """ + return torch.round(torch.mean(subject['label'].data[:,3].nonzero(),axis=0, dtype=torch.float32)).int() + diff --git a/CCNet/data_loader/data_utils.py b/CCNet/data_loader/data_utils.py new file mode 100644 index 00000000..3f63d5d8 --- /dev/null +++ b/CCNet/data_loader/data_utils.py @@ -0,0 +1,792 @@ +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# IMPORTS +import sys +from typing import Union + +import numpy as np +import torch +from skimage.measure import label, regionprops +from scipy.ndimage import binary_erosion, binary_closing, filters, uniform_filter, generate_binary_structure +import scipy.ndimage.morphology as morphology +import nibabel as nib +import pandas as pd + +from FastSurferCNN.utils import logging +from FastSurferCNN.data_loader.conform import is_conform, conform, check_affine_in_nifti + + +## +# Global Vars +## +SUPPORTED_OUTPUT_FILE_FORMATS = ['mgz', 'nii', 'nii.gz'] +LOGGER = logging.getLogger(__name__) + +## +# Helper Functions +## + + +# Conform an MRI brain image to UCHAR, RAS orientation, and 1mm or minimal isotropic voxels +def load_and_conform_image(img_filename, interpol=1, logger=LOGGER, conform_min = False): + """ + Function to load MRI image and conform it to UCHAR, RAS orientation and 1mm or minimum isotropic voxels size + (if it does not already have this format) + :param str img_filename: path and name of volume to read + :param int interpol: interpolation order for image conformation (0=nearest,1=linear(default),2=quadratic,3=cubic) + :param logger logger: Logger to write output to (default = STDOUT) + :param: boolean: conform_min: conform image to minimal voxel size (for high-res) (Default = False) + :return: nibabel.MGHImage header_info: header information of the conformed image + :return: np.ndarray affine_info: affine information of the conformed image + :return: nibabel.MGHImage orig: conformed image + """ + orig = nib.load(img_filename) + # is_conform and conform accept numeric values and the string 'min' instead of the bool value + _conform_vox_size = 'min' if conform_min else 1. + if not is_conform(orig, conform_vox_size=_conform_vox_size): + + logger.info('Conforming image to UCHAR, RAS orientation, and minimum isotropic voxels') + + if len(orig.shape) > 3 and orig.shape[3] != 1: + sys.exit('ERROR: Multiple input frames (' + format(orig.shape[3]) + ') not supported!') + + # Check affine if image is nifti image + if img_filename[-7:] == ".nii.gz" or img_filename[-4:] == ".nii": + if not check_affine_in_nifti(orig, logger=logger): + sys.exit("ERROR: inconsistency in nifti-header. Exiting now.\n") + + # conform + orig = conform(orig, interpol, conform_vox_size=_conform_vox_size) + + # Collect header and affine information + header_info = orig.header + affine_info = orig.affine + orig = np.asanyarray(orig.dataobj) + + return header_info, affine_info, orig + + +# Save image routine +def save_image(header_info, affine_info, img_array, save_as, dtype=None): + """ + Save an image (nibabel MGHImage), according to the desired output file format. + Supported formats are defined in supported_output_file_formats. + :param numpy.ndarray img_array: an array containing image data + :param numpy.ndarray affine_info: image affine information + :param nibabel.freesurfer.mghformat.MGHHeader header_info: image header information + :param str save_as: name under which to save prediction; this determines output file format + :param type dtype: image array type; if provided, the image object is explicitly set to match this type + :return None: saves predictions to save_as + """ + + assert any(save_as.endswith(file_ext) for file_ext in SUPPORTED_OUTPUT_FILE_FORMATS), \ + 'Output filename does not contain a supported file format (' + ', '.join( + file_ext for file_ext in SUPPORTED_OUTPUT_FILE_FORMATS) + ')!' + + mgh_img = None + if save_as.endswith('mgz'): + mgh_img = nib.MGHImage(img_array, affine_info, header_info) + elif any(save_as.endswith(file_ext) for file_ext in ['nii', 'nii.gz']): + mgh_img = nib.nifti1.Nifti1Pair(img_array, affine_info, header_info) + + if dtype is not None: + mgh_img.set_data_dtype(dtype) + + if any(save_as.endswith(file_ext) for file_ext in ['mgz', 'nii']): + nib.save(mgh_img, save_as) + elif save_as.endswith('nii.gz'): + # For correct outputs, nii.gz files should be saved using the nifti1 sub-module's save(): + nib.nifti1.save(mgh_img, save_as) + + +# Transformation for mapping +def transform_axial(vol, coronal2axial=True): + """ + Function to transform volume into Axial axis and back (RAS [coronal] to ASR [axial]) + :param np.ndarray vol: image volume to transform + :param bool coronal2axial: transform from coronal to axial = True (default), + transform from axial to coronal = False + :return: + """ + if coronal2axial: + return np.moveaxis(vol, [0, 1, 2], [1, 2, 0]) + else: + return np.moveaxis(vol, [0, 1, 2], [2, 0, 1]) + + +def transform_sagittal(vol, coronal2sagittal=True): + """ + Function to transform volume into Sagittal axis and back (RAS [coronal] to SAR [saggital]) + :param np.ndarray vol: image volume to transform + :param bool coronal2sagittal: transform from coronal to sagittal = True (default), + transform from sagittal to coronal = False + :return: + """ + if coronal2sagittal: + return np.moveaxis(vol, [0, 1, 2], [2, 1, 0]) + else: + return np.moveaxis(vol, [0, 1, 2], [2, 1, 0]) + + +# Thick slice generator (for eval) and blank slices filter (for training) +def get_thick_slices(img_data, slice_thickness : int = 3, pad : bool = True): + """ + Function to extract thick slices from the image + (feed slice_thickness preceeding and suceeding slices to network, + label only middle one) + :param np.ndarray img_data: 3D MRI image read in with nibabel + :param int slice_thickness: number of slices to stack on top and below slice of interest (default=3) + :return: + """ + + # Pad image if wanted + if pad: + img_data_pad = np.pad(img_data, ((0, 0), (0, 0), (slice_thickness, slice_thickness)), mode='edge') + else: + img_data_pad = img_data + + + h, w, d = img_data_pad.shape + img_data_pad = np.expand_dims(img_data_pad, axis=3) + img_data_thick = np.ndarray((h, w, 2*slice_thickness+1, 0), dtype=np.uint8) + + for slice_idx in range(slice_thickness, d-slice_thickness): + img_data_thick = np.append(img_data_thick, img_data_pad[:, :, slice_idx-slice_thickness:slice_idx+slice_thickness+1, :], axis=3) + + # (plane1, plane2, slice_thickness, no_thick_slices) -> (plane1, plane2, no_thick_slices, slice_thickness) + img_data_thick = np.transpose(img_data_thick, (0, 1, 3, 2)) + + return img_data_thick + + +def filter_blank_slices_thick(volume_list: list, label_vol: np.ndarray, threshold=50): + """ + Function to filter blank slices from the volume using the label volume + :param np.ndarray img_vol: orig image volume + :param np.ndarray label_vol: label images (ground truth) + :param np.ndarray weight_vol: weight corresponding to labels + :param int threshold: threshold for number of pixels needed to keep slice (below = dropped) + :return: + """ + # Get indices of all slices with more than threshold labels/pixels + select_slices = (np.sum((label_vol != 0)*1, axis=(0, 1)) > threshold) + + # Retain only slices with more than threshold labels/pixels + return_list = [] + for v in volume_list: + return_list.append(v[:, :, select_slices]) + + return return_list + + +# weight map generator +def create_weight_mask(mapped_aseg, max_weight=5, max_edge_weight=5, max_hires_weight=None, ctx_thresh=33, + mean_filter=False, cortex_mask=True, gradient=True): + """ + Function to create weighted mask - with median frequency balancing and edge-weighting + :param np.ndarray mapped_aseg: segmentation to create weight mask from + :param int max_weight: maximal weight on median weights (cap at this value) + :param int max_edge_weight: maximal weight on gradient weight (cap at this value) + :param int max_hires_weight: maximal weight on hires weight (cap at this value) + :param int ctx_thresh: label value of cortex (above = cortical parcels) + :param bool mean_filter: flag, set to add mean_filter mask (default = False) + :param bool cortex_mask: flag, set to create cortex weight mask (default=True) + :param bool gradient: flag, set to create gradient mask (default = True) + :return np.ndarray: weights + """ + labels, counts = np.unique(mapped_aseg, return_counts=True) + + for indx, label in enumerate(labels): + mapped_aseg[mapped_aseg==label] = indx + + # Median Frequency Balancing + class_wise_weights = np.median(counts) / counts + class_wise_weights[class_wise_weights > max_weight] = max_weight + (h, w, d) = mapped_aseg.shape + + weights_mask = np.reshape(class_wise_weights[mapped_aseg.ravel()], (h, w, d)) + + # Gradient Weighting + if gradient: + (gx, gy, gz) = np.gradient(mapped_aseg) + grad_weight = max_edge_weight * np.asarray( + np.power(np.power(gx, 2) + np.power(gy, 2) + np.power(gz, 2), 0.5) > 0, + dtype='float') + + weights_mask += grad_weight + + if max_hires_weight is not None: + # High-res Weighting + LOGGER.info(f"Adding hires weight mask deep sulci and WM with weight {max_hires_weight}") + mask1 = deep_sulci_and_wm_strand_mask(mapped_aseg, structure=np.ones((3, 3, 3)), ctx_thresh=ctx_thresh) + weights_mask += mask1 * max_hires_weight + + if cortex_mask: + LOGGER.info(f"Adding cortex mask with weight {max_hires_weight}") + mask2 = cortex_border_mask(mapped_aseg, structure=np.ones((3, 3, 3)), ctx_thresh=ctx_thresh) + weights_mask += mask2 * (max_hires_weight) // 2 + + if mean_filter: + weights_mask = uniform_filter(weights_mask, size=3) + + return weights_mask + + +def cortex_border_mask(label, structure, ctx_thresh=33): + """ + Function to erode the cortex of a given mri image to create + the inner gray matter mask (outer most cortex voxels) + :param np.ndarray label: ground truth labels. + :param np.ndarray structure: structuring element to erode with + :param int ctx_thresh: label value of cortex (above = cortical parcels) + :return: np.ndarray outer GM layer + """ + # create aseg brainmask, erode it and subtract from itself + bm = np.clip(label, a_max=1, a_min=0) + eroded = binary_erosion(bm, structure=structure) + diff_im = np.logical_xor(eroded, bm) + + # only keep values associated with the cortex + diff_im[(label <= ctx_thresh)] = 0 # > 33 (>19) = > 1002 in FS space (full (sag)), + print("Remaining voxels cortex border: ", np.unique(diff_im, return_counts=True)) + return diff_im + + +def deep_sulci_and_wm_strand_mask(volume, structure, iteration=1, ctx_thresh=33): + """ + Function to get a binary mask of deep sulci and small white matter strands + by using binary closing (erosion and dilation) + + :param np.ndarray volume: loaded image (aseg, label space) + :param np.ndarray structure: structuring element (e.g. np.ones((3, 3, 3))) + :param int iteration: number of times mask should be dilated + eroded (default=1) + :param int ctx_thresh: label value of cortex (above = cortical parcels) + :return np.ndarray: sulcus + wm mask + """ + # Binarize label image (cortex = 1, everything else = 0) + empty_im = np.zeros(shape=volume.shape) + empty_im[volume > ctx_thresh] = 1 # > 33 (>19) = >1002 in FS LUT (full (sag)) + + # Erode the image + eroded = binary_closing(empty_im, iterations=iteration, structure=structure) + + # Get difference between eroded and original image + diff_image = np.logical_xor(empty_im, eroded) + LOGGER.info(f"Remaining voxels sulci/wm strand: {np.unique(diff_image, return_counts=True)}") + return diff_image + + +# Label mapping functions (to aparc (eval) and to label (train)) +def read_classes_from_lut(lut_file): + """ + Function to read in FreeSurfer-like LUT table + :param str lut_file: path and name of FreeSurfer-style LUT file with classes of interest + Example entry: + ID LabelName R G B A + 0 Unknown 0 0 0 0 + 1 Left-Cerebral-Exterior 70 130 180 0 + :return pd.Dataframe: DataFrame with ids present, name of ids, color for plotting + """ + # Read in file + separator = {"tsv": "\t", "csv": ",", "txt": " "} + return pd.read_csv(lut_file, sep=separator[lut_file[-3:]]) + + +def map_label2aparc_aseg(mapped_aseg, labels): + """ + Function to perform look-up table mapping from sequential label space to LUT space + :param torch.Tensor mapped_aseg: label space segmentation (aparc.DKTatlas + aseg) + :param np.ndarray labels: list of labels defining LUT space + :return: + """ + if isinstance(labels, np.ndarray): + labels = torch.from_numpy(labels) + labels = labels.to(mapped_aseg.device) + return labels[mapped_aseg] + + +def clean_cortex_labels(aparc): + """ + Function to clean up aparc segmentations: + Map undetermined and optic chiasma to BKG + Map Hypointensity classes to one + Vessel to WM + 5th Ventricle to CSF + Remaining cortical labels to BKG + :param np.array aparc: + :return np.array: cleaned aparc + """ + aparc[aparc == 80] = 77 # Hypointensities Class + aparc[aparc == 85] = 0 # Optic Chiasma to BKG + aparc[aparc == 62] = 41 # Right Vessel to Right WM + aparc[aparc == 30] = 2 # Left Vessel to Left WM + aparc[aparc == 72] = 24 # 5th Ventricle to CSF + aparc[aparc == 29] = 0 # left-undetermined to 0 + aparc[aparc == 61] = 0 # right-undetermined to 0 + + aparc[aparc == 3] = 0 # Map Remaining Cortical labels to background + aparc[aparc == 42] = 0 + return aparc + + +def fill_unknown_labels_per_hemi(gt, unknown_label, cortex_stop): + """ + Function to replace label 1000 (lh unknown) and 2000 (rh unknown) with closest class for each voxel. + :param np.ndarray gt: ground truth segmentation with class unknown + :param int unknown_label: class label for unknown (lh: 1000, rh: 2000) + :param int cortex_stop: class label at which cortical labels of this hemi stop (lh: 2000, rh: 3000) + :return: + """ + # Define shape of image and dilation element + h, w, d = gt.shape + struct1 = generate_binary_structure(3, 2) + + # Get indices of unknown labels, dilate them to get closest sorrounding parcels + unknown = gt == unknown_label + unknown = (morphology.binary_dilation(unknown, struct1) ^ unknown) + list_parcels = np.unique(gt[unknown]) + + # Mask all subcortical structues (fill unknown with closest cortical parcels only) + mask = (list_parcels > unknown_label) & (list_parcels < cortex_stop) + list_parcels = list_parcels[mask] + + # For each closest parcel, blur label with gaussian filter (spread), append resulting blurred images + blur_vals = np.ndarray((h, w, d, 0), dtype=np.float) + for idx in range(len(list_parcels)): + aseg_blur = filters.gaussian_filter(1000 * np.asarray(gt == list_parcels[idx], dtype=np.float), sigma=5) + blur_vals = np.append(blur_vals, np.expand_dims(aseg_blur, axis=3), axis=3) + + # Get for each position parcel with maximum value after blurring (= closest parcel) + unknown = np.argmax(blur_vals, axis=3) + unknown = np.reshape(list_parcels[unknown.ravel()], (h, w, d)) + + # Assign the determined closest parcel to the unknown class (case-by-case basis) + mask = gt == unknown_label + gt[mask] = unknown[mask] + + return gt + + +def fuse_cortex_labels(aparc): + """ + Fuse cortical parcels on left/right hemisphere (reduce aparc classes) + :param np.ndarray aparc: anatomical segmentation with cortical parcels + :return: anatomical segmentation with reduced number of cortical parcels + """ + aparc_temp = aparc.copy() + + # Map undetermined classes + aparc = clean_cortex_labels(aparc) + + # Fill label unknown + if np.any(aparc == 1000): + aparc = fill_unknown_labels_per_hemi(aparc, 1000, 2000) + if np.any(aparc == 2000): + aparc = fill_unknown_labels_per_hemi(aparc, 2000, 3000) + + # De-lateralize parcels + cortical_label_mask = (aparc >= 2000) & (aparc <= 2999) + aparc[cortical_label_mask] = aparc[cortical_label_mask] - 1000 + + # Re-lateralize Cortical parcels in close proximity + aparc[aparc_temp == 2014] = 2014 + aparc[aparc_temp == 2028] = 2028 + aparc[aparc_temp == 2012] = 2012 + aparc[aparc_temp == 2016] = 2016 + aparc[aparc_temp == 2002] = 2002 + aparc[aparc_temp == 2023] = 2023 + aparc[aparc_temp == 2017] = 2017 + aparc[aparc_temp == 2024] = 2024 + aparc[aparc_temp == 2010] = 2010 + aparc[aparc_temp == 2013] = 2013 + aparc[aparc_temp == 2025] = 2025 + aparc[aparc_temp == 2022] = 2022 + aparc[aparc_temp == 2021] = 2021 + aparc[aparc_temp == 2005] = 2005 + + return aparc + + +def split_cortex_labels(aparc): + """ + Splot cortex labels to completely de-lateralize structures + :param np.ndarray aparc: anatomical segmentation and parcellation from network + :return np.ndarray: re-lateralized aparc + """ + # Post processing - Splitting classes + # Quick Fix for 2026 vs 1026; 2029 vs. 1029; 2025 vs. 1025 + rh_wm = get_largest_cc(aparc == 41) + lh_wm = get_largest_cc(aparc == 2) + rh_wm = regionprops(label(rh_wm, background=0)) + lh_wm = regionprops(label(lh_wm, background=0)) + centroid_rh = np.asarray(rh_wm[0].centroid) + centroid_lh = np.asarray(lh_wm[0].centroid) + + labels_list = np.array([1003, 1006, 1007, 1008, 1009, 1011, + 1015, 1018, 1019, 1020, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035]) + + for label_current in labels_list: + + label_img = label(aparc == label_current, connectivity=3, background=0) + + for region in regionprops(label_img): + + if region.label != 0: # To avoid background + + if np.linalg.norm(np.asarray(region.centroid) - centroid_rh) < np.linalg.norm( + np.asarray(region.centroid) - centroid_lh): + mask = label_img == region.label + aparc[mask] = label_current + 1000 + + # Quick Fixes for overlapping classes + aseg_lh = filters.gaussian_filter(1000 * np.asarray(aparc == 2, dtype=np.float32), sigma=3) + aseg_rh = filters.gaussian_filter(1000 * np.asarray(aparc == 41, dtype=np.float32), sigma=3) + + lh_rh_split = np.argmax(np.concatenate((np.expand_dims(aseg_lh, axis=3), np.expand_dims(aseg_rh, axis=3)), axis=3), + axis=3) + + # Problematic classes: 1026, 1011, 1029, 1019 + for prob_class_lh in [1011, 1019, 1026, 1029]: + prob_class_rh = prob_class_lh + 1000 + mask_prob_class = (aparc == prob_class_lh) | (aparc == prob_class_rh) + mask_lh = np.logical_and(mask_prob_class, lh_rh_split == 0) + mask_rh = np.logical_and(mask_prob_class, lh_rh_split == 1) + + aparc[mask_lh] = prob_class_lh + aparc[mask_rh] = prob_class_rh + + return aparc + + +def remove_suffix(input_string, suffix): + if suffix and input_string.endswith(suffix): + return input_string[:-len(suffix)] + return input_string + +def unify_lateralized_labels(lut, combi=("Left-", "Right-")): + """ + Function to generate lookup dictionary of left-right labels + :param str or pd.DataFrame lut: either lut-file string to load or pandas dataframe + Example entry: + ID LabelName R G B A + 0 Unknown 0 0 0 0 + 1 Left-Cerebral-Exterior 70 130 180 0 + :param list(str) combi: Prefix or labelnames to combine. Default: Left- and Right- + :return dict: dictionary mapping between left and right hemispheres + """ + if isinstance(lut, str): + lut = read_classes_from_lut(lut) + left = lut[["ID", "LabelName"]][lut["LabelName"].str.startswith(combi[0])] + right = lut[["ID", "LabelName"]][lut["LabelName"].str.startswith(combi[1])] + #left["LabelName"] = left["LabelName"].str.removeprefix(combi[0]) # only for python3.9 and above + #right["LabelName"] = right["LabelName"].str.removeprefix(combi[1]) + left["LabelName"] = left["LabelName"].str.replace(combi[0], "") + right["LabelName"] = right["LabelName"].str.replace(combi[1], "") + mapp = left.merge(right, on="LabelName") + return pd.Series(mapp.ID_y.values, index=mapp.ID_x).to_dict() + + +def get_labels_from_lut(lut, label_extract=("Left-", "ctx-rh")): + """ + Function to extract + :param str of pd.DataFrame lut: FreeSurfer like LookUp Table (either path to it + or already loaded as pandas DataFrame. + Example entry: + ID LabelName R G B A + 0 Unknown 0 0 0 0 + 1 Left-Cerebral-Exterior 70 130 180 0 + :param tuple(str) label_extract: suffix of label names to mask for sagittal labels + Default: "Left-" and "ctx-rh" + :return np.ndarray: full label list + :return np.ndarray: sagittal label list + """ + if isinstance(lut, str): + lut = read_classes_from_lut(lut) + mask = lut["LabelName"].str.startswith(label_extract) + return lut["ID"].values, lut["ID"][~mask].values + + +def map_aparc_aseg2label(aseg, labels, labels_sag, sagittal_lut_dict, aseg_nocc=None, processing="aparc"): + """ + Function to perform look-up table mapping of aparc.DKTatlas+aseg.mgz data to label space + :param np.ndarray aseg: ground truth aparc+aseg + :param np.ndarray labels: labels to use (extracted from LUT with get_labels_from_lut) + :param np.ndarray labels_sag: sagittal labels to use (extracted from LUT with + get_labels_from_lut) + :param dict(int) sagittal_lut_dict: left-right label mapping (can be extracted with + unify_lateralized_labels from LUT) + :param str processing: should be set to "aparc" or "aseg" for additional mappings (hard-coded) + :param None/np.ndarray aseg_nocc: ground truth aseg without corpus callosum segmentation + :return: + """ + # If corpus callosum is not removed yet, do it now + if aseg_nocc is not None: + cc_mask = (aseg >= 251) & (aseg <= 255) + aseg[cc_mask] = aseg_nocc[cc_mask] + + if processing == "aparc": + LOGGER.info("APARC PROCESSING") + aseg = fuse_cortex_labels(aseg) + + elif processing == "aseg": + LOGGER.info("ASEG PROCESSING") + aseg[aseg == 1000] = 3 # Map unknown to cortex + aseg[aseg == 2000] = 42 + aseg[aseg == 80] = 77 # Hypointensities Class + aseg[aseg == 85] = 0 # Optic Chiasma to BKG + aseg[aseg == 62] = 41 # Right Vessel to Right WM + aseg[aseg == 30] = 2 # Left Vessel to Left WM + aseg[aseg == 72] = 24 # 5th Ventricle to CSF + + assert not np.any(251 <= aseg), "Error: CC classes (251-255) still exist in aseg {}".format(np.unique(aseg)) + assert np.any(aseg == 3) and np.any(aseg == 42), "Error: no cortical marker detected {}".format(np.unique(aseg)) + elif processing == "cc": + LOGGER.info("CC PROCESSING") + + assert set(labels).issuperset( + np.unique(aseg)), "Error: segmentation image contains classes not listed in the labels: \n{}\n{}".format( + np.unique(aseg), labels) + + h, w, d = aseg.shape + lut_aseg = np.zeros(max(labels) + 1, dtype='uint16') + for idx, value in enumerate(labels): + lut_aseg[value] = idx + + # Remap Label Classes - Perform LUT Mapping - Coronal, Axial + mapped_aseg = lut_aseg.ravel()[aseg.ravel()] + mapped_aseg = mapped_aseg.reshape((h, w, d)) + + if processing == "aparc": + cortical_label_mask = (aseg >= 2000) & (aseg <= 2999) + aseg[cortical_label_mask] = aseg[cortical_label_mask] - 1000 + + # For sagittal, all Left hemispheres will be mapped to right, ctx the otherway round + # If you use your own LUT, make sure all per-hemi labels have the corresponding prefix + # Map Sagittal Labels + for left, right in sagittal_lut_dict.items(): + aseg[aseg == left] = right + + h, w, d = aseg.shape + lut_aseg = np.zeros(max(labels_sag) + 1, dtype='uint16') + for idx, value in enumerate(labels_sag): + lut_aseg[value] = idx + + # Remap Label Classes - Perform LUT Mapping - Coronal, Axial + mapped_aseg_sag = lut_aseg.ravel()[aseg.ravel()] + mapped_aseg_sag = mapped_aseg_sag.reshape((h, w, d)) + + return mapped_aseg, mapped_aseg_sag + + +def sagittal_coronal_remap_lookup(x): + """ + Dictionary mapping to convert left labels to corresponding right labels for aseg + :param int x: label to look up + :return: + """ + return { + 2: 41, + 3: 42, + 4: 43, + 5: 44, + 7: 46, + 8: 47, + 10: 49, + 11: 50, + 12: 51, + 13: 52, + 17: 53, + 18: 54, + 26: 58, + 28: 60, + 31: 63, + }[x] + + +def infer_mapping_from_lut(num_classes_full, lut): + labels, labels_sag = unify_lateralized_labels(lut) + idx_list = np.ndarray(shape=(num_classes_full,), dtype=np.int16) + for idx in range(len(labels)): + idx_in_sag = np.where(labels_sag == labels[idx])[0] + if idx_in_sag.size == 0: # Empty not subcortical + idx_in_sag = np.where(labels_sag == (labels[idx] - 1000))[0] + + if idx_in_sag.size == 0: + current_label_sag = sagittal_coronal_remap_lookup(labels[idx]) + idx_in_sag = np.where(labels_sag == current_label_sag)[0] + + idx_list[idx] = idx_in_sag + return idx_list + + +def map_prediction_sagittal2full(prediction_sag, num_classes=51, lut=None): + """ + Function to remap the prediction on the sagittal network to full label space used by coronal and axial networks + (full aparc.DKTatlas+aseg.mgz) + :param prediction_sag: sagittal prediction (labels) + :param int num_classes: number of SAGGITAL classes (96 for full classes, 51 for hemi split, 21 for aseg) + :param str/None lut: look-up table listing class labels + :return: Remapped prediction + """ + if num_classes == 96: + idx_list = np.asarray([0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 1, 2, 3, 14, 15, 4, 16, + 17, 18, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50], dtype=np.int16) + + elif num_classes == 51: + idx_list = np.asarray([0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 1, 2, 3, 14, 15, 4, 16, + 17, 18, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 20, 22, 27, + 29, 30, 31, 33, 34, 38, 39, 40, 41, 42, 45], dtype=np.int16) + + elif num_classes == 21: + idx_list = np.asarray([0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1, 2, 3, 15, 16, 4, + 17, 18, 19, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20], dtype=np.int16) + elif num_classes == 1: + idx_list = np.asarray([0, 1, 1]) + elif num_classes == 3: + idx_list = np.asarray([0, 1, 2], dtype=np.int16) + elif num_classes == 2: + idx_list = np.asarray([0, 1], dtype=np.int16) + else: + assert lut is not None, 'lut is not defined!' + idx_list = infer_mapping_from_lut(num_classes, lut) + prediction_full = prediction_sag[:, idx_list, :, :] + return prediction_full + + +# Clean up and class separation +def bbox_3d(img): + """ + Function to extract the three-dimensional bounding box coordinates. + :param np.ndarray img: mri image + :return: + """ + + r = np.any(img, axis=(1, 2)) + c = np.any(img, axis=(0, 2)) + z = np.any(img, axis=(0, 1)) + + rmin, rmax = np.where(r)[0][[0, -1]] + cmin, cmax = np.where(c)[0][[0, -1]] + zmin, zmax = np.where(z)[0][[0, -1]] + + return rmin, rmax, cmin, cmax, zmin, zmax + + +def get_largest_cc(segmentation): + """ + Function to find largest connected component of segmentation. + :param np.ndarray segmentation: segmentation + :return: + """ + labels = label(segmentation, connectivity=3, background=0) + + bincount = np.bincount(labels.flat) + background = np.argmax(bincount) + bincount[background] = -1 + + largest_cc = labels == np.argmax(bincount) + + return largest_cc + + +def map_incrementing(mapped_aseg, lut): + """Map labels to an incrementing space""" + for idx, label in enumerate(lut['ID']): + mapped_aseg[mapped_aseg == label] = idx + return mapped_aseg + +def subset_volume_plane(volume: np.ndarray, plane: Union[int, float] = 128, thickness: int = 5, axis: int = 0): + """Return a subset of the volume around the plane + + Args: + volume (np.ndarray): volume to subset + plane (Union[int, float], optional): plane of the subset. Defaults to 127.5. + thickness (int, optional): Total thickness of the subset. Defaults to 7. If odd, the plane should be inbetween two indices. + axis (int, optional): axis of the plane. Defaults to 0. + + Returns: + np.ndarray: subset of the volume. + + Raises: + ValueError: If the thickness and plane combination is unknown. + ValueError: If the axis is unknown. + + Notes: + If plane is an integer, it is included in the subset, the subset will need to have an odd thickness. Otherwise, the subset will have an even thickness. + """ + assert 2*plane == int(2*plane), "Plane must be an integer or between two indices" + assert axis in [0, 1, 2], "axis must be 0, 1 or 2" + + # Calculate the lower and upper bounds of the subset + lower_bound, upper_bound = 0, 0 + if (thickness % 2 == 1) and plane == int(plane): + lower_bound = int(plane - thickness//2) + upper_bound = int(plane + thickness//2 + 1) + elif (thickness % 2 == 0) and plane != int(plane): + lower_bound = int(np.ceil(plane - thickness/2)) + upper_bound = int(np.floor(plane + thickness/2)) + else: + raise ValueError("Unknown thickness and plane combination.") + + + if axis == 0: + return volume[lower_bound: upper_bound, :, :] + elif axis == 1: + return volume[:, lower_bound: upper_bound, :] + elif axis == 2: + return volume[:, :, lower_bound: upper_bound] + else: + raise ValueError("Unknown axis") + +def pad_to_size(volume: np.ndarray, size: tuple, plane: Union[int, float] = 128, axis: int = 0): + """ Pad a volume to a given size. Functions as a reverse to subset_volume_plane + + Args: + volume (np.ndarray): volume to pad + size (int): size to pad to + plane (Union[int, float], optional): plane of the subset. Defaults to 128. + axis (int, optional): axis of the plane. Defaults to 0. + + """ + thickness = volume.shape[axis] + + pad_left = int(plane - thickness//2) + pad_right = int(size[axis] - pad_left - thickness) + + assert pad_left >= 0, "Volume is larger than the size" + assert pad_right >= 0, "Volume is larger than the size" + assert pad_left + pad_right + thickness == size[axis] , "Padding is not correct" + + if (size[2]-volume.shape[2]) != 0: + print("Padding volume") + + if (size[1]-volume.shape[1]) != 0: + print("Padding volume") + + # Calculate the lower and upper bounds of the subset + if axis == 0: + return np.pad(volume, ((pad_left, pad_right), (0, size[1]-volume.shape[1]), (0, size[2]-volume.shape[2])), mode='constant', constant_values=0) + elif axis == 1: + return np.pad(volume, ((0, size[0]-volume.shape[0]), (pad_left, pad_right), (0, size[2]-volume.shape[2])), mode='constant', constant_values=0) + elif axis == 2: + return np.pad(volume, ((0, size[0]-volume.shape[0]), (0, size[1]-volume.shape[1]), (pad_left, pad_right)), mode='constant', constant_values=0) diff --git a/CCNet/data_loader/dataset.py b/CCNet/data_loader/dataset.py new file mode 100644 index 00000000..267587f8 --- /dev/null +++ b/CCNet/data_loader/dataset.py @@ -0,0 +1,799 @@ + +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +import time + +import h5py +import numpy as np +import torch +from torch.utils.data import Dataset +import torchio as tio +import copy + +from CCNet.data_loader import data_utils as du +from FastSurferCNN.utils import logging + +logger = logging.getLogger(__name__) + +from abc import ABC, abstractmethod + +class VINNDataset(Dataset, ABC): + + @abstractmethod + def __init__(self): + self.images = [] + self.labels = [] + self.weights = [] + self.subjects = [] + self.zooms = [] + self.default_size = [] + self.com = [] + self.localisation = False + + raise NotImplementedError("InferenceDataset is an abstract class and should not be instantiated") + + def unify_imgs(self, input_list, padsize = None): + if padsize is None: + padsize = [self.default_size for _ in input_list] + + output = [] + for i in range(len(input_list)): + output.append(self._pad(input_list[i], padsize[i])) + return output + + def _pad(self, image, padsize = None): + if padsize is None: + padsize = self.default_size + + dif = lambda i : np.max([0, (padsize[i] - image.shape[i])]) + if len(image.shape) == 2: + padded_img = np.pad(image, pad_width=((0, dif(0)), (0, dif(1))), mode='constant', constant_values = 0) + else: + padded_img = np.pad(image, pad_width=((0, dif(0)), (0, dif(1)), (0, dif(2))), mode='constant', constant_values = 0) + + return padded_img + + def _get_scale_factor(self, img_zoom): + """ + Get scaling factor to match original resolution of input image to + final resolution of FastSurfer base network. Input resolution is + taken from voxel size in image header. + + TODO: This needs to be updated based on the plane we are looking at in case we + are dealing with non-isotropic images as inputs. + :param img_zoom: + :return np.ndarray(float32): scale factor along x and y dimension + """ + scale = self.base_res / img_zoom + return scale + + def load_subset(self, hf, size, num_slices=5): + """ + Same as load data, but only loads a few slices per resolution for faster debugging + + :param num_slices: number of slices to load per resolution + """ + #start = time.time() + #logger.info(f"Processing images of size {size}.") + + self.images.extend(list(hf[f'{size}']['orig_dataset'][:num_slices])) + #logger.info("Processed origs of size {} in {:.3f} seconds".format(size, time.time()-start)) + + if self.localisation: + # load comissures + self.com.extend(list(hf[f'{size}']['center_dataset'][:num_slices])) + #logger.info("Processed comissures of size {} in {:.3f} seconds".format(size, time.time() - start)) + + self.labels.extend(list(hf[f'{size}']['aseg_dataset'][:num_slices])) + #logger.info("Processed asegs of size {} in {:.3f} seconds".format(size, time.time()-start)) + + self.weights.extend(list(hf[f'{size}']['weight_dataset'][:num_slices])) + #logger.info("Processed weights of size {} in {:.3f} seconds".format(size, time.time() - start)) + + self.zooms.extend(list(hf[f'{size}']['zoom_dataset'][:num_slices])) + #logger.info("Processed zooms of size {} in {:.3f} seconds".format(size, time.time() - start)) + + self.subjects.extend(list(hf[f'{size}']['subject'][:num_slices])) + #logger.info("Processed subjects of size {} in {:.3f} seconds".format(size, time.time()-start)) + + assert len(self.images) == len(self.labels) == len(self.weights) == len(self.zooms) == len(self.subjects), "Number of images, labels, weights, zooms and subjects are not equal" + logger.info(f"New number of slices is {len(self.images)}") + + # load max size + max = [self.default_size] + max.append(np.max([image.shape for image in self.images], axis=0)) + self.pad_shape_images = np.max(max, axis=0) + + max = [self.default_size[:2]] + max.append(np.max([label.shape for label in self.labels], axis=0)) + self.pad_shape_labels = np.max(max, axis=0) + + max = [self.default_size[:2]] + max.append(np.max([weight.shape for weight in self.weights], axis=0)) + self.pad_shape_weights = np.max(max, axis=0) + + def load_data(self, hf, size): + """ + Load data from h5 file + :param hf: h5 file + :param size: size of the data + :return: + """ + start = time.time() + self.zooms.extend(list(hf[f'{size}']['zoom_dataset'])) + logger.info("Processed zooms of size {} in {:.3f} seconds".format(size, time.time() - start)) + + start = time.time() + logger.info(f"Processing images of size {size}.") + + self.images.extend(list(hf[f'{size}']['orig_dataset'])) + logger.info("Processed origs of size {} in {:.3f} seconds".format(size, time.time() - start)) + + if self.localisation: + # load comissures + self.com.extend(list(hf[f'{size}']['center_dataset'])) + logger.info("Processed comissures of size {} in {:.3f} seconds".format(size, time.time() - start)) + + self.labels.extend(list(hf[f'{size}']['aseg_dataset'])) + logger.info("Processed asegs of size {} in {:.3f} seconds".format(size, time.time() - start)) + + self.weights.extend(list(hf[f'{size}']['weight_dataset'])) + logger.info("Processed weights of size {} in {:.3f} seconds".format(size, time.time() - start)) + + self.subjects.extend(list(hf[f'{size}']['subject'])) + logger.info("Processed subjects of size {} in {:.3f} seconds".format(size, time.time() - start)) + + logger.info(f"Number of slices for size {size} is {len(self.images)}") + + # load max size + max = [self.default_size] + max.append(np.max([image.shape for image in self.images], axis=0)) + self.pad_shape_images = np.max(max, axis=0) + + max = [self.default_size[:2]] + max.append(np.max([label.shape for label in self.labels], axis=0)) + self.pad_shape_labels = np.max(max, axis=0) + + max = [self.default_size[:2]] + max.append(np.max([weight.shape for weight in self.weights], axis=0)) + self.pad_shape_weights = np.max(max, axis=0) + + def __len__(self): + return len(self.images) + + def get_subject_names(self): + return self.subjects + + def __getitem__(self, index) -> dict: + padded_img, padded_weight, padded_label = self.unify_imgs([self.images[index], self.weights[index], self.labels[index]], [self.pad_shape_images, self.pad_shape_weights, self.pad_shape_labels]) + + padded_img = padded_img.transpose((2, 0, 1)) # move slice thickness to first spatial dimension + padded_img = torch.clamp(torch.from_numpy(padded_img).float() / padded_img.max(), 0, 1) # TODO: if we keep the image as int the data augmentation it will be faster + padded_weight = torch.from_numpy(padded_weight).float() + padded_label = torch.from_numpy(padded_label) + + orig_img = padded_img[padded_img.shape[0]//2].detach().clone() + + scale_factor = self._get_scale_factor(torch.from_numpy(self.zooms[index])) + + return {'image': padded_img, 'label': padded_label, 'weight': padded_weight, + 'scale_factor': scale_factor, 'unmodified_center_slice': orig_img, + 'subject_id': self.subjects[index]} + + +# Operator to load images for inference +class MultiScaleOrigDataThickSlices(VINNDataset): + """ + Class to load MRI-Image and process it to correct format for network inference + """ + def __init__(self, img_filename, orig_data, orig_zoom, cfg, transforms=None, lesion_mask=None, pad = True): + assert orig_data.max() > 0.8, f"Multi Dataset - orig fail, max removed {orig_data.max()}" + self.img_filename = img_filename + self.plane = cfg.DATA.PLANE + self.slice_thickness = cfg.MODEL.NUM_CHANNELS//2 + self.base_res = cfg.MODEL.BASE_RES + self.default_size = cfg.DATA.PADDED_SIZE + + if self.plane == "sagittal": + orig_data = du.transform_sagittal(orig_data) + self.zoom = orig_zoom[::-1][:2] + logger.info("Loading Sagittal with input voxelsize {}".format(self.zoom)) + + if lesion_mask is not None: + lesion_mask = du.transform_sagittal(lesion_mask) + logger.info("Loading Sagittal lesion with input voxelsize {}".format(self.zoom)) + + elif self.plane == "axial": + orig_data = du.transform_axial(orig_data) + self.zoom = orig_zoom[::-1][:2] + logger.info("Loading Axial with input voxelsize {}".format(self.zoom)) + + if lesion_mask is not None: + lesion_mask = du.transform_axial(lesion_mask) + logger.info("Loading Axial lesion with input voxelsize {}".format(self.zoom)) + + else: + self.zoom = orig_zoom[:2] + logger.info("Loading Coronal with input voxelsize {}".format(self.zoom)) + + if lesion_mask is not None: + logger.info("Loading Coronal lesion with input voxelsize {}".format(self.zoom)) + + # Create thick slices + orig_thick = du.get_thick_slices(orig_data, self.slice_thickness, pad = pad) + orig_thick = np.transpose(orig_thick, (2, 0, 1, 3)) + self.images = orig_thick + + if lesion_mask is not None: + lesion_thick = du.get_thick_slices(lesion_mask, self.slice_thickness) + lesion_thick = np.transpose(lesion_thick, (2, 0, 1, 3)) + assert(lesion_thick.shape == orig_thick.shape), f"Lesion mask shape {lesion_thick.shape} does not match image shape {orig_thick.shape}" + self.lesion_mask = lesion_thick + + self.transforms = transforms + logger.info(f"Successfully loaded Image from {img_filename}") + + + def __getitem__(self, index): + img = self.images[index] + if hasattr(self, 'lesion_mask'): + lesion_mask = self.lesion_mask[index] + else: + lesion_mask = np.zeros(img.shape, dtype=bool) + + scale_factor = self._get_scale_factor(np.asarray(self.zoom)) + + output = {'image': img, 'cutout_mask': lesion_mask, 'scale_factor': scale_factor} + if self.transforms is not None: + output = self.transforms(output) + + assert(output['cutout_mask'].shape == output['image'].shape), f"Cutout mask shape {output['cutout_mask'].shape} does not match image shape {output['image'].shape}" + + return output + + +# Operator to load hdf5-file for training +class MultiScaleDatasetAux(VINNDataset): + """ + Class for loading aseg file with augmentations (transforms) + """ + def __init__(self, dataset_path, cfg, scale_aug=False, transforms=None): + self.default_size = cfg.DATA.PADDED_SIZE + self.base_res = cfg.MODEL.BASE_RES + self.scale_aug = scale_aug + + # Check if we are dealing with a localisation task + self.localisation = (cfg.MODEL.MODEL_NAME == "FastSurferLocalisation") + + # Load the h5 file and save it to the datase + self.images = [] + self.labels = [] + self.weights = [] + self.aux_labels = [] + self.subjects = [] + self.zooms = [] + + # Open file in reading mode + logger.info(f"Opening file {dataset_path} for reading...") + assert(h5py.is_hdf5(dataset_path)), f"File \"{dataset_path}\" is not a valid hdf5 file." + + start = time.time() + with h5py.File(dataset_path, "r") as hf: + for size in cfg.DATA.SIZES: + try: + if not cfg.TRAIN.DEBUG: + self.load_data(hf, size) + else: + self.load_subset(hf, size, num_slices=20) + + logger.info("Successfully loaded {} participant volumes from h5 file".format(len(self.subjects))) + + assert(len(self.images) == len(self.labels) == len(self.weights) == len(self.zooms) == len(self.aux_labels)), "Number of images, labels, weights, auxil and zooms are not equal." + + except KeyError as e: + print(f"KeyError: Unable to open object (object {size} does not exist)") + continue + + self.transforms = transforms + + logger.info("Successfully loaded {} slices from {} with plane {} in {:.3f} seconds".format(len(self.images), dataset_path, cfg.DATA.PLANE, time.time()-start)) + + # override this function to enable scale augmentation (could also be outside of this function) + def _get_scale_factor(self, img_zoom, scale_aug = torch.tensor(0.0)): + """ + Get scaling factor to match original resolution of input image to + final resolution of FastSurfer base network. Input resolution is + taken from voxel size in image header. + + TODO: This needs to be updated based on the plane we are looking at in case we + are dealing with non-isotropic images as inputs. + :param img_zoom: + :return np.ndarray(float32): scale factor along x and y dimension + """ + if torch.all(scale_aug > 0): + img_zoom *= (1 / scale_aug) + + scale = self.base_res / img_zoom + + if self.scale_aug: + scale += torch.randn(1) * 0.1 + 0 # needs to be changed to torch.tensor stuff + scale = torch.clamp(scale, min=0.1) + + return scale + + def apply_transforms(self, img, label, weight, aux_labels=None, cutout_mask=None, orig_slice=None, subject_id=None): + """ + apply transforms to the image and labels + :param img: image + :param label: label + :param weight: weight + :return: transformed image, label and weight, and composed torchio history + """ + img = img[None, ...] # add batch dimension for torchio + label = label[None, ...] + weight = weight[None, ...] + if orig_slice is not None: + orig_slice = orig_slice[None, ...] + if cutout_mask is not None: + cutout_mask = cutout_mask[None, ...] + if aux_labels is not None: + aux_labels = aux_labels[None, ...] + + # pad label and weight to match sptial dimensions of image + slice_thickness = img.shape[1] + label = np.pad(label, ((0, 0), (slice_thickness//2, slice_thickness//2), (0, 0), (0, 0)), 'constant', constant_values=0) + weight = np.pad(weight, ((0, 0), (slice_thickness//2, slice_thickness//2), (0, 0), (0, 0)), 'constant', constant_values=0) + #if cutout_mask is not None: + # cutout_mask = np.pad(cutout_mask, ((0, 0), (slice_thickness//2, slice_thickness//2), (0, 0), (0, 0)), 'constant', constant_values=0) + #if aux_labels is not None: + # aux_labels = np.pad(aux_labels, ((0, 0), (slice_thickness//2, slice_thickness//2), (0, 0), (0, 0)), 'constant', constant_values=0) + + + subject = tio.Subject({'image': tio.ScalarImage(tensor=img), + 'label': tio.LabelMap(tensor=label), + 'weight': tio.LabelMap(tensor=weight), + 'unmodified_center_slice': tio.ScalarImage(tensor=orig_slice) if orig_slice is not None else None, + 'cutout_mask': tio.LabelMap(tensor=cutout_mask), # TODO: this converts to UINT8, which is not what we want + 'aux_label': tio.LabelMap(tensor=aux_labels), + 'subject_id': subject_id + }) + tx_sample = self.transforms(subject) # this returns data as torch.tensors + + + + img = tx_sample['image'].data.float().squeeze(0) + label = tx_sample['label'].data.byte() + weight = tx_sample['weight'].data.float() + if aux_labels: + aux_labels = tx_sample['aux_label'].data.byte().squeeze(0) + #aux_labels = aux_labels[:, slice_thickness//2, :, :].squeeze(0) # retrieve middle slice + if orig_slice is not None: + orig_slice = tx_sample['unmodified_center_slice'].data.float().squeeze(0) + if 'cutout_mask' in tx_sample.keys(): + cutout_mask = tx_sample['cutout_mask'].data.bool().squeeze(0) + else: + cutout_mask = torch.zeros(orig_slice.size(), dtype=bool) + + label = label[:, slice_thickness//2, :, :].squeeze(0) # retrieve middle slice + weight = weight[:, slice_thickness//2, :, :].squeeze(0) + # if aux_labels is not None: + # aux_labels = aux_labels[:, slice_thickness//2, :, :].squeeze(0) + + return img, label, weight, aux_labels, orig_slice, cutout_mask, tx_sample.get_composed_history() + + + def __getitem__(self, index): + sample = super().__getitem__(index) + img = sample['image'] + label = sample['label'] + weight = sample['weight'] + subject_id = sample['subject_id'] + + if 'aux_data' in sample.keys(): + auxiliary_labels = sample['aux_data'] + else: + auxiliary_labels = None + + if 'cutout_mask' in sample.keys(): + cutout_mask = sample['cutout_mask'] + elif 'aux_data' in sample.keys(): + cutout_mask = auxiliary_labels == 3 + else: + cutout_mask = torch.zeros_like(img, dtype=bool) + + + #### DEBUG + + # if cutout_mask.sum() > 0: + # import matplotlib.pyplot as plt + # plt.figure() + # fig, ax = plt.subplots(1, 1) + # ax.imshow(img[img.shape[0]//2]) + # ax.imshow(cutout_mask[cutout_mask.shape[0]//2], alpha=0.5) + # plt.show() + + + if self.transforms is not None: + del(sample['unmodified_center_slice']) + orig_slice = img.detach().clone() # we need to get the full thickslice to replicate data augmentation on the original slice (i.e. rotation of the plane) + label = label[np.newaxis, :, :] # add slice thickness dimension of one + #auxiliary_labels = auxiliary_labels[np.newaxis, :, :] + + weight = weight[np.newaxis, :, :] + zoom_aug = torch.as_tensor([0., 0.]) + + if self.aux_labels: + img, label, weight, auxiliary_labels, orig_slice, cutout_mask, rep_tf = self.apply_transforms(img, label, weight, aux_labels=auxiliary_labels, orig_slice=orig_slice, cutout_mask=cutout_mask, subject_id=sample['subject_id']) + else: + img, label, weight, auxiliary_labels, orig_slice, cutout_mask, rep_tf = self.apply_transforms(img, label, weight, aux_labels=auxiliary_labels, orig_slice=orig_slice, cutout_mask=cutout_mask, subject_id=sample['subject_id']) + + + if rep_tf and 'scales' in rep_tf[0]._get_reproducing_arguments().keys(): # get updated scalefactor, incase of scaling + zoom_aug += torch.as_tensor(rep_tf[0]._get_reproducing_arguments()["scales"])[:-1] + + # Normalize image and clamp between 0 and 1 again after data augmentation + img = torch.clamp(img / orig_slice.max(), min=0.0, max=1.0) # use original slice to normalize + orig_slice = torch.clamp(orig_slice / orig_slice.max(), min=0.0, max=1.0) + orig_slice = orig_slice[orig_slice.shape[0]//2] + #cutout_mask = cutout_mask[cutout_mask.shape[0]//2] + + + #assert(torch.max(img) == 1.0 and torch.min(img) == 0.0), "Image is not normalized between 0 and 1, but between {} and {}".format(torch.min(img), torch.max(img)) + else: + orig_slice = sample['unmodified_center_slice'] + zoom_aug = torch.as_tensor([0., 0.]) + #cutout_mask = sample['cutout_mask'] # TODO: is not in sample yet + + scale_factor = self._get_scale_factor(torch.from_numpy(self.zooms[index]), scale_aug=zoom_aug) + + assert(cutout_mask.shape == img.shape), f"Cutout mask shape {cutout_mask.shape} does not match image shape {img.shape}" + + return {'image': img, 'label': label, 'weight': weight, 'aux_labels': auxiliary_labels, + 'scale_factor': scale_factor, 'unmodified_center_slice': orig_slice, 'cutout_mask': cutout_mask, 'subject_id': subject_id} + + +class MultiScaleDataset(VINNDataset): + """ + Class for loading aseg file with augmentations (transforms) + """ + + def __init__(self, dataset_path, cfg, scale_aug=False, transforms=None): + self.default_size= cfg.DATA.PADDED_SIZE + self.base_res = cfg.MODEL.BASE_RES + self.scale_aug = scale_aug + + self.localisation = cfg.MODEL.MODEL_NAME == "FastSurferLocalisation" + + # Load the h5 file and save it to the datase + self.images = [] + self.labels = [] + self.weights = [] + self.subjects = [] + self.zooms = [] + self.com = [] + + # Open file in reading mode + logger.info(f"Opening file {dataset_path} for reading...") + import os + cwd = os.path.realpath('.') + assert (h5py.is_hdf5(dataset_path)), f"File \"{dataset_path}\" is not a valid hdf5 file." + + start = time.time() + with h5py.File(dataset_path, "r") as hf: + sizes = hf.keys() if cfg.DATA.SIZES is None else cfg.DATA.SIZES + for size in hf.keys(): + try: + if not cfg.TRAIN.DEBUG: + self.load_data(hf, size) + else: + self.load_subset(hf, size, num_slices=10) + + logger.info("Successfully loaded {} participant volumes from h5 file".format(len(self.subjects))) + + assert (len(self.images) == len(self.labels) == len(self.weights) == len(self.zooms)), "Number of images, labels, weights and zooms are not equal." + #assert np.array([self.labels[0].shape == label.shape for label in self.labels]).all(), "Not all labels have the same shape" + + + except KeyError as e: + print(f"KeyError: Unable to open object (object {size} does not exist)") + continue + + + # TODO: CHECK + #self.transforms = transforms + self.transforms = None + + logger.info("Successfully loaded {} slices from {} with plane {} in {:.3f} seconds".format(len(self.images), + dataset_path, + cfg.DATA.PLANE, + time.time() - start)) + + # override this function to enable scale augmentation (could also be outside of this function) + def _get_scale_factor(self, img_zoom, scale_aug=torch.tensor(0.0)): + """ + Get scaling factor to match original resolution of input image to + final resolution of FastSurfer base network. Input resolution is + taken from voxel size in image header. + + TODO: This needs to be updated based on the plane we are looking at in case we + are dealing with non-isotropic images as inputs. + :param img_zoom: + :return np.ndarray(float32): scale factor along x and y dimension + """ + if torch.all(scale_aug > 0): + img_zoom *= (1 / scale_aug) + + scale = self.base_res / img_zoom + + if self.scale_aug: + scale += torch.randn(1) * 0.1 + 0 # needs to be changed to torch.tensor stuff + scale = torch.clamp(scale, min=0.1) + + return scale + + def apply_transforms(self, img, label, weight, cutout_mask=None, orig_slice=None, subject_id=None): + """ + apply transforms to the image and labels + :param img: image + :param label: label + :param weight: weight + :return: transformed image, label and weight, and composed torchio history + """ + img = img[None, ...] # add batch dimension for torchio + label = label[None, ...] + weight = weight[None, ...] + if orig_slice is not None: + orig_slice = orig_slice[None, ...] + if cutout_mask is not None: + cutout_mask = cutout_mask[None, ...] + + # pad label and weight to match sptial dimensions of image + slice_thickness = img.shape[1] + label = np.pad(label, ((0, 0), (slice_thickness // 2, slice_thickness // 2), (0, 0), (0, 0)), 'constant', + constant_values=0) + weight = np.pad(weight, ((0, 0), (slice_thickness // 2, slice_thickness // 2), (0, 0), (0, 0)), 'constant', + constant_values=0) + + subject = tio.Subject({'image': tio.ScalarImage(tensor=img), + 'label': tio.LabelMap(tensor=label), + 'weight': tio.LabelMap(tensor=weight), + 'unmodified_center_slice': tio.ScalarImage( + tensor=orig_slice) if orig_slice is not None else None, + 'cutout_mask': tio.LabelMap(tensor=cutout_mask), + # TODO: this converts to UINT8, which is not what we want + + 'subject_id': subject_id + }) + + tx_sample = self.transforms(subject) # this returns data as torch.tensors + + img = tx_sample['image'].data.float().squeeze(0) + label = tx_sample['label'].data.byte() + weight = tx_sample['weight'].data.float() + if orig_slice is not None: + orig_slice = tx_sample['unmodified_center_slice'].data.float().squeeze(0) + if 'cutout_mask' in tx_sample.keys(): + cutout_mask = tx_sample['cutout_mask'].data.bool().squeeze(0) + else: + cutout_mask = torch.zeros(orig_slice.size(), dtype=bool) + + label = label[:, slice_thickness // 2, :, :].squeeze(0) # retrieve middle slice + weight = weight[:, slice_thickness // 2, :, :].squeeze(0) + + return img, label, weight, orig_slice, cutout_mask, tx_sample.get_composed_history() + + def __getitem__(self, index): + sample = super().__getitem__(index) + img = sample['image'] + label = sample['label'] + weight = sample['weight'] + subject_id = sample['subject_id'] + + if 'cutout_mask' in sample.keys(): + cutout_mask = sample['cutout_mask'] + else: + cutout_mask = torch.zeros_like(img, dtype=bool) + + if self.localisation: + com = torch.as_tensor(self._pad(self.com[index], self.pad_shape_labels)) + label = torch.cat((label, com), dim=0) + + if self.transforms is not None: + del (sample['unmodified_center_slice']) + orig_slice = img.detach().clone() # we need to get the full thickslice to replicate data augmentation on the original slice (i.e. rotation of the plane) + label = label[np.newaxis, :, :] # add slice thickness dimension of one + + weight = weight[np.newaxis, :, :] + zoom_aug = torch.as_tensor([0., 0.]) + + img, label, weight, orig_slice, cutout_mask, rep_tf = self.apply_transforms(img, label, weight, orig_slice=orig_slice, subject_id=sample['subject_id'], cutout_mask=cutout_mask) + # ERROR: Labels loose weights + + if rep_tf and 'scales' in rep_tf[ + 0]._get_reproducing_arguments().keys(): # get updated scalefactor, incase of scaling + zoom_aug += torch.as_tensor(rep_tf[0]._get_reproducing_arguments()["scales"])[:-1] + + # Normalize image and clamp between 0 and 1 again after data augmentation + img = torch.clamp(img / orig_slice.max(), min=0.0, max=1.0) # use original slice to normalize + orig_slice = torch.clamp(orig_slice / orig_slice.max(), min=0.0, max=1.0) + orig_slice = orig_slice[orig_slice.shape[0] // 2] + + # assert(torch.max(img) == 1.0 and torch.min(img) == 0.0), "Image is not normalized between 0 and 1, but between {} and {}".format(torch.min(img), torch.max(img)) + else: + orig_slice = sample['unmodified_center_slice'] + zoom_aug = torch.as_tensor([0., 0.]) + + scale_factor = self._get_scale_factor(torch.from_numpy(self.zooms[index]), scale_aug=zoom_aug) + + assert(cutout_mask.shape == img.shape), f"Cutout mask shape {cutout_mask.shape} does not match image shape {img.shape}" + + + return {'image': img, 'label': label, 'weight': weight, + 'scale_factor': scale_factor, 'unmodified_center_slice': orig_slice, + 'cutout_mask': cutout_mask, 'subject_id': subject_id} + + +# # Operator to load hdf5-file for validation +class MultiScaleDatasetVal(VINNDataset): + """ + Class for loading aseg file with augmentations (transforms) + + only used for legacy FastSurferVINN + """ + #overrides + def __init__(self, dataset_path, cfg, transforms=None): + + self.default_size = cfg.DATA.PADDED_SIZE + self.base_res = cfg.MODEL.BASE_RES + + # Load the h5 file and save it to the dataset + self.images = [] + self.labels = [] + self.weights = [] + self.aux_labels = [] + self.subjects = [] + self.zooms = [] + + # Open file in reading mode + start = time.time() + logger.info(f"Opening file {dataset_path} for reading...") + assert(h5py.is_hdf5(dataset_path)), f"File \"{dataset_path}\" is not a valid hdf5 file." + with h5py.File(dataset_path, "r") as hf: + for size in hf.keys(): # iterate over image sizes + if not cfg.TRAIN.DEBUG: + self.load_data(hf, size) + else: + self.load_subset(hf, size) + + + self.transforms = transforms + logger.info("Successfully loaded {} slices from {} with plane {} in {:.3f} seconds".format(len(self.images), + dataset_path, + cfg.DATA.PLANE, + time.time()-start)) + + def __getitem__(self, index) -> dict: + sample = super().__getitem__(index) + + if self.transforms is not None: + sample = self.transforms(sample) + + if 'aux_data' in sample.keys(): + sample['cutout_mask'] = sample['aux_data'] == 3 + if sample['cutout_mask'].sum() > 0: + #mid_slice = sample['image'][sample['image'].shape[0]//2] + sample['image'][sample['cutout_mask']] = 0 + #sample['image'][sample['image'].shape[0]//2] = mid_slice + else: + sample['cutout_mask'] = torch.zeros(sample['label'].size(), dtype=bool) + return sample + + +# class InpaintingDataset(VINNDataset): +# """ +# Class for loading aseg file with augmentations (transforms) +# """ +# #overrides +# def __init__(self, original_dataset): +# start = time.time() + +# self.default_size = original_dataset.max_size +# self.base_res = original_dataset.base_res + + + +# # Load the h5 file and save it to the dataset +# self.images = copy.deepcopy(original_dataset.images) +# self.orig_images = original_dataset.images +# self.labels = original_dataset.labels +# self.aux_labels = original_dataset.aux_labels +# self.weights = original_dataset.weights +# self.subjects = original_dataset.subjects +# self.zooms = original_dataset.zooms +# self.cutout_masks = []#np.zeros((len(self.labels),self.labels[0].shape[0], self.labels[0].shape[1]), dtype=bool) + +# self.cutout_images() + + +# if original_dataset.transforms is not None: +# raise NotImplementedError("Transforms are not yet supported for inpainting dataset") +# #self.transforms = original_dataset.transforms + +# #assert((self.images[0] != original_dataset.images[0]).any()), "Inpainting dataset is a copy of the original dataset" +# assert(self.images[0] is not original_dataset.images[0]), 'images should point to different objects, but do not' +# assert(self.orig_images is original_dataset.images), 'orig_images and images from donor dataset should point to the same object, but do not' +# assert(self.labels[0] is original_dataset.labels[0]), 'labels should point to the same object, but do not' +# assert(self.aux_labels[0] is original_dataset.aux_labels[0]), 'aux_labels should point to the same object, but do not' + +# logger.info("Successfully augmented {} slices with cutout in {:.3f} seconds".format(len(self.images), time.time()-start)) + +# def cutout_images(self): +# "Duplicate loaded images and generate deterministic cutout for validation" +# # start_idx_of_cutout_imgs = len(self.images) + +# # self.images.extend(self.images.copy()) +# # self.labels.extend(self.labels.copy()) +# # self.weights.extend(self.weights.copy()) +# # self.zooms.extend(self.zooms.copy()) + + +# cutout_size = 5 #th of the image size + +# # iterate over all images and cutout a square move it on a grid left to right and top to bottom +# # cutout size is the grid size +# for i in range(len(self.images)): +# grid_x = i%cutout_size +# x1 = int(grid_x/cutout_size*self.images[i].shape[0]) +# x2 = int((grid_x+1)/cutout_size*self.images[i].shape[0]) + + +# grid_y = ((i)//cutout_size)%cutout_size +# y1 = int(grid_y/cutout_size*self.images[i].shape[1]) +# y2 = int((grid_y+1)/cutout_size*self.images[i].shape[1]) + +# #print(x1,x2,y1,y2) + +# self.images[i][x1:x2,y1:y2] = 0 +# self.weights[i][x1:x2,y1:y2] *= 0.5 +# self.cutout_masks.append(np.zeros(self.labels[i].shape, dtype=bool)) +# self.cutout_masks[i][x1:x2,y1:y2] = True + +# #self.images[i,i/cutout_size*self.images[i].shape[1],i*cutout_size*self.images[i].shape[1]] = 0.0 +# #self.weights[i,i*cutout_size*self.images[i].shape[1],i*cutout_size*self.images[i].shape[1]] *= 0.5 + +# def __getitem__(self, index) -> dict: +# padded_img, padded_label, padded_weight, padded_orig, padded_cutout_mask, padded_auxiliary_labels = self.unify_imgs([self.images[index], self.labels[index], self.weights[index], self.orig_images[index], self.cutout_masks[index], self.aux_labels[index]]) + +# padded_img = padded_img.transpose((2, 0, 1)) # move slice thickness to first spatial dimension +# padded_img = torch.clamp(torch.from_numpy(padded_img).float() / padded_orig.max(), 0, 1) # use original max to normalize + +# padded_orig = padded_orig.transpose((2, 0, 1)) # move slice thickness to first spatial dimension +# padded_orig = torch.clamp(torch.from_numpy(padded_orig).float() / padded_orig.max(), 0, 1) +# padded_orig = padded_orig[padded_orig.shape[0] // 2] # take center slice + + +# padded_weight = torch.from_numpy(padded_weight).float() +# padded_label = torch.from_numpy(padded_label) +# padded_auxiliary_labels = torch.from_numpy(padded_auxiliary_labels) +# padded_cutout_mask = torch.from_numpy(padded_cutout_mask).bool() + +# scale_factor = self._get_scale_factor(torch.from_numpy(self.zooms[index])) + +# return {'image': padded_img, 'label': padded_label, 'weight': padded_weight, +# 'scale_factor': scale_factor, 'unmodified_center_slice': padded_orig, +# 'cutout_mask': padded_cutout_mask, 'subject_id': self.subjects[index], 'aux_labels': padded_auxiliary_labels} + + \ No newline at end of file diff --git a/CCNet/data_loader/loader.py b/CCNet/data_loader/loader.py new file mode 100644 index 00000000..e7d06b67 --- /dev/null +++ b/CCNet/data_loader/loader.py @@ -0,0 +1,232 @@ + +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +from torchvision import transforms +from torch.utils.data import DataLoader +import torchio as tio + +from CCNet.data_loader import dataset as dset +from CCNet.data_loader.augmentation import ToTensor, ZeroPad2D, \ + RandomizeScaleFactor, SmartRandomCutout, RandomCutout, CutoutRandomHemisphere, FlipLeftRight, \ + CutoutBRATSTumor, CutoutBRATSTumorDeterministic, CutoutTumorMask, RandGridDistortiond +from FastSurferCNN.utils import logging +#from monai.utils import GridSampleMode, GridSamplePadMode + +logger = logging.getLogger(__name__) + + + + + +def create_all_augmentations(plane: str): + """" + create a dictionary of all augmentations + """ + # Cutout + cutout = RandomCutout(include=['image'], downweighting_factor=0.5) + + if plane is not None: + hemisphere = CutoutRandomHemisphere(orientation=plane, include=['image']) + else: + hemisphere = None + + + # brats_tumor = CutoutBRATSTumor(tumor_mask_hdf5='../../data/tumor_masks.hdf5', + # include=['image', 'weight'], downweighting_factor=0.9, random=False) + + # Flip + flip = FlipLeftRight(orientation=plane, include=['image', 'label', 'weight', 'unmodified_center_slice', 'aux_labels', 'cutout_mask']) + + # NOTE: Don't use this!! Its super slow + # Elastic + # elastic = tio.RandomElasticDeformation(num_control_points=7, + # max_displacement=(0, 30, 30), + # locked_borders=2, + # image_interpolation='linear', + # include=['image', 'label', 'weight', 'unmodified_center_slice']) + + # Grid Distortion + grid_distortion = RandGridDistortiond(keys=['image', 'weight', 'label', 'unmodified_center_slice', 'cutout_mask'], + num_cells=20, prob=1, distort_limit=(-0.4, 0.4), mode='nearest', padding_mode='zeros') #mode=GridSampleMode.NEAREST, padding_mode=GridSamplePadMode.ZEROS), + + # Scales + scaling = tio.RandomAffine(scales=(0.8, 1.15), + degrees=0, + translation=(0, 0, 0), + isotropic=True, # If True, scaling factor along all dimensions is the same + center='image', + default_pad_value='minimum', + image_interpolation='linear', + include=['image', 'label', 'weight', 'unmodified_center_slice', 'aux_labels', 'cutout_mask']) + + # Rotation + rot = tio.RandomAffine(scales=(1.0, 1.0), + degrees=(20,0,0), + translation=(0, 0, 0), + isotropic=True, # If True, scaling factor along all dimensions is the same + center='image', + default_pad_value='minimum', + image_interpolation='linear', + include=['image', 'label', 'weight', 'unmodified_center_slice', 'aux_labels', 'cutout_mask']) + + # Translation + tl = tio.RandomAffine(scales=(1.0, 1.0), + degrees=0, + translation=(15.0, 15.0, 0), + isotropic=True, # If True, scaling factor along all dimensions is the same + center='image', + default_pad_value='minimum', + image_interpolation='linear', + include=['image', 'label', 'weight', 'unmodified_center_slice', 'aux_labels', 'cutout_mask']) + + # Customzied Affine with scaling and rotation for better performance + custom_affine = tio.RandomAffine(scales=(0.8, 1.15), + degrees=(10,0,0), + translation=(0, 0, 0), + isotropic=True, # If True, scaling factor along all dimensions is the same + center='image', + default_pad_value='minimum', + image_interpolation='linear', + include=['image', 'label', 'weight', 'unmodified_center_slice', 'aux_labels', 'cutout_mask']) + + # Random Anisotropy (Downsample image along an axis, then upsample back to initial space + ra = tio.transforms.RandomAnisotropy(axes=(0, 1), + downsampling=(1.1, 1.5), + image_interpolation="linear", + include=['image', 'unmodified_center_slice']) + + # Bias Field + bias_field = tio.transforms.RandomBiasField(coefficients=0.5, order=3, include=['image', 'unmodified_center_slice']) + + # Gamma + random_gamma = tio.transforms.RandomGamma(log_gamma=(-0.1, 0.1), include=['image', 'unmodified_center_slice']) + + return {"Cutout": cutout, + "Hemisphere": hemisphere, "Flip": flip, + "Scaling": scaling, "Rotation": rot, + "CustomAffine": custom_affine, "Translation": tl, + "RAnisotropy": ra, "BiasField": bias_field, "RGamma": random_gamma, + "GridDistortion": grid_distortion} + + +def get_dataloader(cfg, mode, val_dataset=None, data_path=None): + """ + Creating the dataset and pytorch data loader + + :param cfg: + :param mode: loading data for train, val and test mode + :return: + """ + if mode == 'train': + batch_size = cfg.TRAIN.BATCH_SIZE + if "None" in cfg.DATA.AUG: + tfs = [ToTensor()]#ZeroPad2D((padding_size, padding_size)), + # old transform + if "Gaussian" in cfg.DATA.AUG: + tfs.append(RandomizeScaleFactor(mean=0, std=0.1)) + + if data_path is None: + data_path = cfg.DATA.PATH_HDF5_TRAIN + shuffle = True + + logger.info(f"Loading {mode.capitalize()} data ... from {data_path}. Using legacy data augmentation (no torchio)") + + dataset = dset.MultiScaleDatasetVal(data_path, cfg, transforms.Compose(tfs)) + elif len(cfg.DATA.AUG) == 0: + if data_path is None: + data_path = cfg.DATA.PATH_HDF5_TRAIN + shuffle = True + + logger.info(f"Loading {mode.capitalize()} data from {data_path}. No data augmentation") + + dataset = dset.MultiScaleDataset(data_path, cfg, scale_aug=False, transforms=None) + else: + + all_augmentations = create_all_augmentations(cfg.DATA.PLANE) + + all_tfs = {all_augmentations[aug] for aug in cfg.DATA.AUG if aug != "Gaussian"} # TODO: adjust hard coded probability (is this the same as hardcoded below?) + gaussian_noise = True if "Gaussian" in cfg.DATA.AUG else False + + # remove "Gaussian" from cfg.DATA.AUG + cfg.DATA.AUG = [aug for aug in cfg.DATA.AUG if aug != "Gaussian"] + + # get probabilities for each augmentation + if cfg.DATA.AUG_LIKELYHOOD is None: + cfg.DATA.AUG_LIKELYHOOD = 1.0 + if isinstance(cfg.DATA.AUG_LIKELYHOOD, float): + cfg.DATA.AUG_LIKELYHOOD = [cfg.DATA.AUG_LIKELYHOOD] * len(cfg.DATA.AUG) + + for prob, aug in zip(cfg.DATA.AUG_LIKELYHOOD, all_tfs): + aug.p = prob # set probability for each augmentation + #all_tfs = [tio.OneOf(all_tfs, p=prob) for prob in cfg.DATA.AUG_LIKELYHOOD] + + transform = tio.Compose(all_tfs, include=["image", "label", "weight"]) + + data_path = cfg.DATA.PATH_HDF5_TRAIN + shuffle = True + + logger.info(f"Loading {mode.capitalize()} data from {data_path}. Using torchio data augmentation") + + dataset = dset.MultiScaleDataset(data_path, cfg, gaussian_noise, transforms=transform) + + + + elif mode.startswith('val'): + if data_path is None: + data_path = cfg.DATA.PATH_HDF5_VAL + + + shuffle = False + batch_size = cfg.TEST.BATCH_SIZE + + if mode.endswith('inpainting'): + logger.info(f"Loading {mode.capitalize()} data from {data_path}. Using inpainting data augmentation") + transform = CutoutBRATSTumorDeterministic(tumor_mask_hdf5='../../data/tumor_masks.hdf5', include=['image']) + dataset = dset.MultiScaleDataset(data_path, cfg, scale_aug=False, transforms=transform) + else: #mode == 'val_inpainting' + logger.info(f"Loading {mode.capitalize()} data from {data_path}") + transform = CutoutTumorMask() + #dataset = dset.InpaintingDataset(val_dataset) + dataset = dset.MultiScaleDataset(data_path, cfg, scale_aug=False, transforms=transform) + + + + else: + raise ValueError(f"Unknown dataloader mode {mode}") + + assert(len(dataset) > 0), f"Dataset {data_path} is empty" + + + if cfg.DATA_LOADER.NUM_WORKERS > 0: + dataloader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=cfg.DATA_LOADER.NUM_WORKERS, + shuffle=shuffle, + pin_memory=cfg.DATA_LOADER.PIN_MEMORY, + prefetch_factor=cfg.DATA_LOADER.PREFETCH_FACTOR, # prefetch doesn't do anything if we don't have workers + drop_last=False + ) + else: + dataloader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=cfg.DATA_LOADER.NUM_WORKERS, + shuffle=shuffle, + pin_memory=cfg.DATA_LOADER.PIN_MEMORY, # prefetch doesn't do anything if we don't have workers + drop_last=False + ) + return dataloader diff --git a/CCNet/download_checkpoints.py b/CCNet/download_checkpoints.py new file mode 100644 index 00000000..5393eec9 --- /dev/null +++ b/CCNet/download_checkpoints.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +# Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from FastSurferCNN.utils.checkpoint import check_and_download_ckpts, get_checkpoints, VINN_AXI, VINN_COR, VINN_SAG, URL + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Check and Download Network Checkpoints') + parser.add_argument("--all", default=False, action="store_true", + help="Check and download all default checkpoints") + parser.add_argument("--vinn", default=False, action="store_true", + help="Check and download VINN default checkpoints") + parser.add_argument("--url", type=str, default=URL, + help="Specify you own base URL. Default: {}".format(URL)) + parser.add_argument('files', nargs='*', + help="Checkpoint file paths to download, e.g. checkpoints/aparc_vinn_axial_v2.0.0.pkl ...") + args = parser.parse_args() + + # download all sets of weights: + if args.all: + args.vinn = True + + if not args.vinn and not args.files: + print("Specify either files to download or --vinn, see help -h.") + exit(1) + + if args.vinn or args.all: + get_checkpoints(VINN_AXI, VINN_COR, VINN_SAG, args.url) + + # later we can add more defaults here (for other sub-segmentation networks, or old CNN) + + for fname in args.files: + check_and_download_ckpts(fname, args.url) diff --git a/CCNet/generate_hdf5.py b/CCNet/generate_hdf5.py new file mode 100644 index 00000000..5e006a4e --- /dev/null +++ b/CCNet/generate_hdf5.py @@ -0,0 +1,688 @@ +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ast import Tuple +from calendar import c +import os.path +from re import T, X +# IMPORTS +import time +import glob +from os.path import join +from collections import defaultdict +import json +import csv +from typing import Optional, Union + +import numpy as np +import nibabel as nib +import h5py +from sympy import true + +from FastSurferCNN.data_loader.data_utils import (transform_axial, transform_sagittal, map_aparc_aseg2label, + create_weight_mask, get_thick_slices, filter_blank_slices_thick, + read_classes_from_lut, get_labels_from_lut, unify_lateralized_labels, + map_incrementing, subset_volume_plane) +from FastSurferCNN.utils import logging +from CCNet.utils.misc import calculate_centers_of_comissures + + +LOGGER = logging.getLogger(__name__) + + +class H5pyDataset: + + def __init__(self, params, processing="aparc"): + + self.debug = params["debug"] + self.dataset_name = params["dataset_name"] + self.data_path = params["data_path"] + self.slice_thickness = params["thickness"] + self.orig_name = params["image_name"] + self.aparc_name = params["gt_name"] + self.aparc_nocc = params["gt_nocc"] + self.aux_data = params["aux_data"] + self.processing = processing + + self.max_weight = params["max_weight"] + self.edge_weight = params["edge_weight"] + self.hires_weight = params["hires_weight"] + self.gradient = params["gradient"] + self.gm_mask = params["gm_mask"] + self.pad = params["pad"] + self.crop = params["crop"] + + self.lut = read_classes_from_lut(params["lut"]) + self.labels, self.labels_sag = get_labels_from_lut(self.lut, params["sag-mask"]) + self.lateralization = unify_lateralized_labels(self.lut, params["combi"]) + + self.centers = {} # RAS coordinates of the center of the commisures [PC, AC] + + if params["center"] is not None: + with open(params["center"], 'r') as center_csv_file: + rows = list(csv.reader(center_csv_file, delimiter=",")) + for row in rows[1:]: + pc_center = row[1:4] + ac_center = row[4:7] + self.centers[row[0]] = [pc_center, ac_center] + + + + # init files paths + self.subject_dirs = [] + self.gt_files = [] + self.gt_nocc_files = None + + + + if params["csv_file"] is not None: + # read from csv file + with open(params["csv_file"], "r") as s_dirs: + complete_list = list(csv.reader(s_dirs, delimiter=",")) + + # csv file contains subjects + self.subject_dirs = [subject[0] for subject in complete_list] + + if (len(complete_list[0]) >= 2) and (complete_list[0][1] is not None): + # csv file contains gt in second column + self.gt_files = [subject[1] for subject in complete_list] + + if (len(complete_list[0]) >= 3) and (complete_list[0][2] is not None): + # csv file contains gt_nocc in 3rd column + self.gt_nocc_files = [subject[2] for subject in complete_list] + + elif (self.data_path is not None) and (params["pattern"] is not None): + self.search_pattern = join(self.data_path, "/", params["pattern"]) + self.s = glob.glob(self.search_pattern) + else: + raise ValueError("No valid subject path list could be created") + + # Set subject file names + if os.path.isfile(self.subject_dirs[0]): + self.subject_files = self.subject_dirs + self.subject_dirs = [os.path.dirname(subject) for subject in self.subject_dirs] + elif self.orig_name is not None: + self.subject_files = [join(subject, self.orig_name) for subject in self.subject_dirs] + else: + raise ValueError("No valid subject file list could be created") + + # Set gt file name + if self.gt_files is None: + self.gt_files = self.subject_dirs + + if os.path.isdir(self.gt_files[0]): + if self.aparc_name is not None: + self.gt_files = [join(subject, self.aparc_name) for subject in self.gt_files] + else: + raise ValueError("No valid gt file list could be created") + + #if self.gt_nocc_files is not []: + # if (os.path.isdir(self.gt_nocc_files[0]) and (self.aparc_nocc is not None)): + # self.gt_nocc_files = [join(subject, self.aparc_nocc) for subject in self.subject_dirs] + + self.data_set_size = len(self.subject_dirs) + + for subject in self.subject_files: + assert os.path.isfile(subject), f"{subject} is not a file" + for subject in self.gt_files: + assert os.path.isfile(subject), f"{subject} is not a file" + + + def _load_volumes(self, gt_file, orig_file, gt_nocc_file=None, subject_path=None): + # Load the orig and extract voxel spacing information (x, y, and z dim) + LOGGER.info('Processing intensity image {} and ground truth segmentation {}'.format(self.orig_name, gt_file)) + + # Load the orignal image and zoom + orig = nib.load(orig_file) + zoom = orig.header.get_zooms() + orig = np.asarray(orig.get_fdata(), dtype=np.uint8) + + # Load the segmentation ground truth + aseg = np.asarray(nib.load(gt_file).get_fdata(), dtype=np.uint16) + + if gt_nocc_file is not None: + aseg_nocc = nib.load(gt_nocc_file) + aseg_nocc = np.asarray(aseg_nocc.get_fdata(), dtype=np.uint16) + assert (aseg_nocc.shape == aseg.shape), "Aseg and aseg_nocc must have the same shape" + else: + aseg_nocc = None + + # LUT for aux data: + # 0: unknown + # 1: brain + # 2: midplane + # 3: anomaly + + if self.aux_data is not None: + if 'brainmask' in self.aux_data.keys(): + aux_data_brainmask = np.asarray(nib.load(join(subject_path, self.aux_data['brainmask'])).get_fdata()) + aux_data_brainmask = (aux_data_brainmask > 0).astype(np.uint8) + + # assert(np.unique(aux_data).shape[0] == 2), "Auxiliary data must be binary" + assert ( + aux_data_brainmask.shape == orig.shape), "Auxiliary data must have the same shape as the original image" + else: + aux_data_brainmask = np.zeros_like(orig, dtype=np.uint8) + + if aseg_nocc is not None: + try: + aseg_cc = nib.load(join(subject_path, 'mri/aparc+aseg.mgz')).get_fdata() + except: + aseg_cc = nib.load(gt_file).get_fdata() + cc_center_mask = aseg_cc == 253 # get midplane from CC mask + assert (cc_center_mask.any()), "Fornix mask is empty" + cc_coordinates = np.where(cc_center_mask) + midplane = np.zeros_like(orig, dtype=bool) + LR_coordinate = int(np.median(cc_coordinates[0])) + # print("midplane at ", LR_coordinate) + midplane[LR_coordinate, :, :] = True + + if 'brainmask' in self.aux_data.keys(): + midplane = midplane & (aux_data_brainmask == 1) + aux_data_brainmask[midplane] = 2 + aux_data_brainmask = aux_data_brainmask + else: + aux_data_brainmask[midplane] = 2 + + if 'anomaly' in self.aux_data.keys(): + aux_data_anomaly = np.asarray(nib.load(join(subject_path, self.aux_data['anomaly'])).get_fdata(), + dtype=bool) + aux_data_brainmask[aux_data_anomaly] = 3 + + # import matplotlib.pyplot as plt + # plt.figure(figsize=(20, 8)) + # plt.subplot(1, 3, 1) + # plt.imshow(orig[LR_coordinate, :, :], cmap='gray') + # plt.subplot(1, 3, 2) + # plt.title("aux_data") + # plt.imshow(aux_data[LR_coordinate, :, :], cmap='gray') + # plt.subplot(1, 3, 3) + # plt.title("aux_data") + # plt.imshow(aux_data[:, 128, :], cmap='Pastel1') + # plt.savefig("aux_data.png", dpi=300) + + # nib.save(nib.Nifti1Image(aux_data, np.eye(4)), "aux_data.nii.gz") + else: + aux_data_brainmask = None + return orig, aseg, aseg_nocc, aux_data_brainmask, zoom + + def transform(self, plane, imgs, zoom): + """ + Change axis of all volumes so that the slice thickness dimension is the first dimension + and select the zooms in 2d accordingly + """ + if plane == "sagittal": + for i in range(len(imgs)): + imgs[i] = transform_sagittal(imgs[i]) + zoom = [zoom[i] for i in [1, 0]] # zoom[::-1][:2] 21 + elif plane == "axial": + for i in range(len(imgs)): + imgs[i] = transform_axial(imgs[i]) + zoom = [zoom[i] for i in [2, 0]] # zoom[1:] 12 + elif plane == 'coronal': # no image axis changes for coronal plane + zoom = [zoom[i] for i in [1, 2]] # zooms = zoom[:2] 01 + else: + raise ValueError("Plane {} not supported".format(plane)) + return imgs, zoom + + def _pad_image(self, img, max_out): + # Get correct size = max along shape + h, w, d = img.shape + LOGGER.info("Padding image from {0} to {1}x{1}x{1}".format(img.shape, max_out)) + padded_img = np.zeros((max_out, max_out, max_out), dtype=img.dtype) + padded_img[0: h, 0: w, 0:d] = img + return padded_img + + + + def calculate_center_volume(self, orig_file, centers : np.ndarray, l : int = 5, sig : float = 3., offset : int = 0): + """Calculates a gaussian kernel around the given centers and adds them to a volume + + Arguments: + orig_file (str): path to original image + centers (np.ndarray): array of RAS coordinates of the centers (In the same space as orig file) + l (int): size of the kernel (default: {5}) + sig (float): sigma of the kernel (default: {1.}) + + Returns: + np.ndarray: volume with gaussian kernels around the centers + + """ + assert l % 2 == 1, "l must be odd" + orig = nib.load(orig_file) + + # create empty volume + center_volume = np.zeros_like(orig.get_fdata(), dtype=np.float32) + + centers_vox = np.array(centers, dtype=np.float32) + + # add gkern around centers + for center in centers_vox: + center = np.rint(center).astype(int) + kernel = self.gkern(l=l, sig=sig) + + # Calculate the start and end indices for adding the kernel to the center_volume + start_idx = np.maximum(center - (l - 1) // 2, 0).astype(int) + end_idx = np.minimum(center + (l - 1) // 2 + 1, center_volume.shape).astype(int) + + # Calculate the start and end indices for the kernel + kernel_start_idx = np.maximum((l - 1) // 2 - center, 0).astype(int) + kernel_end_idx = np.minimum((l - 1) // 2 + np.array(center_volume.shape) - center, l).astype(int) + + # check if centers are overlapping + assert np.all(center_volume[start_idx[0]:end_idx[0], + start_idx[1]:end_idx[1], + start_idx[2]:end_idx[2]] == 0), "Centers are overlapping" + + # Add the weighted kernel to the center_volume + center_volume[start_idx[0]:end_idx[0], + start_idx[1]:end_idx[1], + start_idx[2]:end_idx[2]] += kernel[ + kernel_start_idx[0]:kernel_end_idx[0], + kernel_start_idx[1]:kernel_end_idx[1], + kernel_start_idx[2]:kernel_end_idx[2]] + + tmp = calculate_centers_of_comissures(center_volume) + return center_volume + + + def gkern(self, l=5, sig=1.): + """\ + creates gaussian kernel with side length `l` and a sigma of `sig` + """ + ax = np.linspace(-(l - 1) / 2., (l - 1) / 2., l) + + gauss = np.exp(-0.5 * np.square(ax) / np.square(sig)) + kernel = np.outer(gauss, gauss) + + kernel = kernel.T[:,:,None]*kernel.T[:,None] + + return kernel / np.max(kernel) + + + def create_hdf5_dataset(self, plane='axial'): + data_per_size = defaultdict(lambda: defaultdict(list)) + start_d = time.time() + failed = 0 + + for idx in range(len(self.subject_dirs)): + if true: + current_path = self.subject_dirs[idx] + current_orig = self.subject_files[idx] + current_gt = self.gt_files[idx] + + if current_path.endswith('/'): + sub_name = current_path.split("/")[-2] + else: + sub_name = current_path.split("/")[-1] + + if self.gt_nocc_files != None: + current_gt_nocc = self.gt_nocc_files[idx] + else: + current_gt_nocc = None + + LOGGER.info("Volume Nr: {} Processing MRI Data from {}".format(idx + 1, current_path)) + + orig, aseg, aseg_nocc, aux_data, zoom = self._load_volumes(gt_file=current_gt, + orig_file=current_orig, + gt_nocc_file=current_gt_nocc, + subject_path=current_path) + + # Create ACPC volumes + if self.centers != {}: + center_volume = self.calculate_center_volume(current_orig, self.centers[sub_name]) + #center_volume = self.calculate_center_volume2(current_orig, self.centers[sub_name]) + + + if self.crop: + axis = 0 #sagittal + mid_plane = 128 #mri_cc default + volume_thickness = np.ceil((5/zoom[axis])) + if volume_thickness % 2 == 0: + volume_thickness += 1 + + if plane=="sagittal" and not self.pad: + orig = subset_volume_plane(orig, plane=mid_plane, thickness=volume_thickness+args.thickness*2, axis=axis) + else: + orig = subset_volume_plane(orig, plane=mid_plane, thickness=volume_thickness, axis=axis) + + + aseg = subset_volume_plane(aseg, plane=mid_plane, thickness=volume_thickness, axis=axis) + if aux_data is not None: + aux_data = subset_volume_plane(aux_data, plane=mid_plane, thickness=volume_thickness, axis=axis) + if aseg_nocc is not None: + aseg_nocc = subset_volume_plane(aseg_nocc, plane=mid_plane, thickness=volume_thickness, axis=axis) + if self.centers != {}: + center_volume = subset_volume_plane(center_volume, plane=mid_plane, thickness=volume_thickness, axis=axis) + + + + mapped_aseg, mapped_aseg_sag = map_aparc_aseg2label(aseg, self.labels, self.labels_sag, + self.lateralization, aseg_nocc, + processing=self.processing) + + if plane == 'sagittal': + mapped_aseg = mapped_aseg_sag + + mapped_aseg = map_incrementing(mapped_aseg.copy(), lut=self.lut) + weights = create_weight_mask(mapped_aseg.copy(), max_weight=self.max_weight, ctx_thresh=19 if plane == 'sagittal' else 33, + max_edge_weight=self.edge_weight, max_hires_weight=self.hires_weight, + cortex_mask=self.gm_mask, gradient=self.gradient) + + + LOGGER.info("Created weights with max_w {}, gradient {}," + " edge_w {}, hires_w {}, gm_mask {}".format(self.max_weight, self.gradient, self.edge_weight, + self.hires_weight, self.gm_mask)) + + # transform volumes to correct shape + if aux_data is not None: + [orig, mapped_aseg, weights, aux_data], zoom = self.transform(plane, [orig, mapped_aseg, weights, + aux_data], zoom) + elif (self.centers != {}): + [orig, mapped_aseg, weights, center_volume], zoom = self.transform(plane, [orig, mapped_aseg, weights, center_volume], zoom) + else: + [orig, mapped_aseg, weights], zoom = self.transform(plane, [orig, mapped_aseg, weights], zoom) + + assert (len(zoom) == 2), "Zoom should be 2D" + + + # Create Thick Slices, filter out blanks + orig_thick = get_thick_slices(orig, self.slice_thickness, pad=self.pad) + + if aux_data is not None: + aux_data = get_thick_slices(aux_data, self.slice_thickness, pad = self.pad) if aux_data is not None else None + orig, mapped_aseg, weights, aux_data = filter_blank_slices_thick( + [orig_thick, mapped_aseg, weights, aux_data], label_vol=mapped_aseg) + elif (self.centers != {}): + orig, mapped_aseg, weights, center_volume = filter_blank_slices_thick( + [orig_thick, mapped_aseg, weights, center_volume], label_vol=mapped_aseg) + else: + orig, mapped_aseg, weights = filter_blank_slices_thick([orig_thick, mapped_aseg, weights], + label_vol=mapped_aseg) + + num_batch = orig.shape[2] + orig = np.transpose(orig, (2, 0, 1, + 3)) # shape: (plane1, plane2, no_thick_slices, slice_thickness) -> (no_thick_slices, plane1, plane2, slice_thickness) + mapped_aseg = np.transpose(mapped_aseg, (2, 0, 1)) # put no_thick_slices as first dimension + weights = np.transpose(weights, (2, 0, 1)) # put no_thick_slices as first dimension + if aux_data is not None: + aux_data = np.transpose(aux_data, (2, 0, 1, 3)) # put no_thick_slices as first dimension + + if self.centers != {}: + center_volume = np.transpose(center_volume, (2, 0, 1)) # put no_thick_slices as first dimension + + assert (orig.shape[0] == mapped_aseg.shape[0] == weights.shape[0]), "Number of slices does not match" + assert (orig.shape[1] == mapped_aseg.shape[1] == weights.shape[1]), "Number of rows does not match" + assert (orig.shape[2] == mapped_aseg.shape[2] == weights.shape[2]), "Number of columns does not match" + + if orig.shape[1] == orig.shape[2]: + size = orig.shape[1] + else: + LOGGER.info("Image is not isotropic; using both dimenstions as key") + size = f"{orig.shape[1]}-{orig.shape[2]}" + + assert not (np.sum(mapped_aseg, axis=(1, 2)) == 0).any(), "Empty labels in mapped_aseg" + + data_per_size[f'{size}']['orig'].extend(orig) # add slices to list + data_per_size[f'{size}']['aseg'].extend(mapped_aseg) + + data_per_size[f'{size}']['weight'].extend(weights) + if aux_data is not None: + data_per_size[f'{size}']['aux_data'].extend(aux_data) + data_per_size[f'{size}']['zoom'].extend((zoom,) * num_batch) + + data_per_size[f'{size}']['subject'].extend([sub_name.encode("ascii", "ignore")] * len(orig)) # add subject name to each slice + + if self.centers != {}: + # assert not (np.sum(center_volume, axis=(1, 2)) == 0).all(), "Empty center_volume" + data_per_size[f'{size}']['center_volume'].extend(center_volume) + + if self.debug and idx == 20: + break + + #except AssertionError as e: + # LOGGER.warning("Volume: {} Failed Reading Data. Error: {}".format(idx, e)) + # failed += 1 + # continue + + for key, data_dict in data_per_size.items(): + data_per_size[key]['orig'] = np.asarray(data_dict['orig'], dtype=np.uint8) + data_per_size[key]['aseg'] = np.asarray(data_dict['aseg'], dtype=np.uint8) + data_per_size[key]['weight'] = np.asarray(data_dict['weight'], dtype=float) + + if 'center_volume' in data_dict.keys(): + data_per_size[key]['center_volume'] = np.asarray(data_dict['center_volume'], dtype=float) + + assert (data_dict['orig'].shape[0] == data_dict['aseg'].shape[0] == data_dict['weight'].shape[0] == len( + data_dict['zoom']) + ), 'Data does have not the same length, but orig:{} aseg:{} weight:{} zoom:{}'.format( + data_dict['orig'].shape[0], data_dict['aseg'].shape[0], data_dict['weight'].shape[0], + data_dict['zoom'].shape[0]) + assert (len(data_dict['subject']) == data_dict['orig'].shape[ + 0]), f'Subject and data do not have the same length, but subject:{data_dict["subject"].shape[0]} data:{data_dict["orig"].shape[0]}' + + with h5py.File(self.dataset_name, "w") as hf: + dt = h5py.special_dtype(vlen=str) + for key, data_dict in data_per_size.items(): + group = hf.create_group(f"{key}") + group.create_dataset("orig_dataset", data=data_dict['orig'], chunks=True) + group.create_dataset("aseg_dataset", data=data_dict['aseg']) + group.create_dataset("weight_dataset", data=data_dict['weight']) + if 'aux_data' in data_dict.keys(): + group.create_dataset("aux_data", data=data_dict['aux_data']) + if 'center_volume' in data_dict.keys(): + group.create_dataset("center_dataset", data=data_dict['center_volume']) + group.create_dataset("zoom_dataset", data=data_dict['zoom']) + group.create_dataset("subject", data=data_dict['subject'], dtype=dt) + + end_d = time.time() - start_d + LOGGER.info("Successfully written {} in {:.3f} seconds.".format(self.dataset_name, end_d)) + if failed > 0: + LOGGER.warning("Failed to read {} volumes.".format(failed)) + else: + LOGGER.info("No failures while reading volumes.") + + +class H5pyDatasetMasks(H5pyDataset): + + def __init__(self, params): + self.debug = params["debug"] + self.dataset_name = params["dataset_name"] + self.data_path = params["data_path"] + self.slice_thickness = params["thickness"] + self.mask_name = params["image_name"] + + # self.lateralization = unify_lateralized_labels(self.lut, params["combi"]) + + if params["csv_file"] is not None: + with open(params["csv_file"], "r") as s_dirs: + self.subject_dirs = [line.strip() for line in s_dirs.readlines()] + + else: + self.search_pattern = join(self.data_path, params["pattern"]) + self.subject_dirs = glob.glob(self.search_pattern) + + self.data_set_size = len(self.subject_dirs) + + self.aux_data = None + self.aseg_nocc = None + + def _load_volumes(self, subject_path): + # return super()._load_volumes(subject_path) + # Load the orig and extract voxel spacing information (x, y, and z dim) + LOGGER.info(f'Processing mask image {self.mask_name}') + mask = nib.load(join(subject_path, self.mask_name) if self.mask_name != "" else subject_path) + + zoom = mask.header.get_zooms() + mask = np.asarray(mask.get_fdata(), dtype=bool) + + return mask, zoom + + def create_hdf5_dataset(self, plane='all'): + data_per_size = defaultdict(lambda: defaultdict(list)) + start_d = time.time() + failed = 0 + + for idx, current_subject in enumerate(self.subject_dirs): + + try: + LOGGER.info( + "Volume Nr: {} Processing MRI Data from {}/{}".format(idx + 1, current_subject, self.mask_name)) + + mask, zoom = self._load_volumes(current_subject) + + if not (mask.shape[0] == mask.shape[1] == mask.shape[2]): + # LOGGER.warning(f"Image is not isotropic, but has size {mask.shape}") + # np.pad(mask, ((0, 0), (0, 0), (0, 0)), mode='constant', constant_values=0) + mask = self._pad_image(mask, max_out=max(mask.shape)) + + assert (zoom[0] == zoom[1] == zoom[2]), f"Zoom should be isotropic, but is {zoom}" + + size = mask.shape[0] + + # assert(mask.shape[1] == mask.shape[2] == size), f"Image is not isotropic, but has size {mask.shape}" + + # transform volumes to correct shape + if plane == 'all': + [mask_axial], _ = self.transform('axial', [mask], zoom) + [mask_coronal], _ = self.transform('coronal', [mask], zoom) + [mask_sagittal], zoom = self.transform('sagittal', [mask], zoom) + mask = np.concatenate([mask_axial, mask_coronal, mask_sagittal], axis=2) + else: + [mask], zoom = self.transform(plane, [mask], zoom) + + assert (len(zoom) == 2), "Zoom should be 2D" + + # Create Thick Slices, filter out blanks + mask_thick = get_thick_slices(mask, self.slice_thickness, pad=self.pad) + + # filter blank slices + select_slices = (np.sum(mask_thick, + axis=(0, 1, 3)) > 10) # select thick-slices with more than 10 voxels of mask + mask_thick = mask_thick[:, :, select_slices, :] + + num_batch = mask_thick.shape[2] + mask = np.transpose(mask_thick, (2, 0, 1, + 3)) # shape: (plane1, plane2, no_thick_slices, slice_thickness) -> (no_thick_slices, plane1, plane2, slice_thickness) + + data_per_size[f'{size}']['mask'].extend(mask) # add slices to list + data_per_size[f'{size}']['zoom'].extend((zoom,) * num_batch) + sub_name = current_subject.split("/")[-1] + data_per_size[f'{size}']['subject'].extend( + [sub_name.encode("ascii", "ignore")] * len(mask)) # add subject name to each slice + + if self.debug and idx == 20: + break + + except Exception as e: + LOGGER.info("Volume: {} Failed Reading Data. Error: {}".format(idx, e)) + failed += 1 + continue + + for key, data_dict in data_per_size.items(): + data_per_size[key]['mask'] = np.asarray(data_dict['mask'], dtype=bool) + + assert (len(data_dict['subject']) == data_dict['mask'].shape[0] == len(data_dict['zoom'])), \ + f'Subject and data do not have the same length, but subject:{data_dict["subject"].shape[0]} data:{data_dict["mask"].shape[0]}, zoom:{len(data_dict["zoom"])}' + + with h5py.File(self.dataset_name, "w") as hf: + dt = h5py.special_dtype(vlen=str) + for key, data_dict in data_per_size.items(): + group = hf.create_group(f"{key}") + group.create_dataset("mask_dataset", data=data_dict['mask']) + group.create_dataset("zoom_dataset", data=data_dict['zoom']) + group.create_dataset("subject", data=data_dict['subject'], dtype=dt) + + end_d = time.time() - start_d + LOGGER.info("Successfully written {} in {:.3f} seconds.".format(self.dataset_name, end_d)) + if failed > 0: + LOGGER.warning("Failed to read {} volumes.".format(failed)) + else: + LOGGER.info("No failures while reading volumes.") + + +if __name__ == '__main__': + import argparse + + # Training settings + parser = argparse.ArgumentParser(description='HDF5-Creation') + + parser.add_argument('--hdf5_name', type=str, default="../data/hdf5_set/Multires_coronal.hdf5", + help='path and name of hdf5-data_loader (default: ../data/hdf5_set/Multires_coronal.hdf5)') + parser.add_argument('--plane', type=str, default="axial", choices=["axial", "coronal", "sagittal", "all"], + help="Which plane to put into file (axial (default), coronal or sagittal)") + parser.add_argument('--data_dir', type=str, default="/data", help="Directory with images to load") + parser.add_argument('--thickness', type=int, default=3, help="Number of pre- and succeeding slices (default: 3)") + parser.add_argument('--csv_file', type=str, default=None, help="Csv-file listing subjects to include in file") + parser.add_argument('--pattern', type=str, help="Pattern to match files in directory.") + parser.add_argument('--image_name', type=str, + help="Default name of original images. FreeSurfer orig.mgz is default (mri/orig.mgz)") + parser.add_argument('--gt_name', type=str, default=None, + help="Default name for ground truth segmentations. Default: mri/aparc.DKTatlas+aseg.mgz." + " If Corpus Callosum segmentation is already removed, do not set gt_nocc." + " (e.g. for our internal training set mri/aparc.DKTatlas+aseg.filled.mgz exists already" + " and should be used here instead of mri/aparc.DKTatlas+aseg.mgz). ") + parser.add_argument('--gt_nocc', type=str, default=None, + help="Segmentation without corpus callosum (used to mask this segmentation in ground truth)." + " If the used segmentation was already processed, do not set this argument." + " For a normal FreeSurfer input, use mri/aseg.auto_noCCseg.mgz.") + parser.add_argument('--aux_data', type=str, default=None + , help="Auxiliary data to load (e.g. mri/brainmask.mgz).") + parser.add_argument('--lut', type=str, default='./config/FastSurfer_ColorLUT.tsv', + help="FreeSurfer-style Color Lookup Table with labels to use in final prediction. " + "Has to have columns: ID LabelName R G B A" + "Default: ./config/FastSurfer_ColorLUT.tsv.") + parser.add_argument('--combi', action='append', default=["Left-", "Right-"], + help="Suffixes of labels names to combine. Default: Left- and Right-.") + parser.add_argument('--sag_mask', default=("Left-", "ctx-rh"), + help="Suffixes of labels names to mask for final sagittal labels. Default: Left- and ctx-rh.") + parser.add_argument('--max_w', type=int, default=5, + help="Overall max weight for any voxel in weight mask. Default=5") + parser.add_argument('--edge_w', type=int, default=5, help="Weight for edges in weight mask. Default=5") + parser.add_argument('--hires_w', type=int, default=None, + help="Weight for hires elements (sulci, WM strands, cortex border) in weight mask. Default=None") + parser.add_argument('--no_grad', action='store_true', default=False, + help="Turn on to only use median weight frequency (no gradient)") + parser.add_argument('--gm', action="store_true", default=False, + help="Turn on to add cortex mask for hires-processing.") + parser.add_argument('--processing', type=str, default="aparc", choices=["aparc", "aseg", "none"], + help="Use aseg, aparc or no specific mapping processing") + parser.add_argument('--mask_dataset', action="store_true", default=False, + help="Turn on to create a dataset with masks instead of images - useful to create masks for data augmentation") + parser.add_argument('--debug', action="store_true", default=False, + help="Only process 20 subjects for debugging") + parser.add_argument('--center', type=str, default=None, help="CSV file with RAS coordinates of the center of the commisures [PC, AC]") + parser.add_argument('--pad', action="store_true", default=False, help="Pad images at the edge of the volume for thick slices") + parser.add_argument('--crop', action="store_true", default=False, help="Crop images to the center of the volume for thick slices") + + args = parser.parse_args() + + if args.aux_data is not None: + args.aux_data = json.loads(args.aux_data) + + dataset_params = {"dataset_name": args.hdf5_name, "data_path": args.data_dir, "thickness": args.thickness, + "csv_file": args.csv_file, "pattern": args.pattern, "image_name": args.image_name, + "gt_name": args.gt_name, "gt_nocc": args.gt_nocc, + "max_weight": args.max_w, "edge_weight": args.edge_w, + "lut": args.lut, "combi": args.combi, "sag-mask": args.sag_mask, + "hires_weight": args.hires_w, "gm_mask": args.gm, "gradient": not args.no_grad, + "aux_data": args.aux_data, "debug": args.debug, "center": args.center, "pad": args.pad, "crop": args.crop} + + logging.setup_logging() + + if not args.mask_dataset: + dataset_generator = H5pyDataset(params=dataset_params, processing=args.processing) + dataset_generator.create_hdf5_dataset(plane=args.plane) + else: + dataset_generator = H5pyDatasetMasks(params=dataset_params) + dataset_generator.create_hdf5_dataset() diff --git a/CCNet/inference.py b/CCNet/inference.py new file mode 100644 index 00000000..b5da3848 --- /dev/null +++ b/CCNet/inference.py @@ -0,0 +1,363 @@ + +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +import time +from typing import Optional, Dict, Tuple +import nibabel as nib + +import numpy as np +import torch +from tqdm.contrib.logging import logging_redirect_tqdm # why is this not at the top? +from tqdm import tqdm + +from torch.utils.data import DataLoader +from torchvision import transforms + +from FastSurferCNN.utils import logging +from CCNet.models.networks import build_model +from CCNet.data_loader.augmentation import ToTensorTest, CutoutTumorMaskInference +from FastSurferCNN.data_loader.data_utils import map_prediction_sagittal2full +from CCNet.data_loader.dataset import MultiScaleOrigDataThickSlices +from FastSurferCNN.config.global_var import get_class_names +from CCNet.utils.misc import calculate_centers_of_comissures + + +logger = logging.getLogger(__name__) + + +class Inference: + + permute_order: Dict[str, Tuple[int, int, int, int]] + device: Optional[torch.device] + default_device: torch.device + + def __init__(self, cfg, device: torch.device, ckpt: str = "", inference_weights: Optional[Dict[str, float]] = None): + # Set random seed from configs. + np.random.seed(cfg.RNG_SEED) + torch.manual_seed(cfg.RNG_SEED) + self.cfg = cfg + + self.doing_inpainting = self.cfg.MODEL.MODEL_NAME == 'FastSurferPaint' + self.doing_localisation = self.cfg.MODEL.MODEL_NAME == 'FastSurferLocalisation' + + # Switch on denormal flushing for faster CPU processing + # seems to have less of an effect on VINN than old CNN + torch.set_flush_denormal(True) + + self.default_device = device + + # Options for parallel run + self.model_parallel = torch.cuda.device_count() > 1 and \ + self.default_device.type == "cuda" and \ + self.default_device.index is None + + # Initial model setup + self.model = None + self._model_not_init = None + self.setup_model(cfg, device=self.default_device) + self.model_name = self.cfg.MODEL.MODEL_NAME + + self.alpha = {"sagittal": 0.6, "axial": 0.2, "coronal": 0.2} if inference_weights is None else inference_weights + self.permute_order = {"axial": (3, 0, 2, 1), "coronal": (2, 3, 0, 1), "sagittal": (0, 3, 2, 1)} + self.post_prediction_mapping_hook = {"sagittal": map_prediction_sagittal2full, "axial": map_prediction_sagittal2full, "coronal": map_prediction_sagittal2full} + self.alpha_loc = {"sagittal": 1.0, "axial": 0.0, "coronal": 0.0} + + # Initial checkpoint loading + if ckpt: + # this also moves the model to the para + self.load_checkpoint(ckpt) + + self._debug = False + + def setup_model(self, cfg=None, device: torch.device = None): + if cfg is not None: + self.cfg = cfg + if device is None: + device = self.default_device + + # Set up model + self._model_not_init = build_model(self.cfg) # ~ model = CCNet(params_network) + self._model_not_init.to(device) + self.device = None + + def set_cfg(self, cfg): + self.cfg = cfg + + def to(self, device: Optional[torch.device] = None): + if self.model_parallel: + raise RuntimeError("Moving the model to other devices is not supported for multi-device models.") + _device = self.default_device if device is None else device + self.device = _device + self.model.to(device=_device) + + def load_checkpoint(self, ckpt): + logger.info("Loading checkpoint {}".format(ckpt)) + + self.model = self._model_not_init + # If device is None, the model has never been loaded (still in random initial configuration) + if self.device is None: + self.device = self.default_device + + # workaround for mps (directly loading to map_location=mps results in zeros) + device = self.device + if self.device.type == 'mps': + self.model.to('cpu') + device = 'cpu' + else: + # make sure the model is, where it is supposed to be + self.model.to(self.device) + + model_state = torch.load(ckpt, map_location=device) + self.model.load_state_dict(model_state['model_state']) + + # workaround for mps (move the model back to mps) + if self.device.type == 'mps': + self.model.to(self.device) + + if self.model_parallel: + self.model = torch.nn.DataParallel(self.model) + + def get_modelname(self): + return self.model_name + + def get_cfg(self): + return self.cfg + + def get_num_classes(self): + return self.cfg.MODEL.NUM_CLASSES + + def get_plane(self): + return self.cfg.DATA.PLANE + + def get_model_height(self): + return self.cfg.MODEL.HEIGHT + + def get_model_width(self): + return self.cfg.MODEL.WIDTH + + def get_max_size(self): + if self.cfg.MODEL.OUT_TENSOR_WIDTH == self.cfg.MODEL.OUT_TENSOR_HEIGHT: + return self.cfg.MODEL.OUT_TENSOR_WIDTH + else: + return self.cfg.MODEL.OUT_TENSOR_WIDTH, self.cfg.MODEL.OUT_TENSOR_HEIGHT + + def get_device(self): + return self.device + + @torch.no_grad() + def eval(self, init_pred: torch.Tensor, + val_loader: DataLoader, + *, out_scale=None, + out: Optional[torch.Tensor] = None, + localisation : Optional[torch.Tensor] = None, + out_localisation : Optional[torch.Tensor] = None, + init_pred_localisation : Optional[torch.Tensor] = None, + sdir : Optional[str] = None): + """Perform prediction and inplace-aggregate views into pred_prob. Return pred_prob. + + Args: + init_pred: Initial prediction (e.g. from the segmentation network) + val_loader: DataLoader for the test data + out_scale: Output scale factor + out: Output tensor to add the predictions to + + Returns: + out: Output tensor with the predictions added + + """ + self.model.eval() + # we should check here, whether the DataLoader is a Random or a SequentialSampler, but we cannot easily. + if not isinstance(val_loader.sampler, torch.utils.data.SequentialSampler): + logger.warning("The Validation loader seems to not use the SequentialSampler. This might interfere with " + "the assumed sorting of batches.") + + start_index = 0 + plane = self.cfg.DATA.PLANE + index_of_current_plane = self.permute_order[plane].index(0) + target_shape = init_pred.shape + ii = [slice(None) for _ in range(4)] + pred_ii = tuple(slice(i) for i in target_shape[:3]) + + if out is None: + out = init_pred.detach().clone() + + if self.doing_inpainting and localisation is None: + localisation = torch.zeros([*target_shape[:3], 1], dtype=float, device=out.device) + + if self.doing_localisation and out_localisation is None: + if init_pred_localisation is not None: + out_localisation = init_pred_localisation.detach().clone() + else: + out_localisation = torch.zeros([*target_shape[:3], 1], dtype=float, device=out.device) + + + assert not (self.doing_localisation and out_localisation is None), "Localisation is only possible if the model is a localisation model" + + batch_idx = 0 # prevent error in throws + with logging_redirect_tqdm(): + try: + for batch_idx, batch in tqdm(enumerate(val_loader), total=len(val_loader), unit="batch"): + + # move data to the model device + images, scale_factors = batch['image'].to(self.device), batch['scale_factor'].to(self.device) + + # predict the current batch, outputs logits + + + pred = self.model(images, scale_factors, out_scale) + if self.doing_inpainting: + output_slice = pred[1] + pred = pred[0] + + elif self.doing_localisation: + loc_pred = pred[:, -1:, :, :] + pred = pred[:, :-1, :, :] + + + # if not np.sum(output_slice.shape) == 256*2+2: + # import pdb; pdb.set_trace() + + batch_size = pred.shape[0] + end_index = start_index + batch_size + + # check if we need a special mapping (e.g. as for sagittal) + if self.post_prediction_mapping_hook.get(plane) is not None: + pred = self.post_prediction_mapping_hook.get(plane)(pred, num_classes=self.get_num_classes()) + + # permute the prediction into the out slice order + pred = pred.permute(*self.permute_order[plane]).to(out.device) # the to-operation is implicit + + # cut prediction to the image size + pred = pred[pred_ii] + + # add prediction logits into the output (same as multiplying probabilities) + ii[index_of_current_plane] = slice(start_index, end_index) + + + out[tuple(ii)].add_(pred, alpha=self.alpha.get(plane)) + + if self.doing_inpainting: + + # if plane == 'sagittal': + # import pdb; pdb.set_trace() + #output_slice = self.post_predition_mapping_hook.get(plane, lambda x: x)(output_slice) + output_slice = output_slice.permute(*self.permute_order[plane]).to(out.device) + output_slice = output_slice[pred_ii] + try: + localisation[tuple(ii)[:3]].add_(output_slice) + except: + print(output_slice.shape) + print(localisation[tuple(ii)[:3]].shape) + import pdb; pdb.set_trace() + + # if doing localisation, add the localisation prediction + if self.doing_localisation: + # permute the prediction into the out slice order + loc_pred = loc_pred.permute(*self.permute_order[plane]).to(out.device) # the to-operation is implicit + + # cut prediction to the image size + loc_pred = loc_pred[pred_ii] + + out_localisation[tuple(ii)].add_(loc_pred, alpha=self.alpha_loc.get(plane)) + + + if self._debug: + # debug save + assert batch_size == 1, "Debug save only works for batch size 1" + + orig_image = nib.load(f'{sdir}/mri/orig.mgz') + max_value = np.max(orig_image.get_fdata()) + + props = (loc_pred - loc_pred.min()) / (loc_pred.max() - loc_pred.min()) # Normalize values between 0 and 1 + props = props * max_value # Scale values to [0, 255] + props= torch.round(props) + + localisation = torch.zeros_like(out_localisation) + localisation[tuple(ii)].add_(props, alpha=self.alpha_loc.get(plane)) + + #print(torch.unique(props)) + localisation_img = nib.MGHImage(localisation.cpu()[:,:,:,0], orig_image.affine, orig_image.header) + nib.save(localisation_img, f'{sdir}/mri/localisation_{plane}_{batch_idx}.mgz') + + + # permute the image into the out slice order + images = images.permute(1,3,2,0).to(out.device) # the to-operation is implicit + + # cut image to the image size + #images = images[pred_ii].cpu() + images = images.cpu() * 255 + images = torch.round(images) + + print(torch.unique(images)) + input_img = nib.MGHImage(images[:,:,:,0], orig_image.affine, orig_image.header) + nib.save(input_img, f'{sdir}/mri/img_{plane}_{batch_idx}.mgz') + + + #AC_PC = calculate_centers_of_comissures(loc_pred) + #with open(f'{sdir}/mri/AC_PC_{plane}.txt', 'a') as f: + # f.write(f'{batch_idx},{AC_PC[0]},{AC_PC[1]}\n') + + + + start_index = end_index + + except: + logger.exception("Exception in batch {} of {} inference.".format(batch_idx, plane)) + raise + else: + logger.info("Inference on {} batches for {} successful".format(batch_idx+1, plane)) + + if self.doing_localisation: + return out, out_localisation + + if self.doing_inpainting: + return out, localisation.squeeze() + else: + return out + + @torch.no_grad() + def run(self, + init_pred: torch.Tensor, + img_filename, + orig_data, + orig_zoom, + out: Optional[torch.Tensor] = None, + out_res = None, + lesion_mask = None, + batch_size: Optional[int] = None, + out_localisation : Optional[torch.Tensor] = None, + init_pred_localisation : Optional[torch.Tensor] = None, + sdir : Optional[str] = None, + pad : bool = False): + """Run the loaded model on the data (T1) from orig_data and filename img_filename with scale factors orig_zoom.""" + # Set up DataLoader + test_dataset = MultiScaleOrigDataThickSlices(img_filename, orig_data, orig_zoom, self.cfg, lesion_mask=lesion_mask, + transforms=transforms.Compose([ToTensorTest(include=['image','cutout_mask']), CutoutTumorMaskInference()]), + pad = pad) + + test_data_loader = DataLoader(dataset=test_dataset, shuffle=False, + batch_size=self.cfg.TEST.BATCH_SIZE if batch_size is None else batch_size) + + # Run evaluation + start = time.time() + + out = self.eval(init_pred, test_data_loader, out=out, out_scale=out_res, out_localisation=out_localisation, init_pred_localisation=init_pred_localisation, sdir=sdir) + time_delta = time.time() - start + logger.info(f"{self.cfg.DATA.PLANE.capitalize()} inference on {img_filename} finished in " + f"{time_delta:0.4f} seconds") + + return out + diff --git a/CCNet/models/__init__.py b/CCNet/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CCNet/models/losses.py b/CCNet/models/losses.py new file mode 100644 index 00000000..cd2407f1 --- /dev/null +++ b/CCNet/models/losses.py @@ -0,0 +1,390 @@ + +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# IMPORTS +from FastSurferCNN.utils import logging +from doctest import FAIL_FAST +from math import e, log +from numpy import NaN +import torch +from torch import nn, normal +from torch.nn.modules.loss import _Loss +from torch.nn import functional as F +from pytorch_msssim import SSIM as SSIM_loss #, ms_ssim, SSIM, MS_SSIM +from torch.nn.modules.loss import KLDivLoss, CrossEntropyLoss + +logger = logging.getLogger(__name__) + +class DiceLoss(_Loss): + """ + Dice Loss + """ + + def forward(self, output, target, weights=None, ignore_index=None): + """ + :param output: N x C x H x W Variable + :param target: N x C x W LongTensor with starting class at 0 + :param weights: C FloatTensor with class wise weights + :param int ignore_index: ignore label with index x in the loss calculation + :return: + """ + eps = 0.001 + + encoded_target = output.detach() * 0 + + if ignore_index is not None: + mask = target == ignore_index + target = target.clone() + target[mask] = 0 + encoded_target.scatter_(1, target.unsqueeze(1), 1) + mask = mask.unsqueeze(1).expand_as(encoded_target) + encoded_target[mask] = 0 + + else: + encoded_target.scatter_(1, target.unsqueeze(1), 1) + + if weights is None: + weights = 1 + + intersection = output * encoded_target + numerator = 2 * intersection.sum(0).sum(1).sum(1) + denominator = output + encoded_target + + if ignore_index is not None: + denominator[mask] = 0 + + denominator = denominator.sum(0).sum(1).sum(1) + eps + loss_per_channel = weights * (1 - (numerator / denominator)) # Channel-wise weights + + return loss_per_channel.sum() / output.size(1) + + +class CrossEntropy2D(nn.Module): + """ + 2D Cross-entropy loss implemented as negative log likelihood + """ + + def __init__(self, weight=None, reduction='none'): + super(CrossEntropy2D, self).__init__() + self.nll_loss = nn.CrossEntropyLoss(weight=weight, reduction=reduction) + print(f"Initialized {self.__class__.__name__} with weight: {weight} and reduction: {reduction}") + + def forward(self, inputs, targets): + return self.nll_loss(inputs, targets) + + +class SSIMLoss(_Loss): + + def __init__(self, window_size=11, window_sigma=1.5, data_range=255, size_average=True, channel=1): + """ from https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py + Args: + data_range (float or int, optional): value range of input images. (usually 1.0 or 255) + size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar + win_size: (int, optional): the size of gauss kernel + win_sigma: (float, optional): sigma of normal distribution + channel (int, optional): input channels (default: 3) + K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu. + """ + super(SSIMLoss, self).__init__() + ssim = SSIM_loss(win_size=window_size, win_sigma=window_sigma, data_range=data_range, size_average=size_average, channel=channel) + self.ssim_loss = lambda pred, orig: 1-ssim(pred.unsqueeze(1), orig.unsqueeze(1)) + #self.mask_weight = mask_weight + self.window_size = window_size + + def dilate_mask(self, min_x, max_x, img_size): + # make rectangular mask bigger to enable gaussian in SSIM to work + if torch.abs(min_x - max_x) < self.window_size * 2: + min_x = min_x - self.window_size + max_x = max_x + self.window_size + if min_x < 0: + min_x = 0 + if max_x > img_size: + max_x = img_size + + return min_x, max_x + + def ssim_loss_masked(self, pred, orig, mask): + # get min and max values of mask + + mean_ssims = 0 + ssim_count = 0 + + for i in range(pred.shape[0]): + indices = torch.nonzero(mask[i]) + + if len(indices) == 0: + continue + + min_x, min_y = torch.min(indices, dim=0).values + max_x, max_y = torch.max(indices, dim=0).values + + max_x += 1 + max_y += 1 + + min_x, max_x = self.dilate_mask(min_x, max_x, pred.shape[1]) + min_y, max_y = self.dilate_mask(min_y, max_y, pred.shape[2]) + + mean_ssims += self.ssim_loss(pred[i, min_x:max_x, min_y:max_y].unsqueeze(0), orig[i, min_x:max_x, min_y:max_y].unsqueeze(0)) + ssim_count += 1 + + return mean_ssims / ssim_count if ssim_count > 0 else 0 + + def forward(self, pred, orig, mask=None): + if mask is None: #self.mask_weight <= 0 or mask is None: + return self.ssim_loss(pred, orig) #* (1 - self.mask_weight) + else: + return self.ssim_loss_masked(pred, orig, mask) #* self.mask_weight # + self.ssim_loss(pred, orig) * (1 - self.mask_weight) \ + + + +class GradientLoss(_Loss): + + def __init__(self, alpha=1): + super(GradientLoss, self).__init__() + self.alpha = alpha + + @staticmethod + def gradient(image): + """Returns image gradients (dy, dx) for each color channel. + + Both output tensors have the same shape as the input: [batch_size, h, w, + d]. The gradient values are organized so that [I(x+1, y) - I(x, y)] is in + location (x, y). That means that dy will always have zeros in the last row, + and dx will always have zeros in the last column. + + Arguments: + image: Tensor with shape [batch_size, h, w, d]. + + Returns: + Pair of tensors (dy, dx) holding the vertical and horizontal image + gradients (1-step finite difference). + + Raises: + ValueError: If `image` is not a 4D tensor. + """ + + if image.dim() != 4: + raise ValueError( + 'image_gradients expects a 4D tensor ' + '[batch_size, d, h, w], not %s.', image.shape) + + # idea from tf.image.image_gradients(image) + # https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/image_ops_impl.py#L3441-L3512 + # x: (b,c,h,w), float32 or float64 + # dx, dy: (b,c,h,w) + + # gradient step=1 + left = image + right = F.pad(image, [0, 1, 0, 0])[:, :, :, 1:] + top = image + bottom = F.pad(image, [0, 0, 0, 1])[:, :, 1:, :] + + # dx, dy = torch.abs(right - left), torch.abs(bottom - top) + dx, dy = right - left, bottom - top + # dx will always have zeros in the last column, right-left + # dy will always have zeros in the last row, bottom-top + dx[:, :, :, -1] = 0 + dy[:, :, -1, :] = 0 + + + # ported tf implmentation + + # batch_size, depth, height, width, = torch.unbind(image.shape) + # dy = image[:,:, 1:, :] - image[:,:, :-1, :] + # dx = image[:,:, :, 1:] - image[:,:, :, :-1] + + # # Return tensors with same size as original image by concatenating + # # zeros. Place the gradient [I(x+1,y) - I(x,y)] on the base pixel (x, y). + # shape = torch.stack([batch_size, 1, width, depth]) + # dy = torch.concat([dy, torch.zeros(shape, image.dtype)], 1) + # dy = torch.reshape(dy, image.shape) + + # shape = torch.stack([batch_size, height, 1, depth]) + # dx = torch.concat([dx, torch.zeros(shape, image.dtype)], 2) + # dx = torch.reshape(dx, image.shape) + + return dx, dy + + + #loss = tf.reduce_mean(dx * dx + dy * dy, axis=axis) + #return loss + + def masked_loss(self, gen_frames, gt_frames, mask): + + loss_per_img = torch.zeros(gen_frames.shape[0], device=gen_frames.device) + + for i in range(gen_frames.shape[0]): + indices = torch.nonzero(mask[i]) + + if len(indices) == 0: + continue + # get min and max values of mask + min_x, min_y = torch.min(indices, dim=0).values + max_x, max_y = torch.max(indices, dim=0).values + + max_x += 1 + max_y += 1 + + # gradient + loss_per_img[i] = self.gradient_loss(gen_frames[None,None,i,min_x:max_x, min_y:max_y], gt_frames[None,None,i,min_x:max_x, min_y:max_y]) + + return torch.mean(loss_per_img) + + + def gradient_loss(self, gen_frames, gt_frames): + # gradient + dx, dy = self.gradient(gen_frames - gt_frames) + + # grad_diff_x = torch.abs(gt_dx - gen_dx) + # grad_diff_y = torch.abs(gt_dy - gen_dy) + + # condense into one tensor and avg + return torch.mean(dx ** self.alpha + dy ** self.alpha) + + def forward(self, gen_frames, gt_frames, mask=None): + if mask is None: + return self.gradient_loss(gen_frames, gt_frames) + else: + return self.masked_loss(gen_frames, gt_frames, mask) + + +class MSELoss(_Loss): + + def __init__(self): + super(MSELoss, self).__init__() + self.mse_loss = nn.MSELoss() + + def forward(self, pred, orig, mask=None): + if mask is None: + return self.mse_loss(pred, orig) + else: + return self.mse_loss(pred * mask.to(pred.device), orig * mask.to(orig.device)) + + + + +class CombinedLoss(nn.Module): + """ + For CrossEntropy the input has to be a long tensor + Args: + -- inputx N x C x H x W + -- target - N x H x W - int type + -- weight - N x H x W - float + """ + + def __init__(self, weight_dice=1, weight_ce=1): + super(CombinedLoss, self).__init__() + self.cross_entropy_loss = CrossEntropy2D() + self.dice_loss = DiceLoss() + self.weight_dice = weight_dice + self.weight_ce = weight_ce + + def forward(self, inputx, target, weight): + # Typecast to long tensor --> labels are bytes initially (uint8), + # index operations requiere LongTensor in pytorch + + target = target.to(device=target.device, dtype=torch.long) + + #target = target.type(torch.LongTensor) + # Due to typecasting above, target needs to be shifted to gpu again + #if inputx.is_cuda: + # target = target.cuda() + + input_soft = F.softmax(inputx, dim=1) # Along Class Dimension + dice_val = torch.mean(self.dice_loss(input_soft, target)) + ce_val = torch.mean(torch.mul(self.cross_entropy_loss.forward(inputx, target), weight)) + total_loss = torch.add(torch.mul(dice_val, self.weight_dice), torch.mul(ce_val, self.weight_ce)) + + return total_loss, dice_val, ce_val + +class ComissureLoss(nn.Module): + + def __init__(self, weight_seg = 1e-5, weight_loc = 1, weight_dist = 0): + super(ComissureLoss, self).__init__() + self.seg_loss = CombinedLoss() + self.loc_loss = KLDivLoss(reduction='batchmean', log_target=True) + self.dist_loss = MSELoss() + self.weight_seg = weight_seg + self.weight_loc = weight_loc + self.weight_dist = weight_dist + + + + def forward(self, inputx, target, weight): + """ + Args: + inputx: N x C+1 x H x W + target: N x 2H x W + weight: N x H x W + """ + + # calculate segmentation loss + input_seg = inputx[:, :-1, :, :] + target_seg = target[:, :target.shape[1]//2, :] + + seg_loss, dice_val, ce_val = self.seg_loss(input_seg, target_seg, weight) + + # calculate localisation loss + input_loc = torch.select(inputx, 1, -1).float() # Assuming the last channel is the location map + input_loc = torch.sigmoid(input_loc) # input_loc was in logit space. Sigmoid is its reverse function + target_loc = target[:, target.shape[1]//2:, :].float() + # KLDIVloss requires input to be probabilitie distributions (sum to 1) in logspace + input_loc = torch.log_softmax(input_loc.view(input_loc.shape[0], -1), dim=1).view_as(input_loc) + target_loc = torch.log_softmax(target_loc.view(target_loc.shape[0], -1), dim=1).view_as(target_loc) + + if torch.all(target_loc == 0): # empty slice. Would result in inf/nan total loss + logger.warn("target_loc is 0") + target_loc += 1e-5 + + weight_loc = self.weight_loc + assert (abs(torch.sum(torch.exp(input_loc), dim=(1,2))-1)< 2e-5).all(), f"input_loc is not a probability distribution in logspace" + assert (abs(torch.sum(torch.exp(input_loc), dim=(1,2))-1)< 2e-5).all(), f"target_loc is not a probability distribution in logspace" + loc_loss = self.loc_loss(input_loc, target_loc) + + if loc_loss <= 0: + logger.warn("loc_loss less than 0") + loc_loss = 1e-5 + weight_loc = 0 + + # calculate distance loss + input_dist = torch.select(inputx, 1, -1).float() # Assuming the last channel is the location map + target_dist = target[:, target.shape[1]//2:, :].float() + + input_dist = torch.sigmoid(input_dist) # input_dist was in logit space. Sigmoid is its reverse function + input_dist = torch.clamp(input_dist, 0, 1) + + dist_loss = self.dist_loss(input_dist, target_dist) + + # calculate total loss + #loc_loss = (1/loc_loss) + + total_loss = self.weight_seg * seg_loss + weight_loc * loc_loss + self.weight_dist * dist_loss + + return total_loss, seg_loss, dice_val, ce_val, loc_loss, dist_loss + +def get_loss_func(cfg): + if cfg.MODEL.LOSS_FUNC == 'combined': + return CombinedLoss() + elif cfg.MODEL.LOSS_FUNC == 'ce': + return CrossEntropy2D() + elif cfg.MODEL.LOSS_FUNC == "dice": + return DiceLoss() + elif cfg.MODEL.LOSS_FUNC == "localisation": + return ComissureLoss(cfg.MODEL.WEIGHT_SEG, cfg.MODEL.WEIGHT_LOC, cfg.MODEL.WEIGHT_DIST) + else: + raise NotImplementedError(f"{cfg.MODEL.LOSS_FUNC}" + f" loss function is not supported") diff --git a/CCNet/models/networks.py b/CCNet/models/networks.py new file mode 100644 index 00000000..0c53daac --- /dev/null +++ b/CCNet/models/networks.py @@ -0,0 +1,279 @@ + +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# IMPORTS +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +import FastSurferCNN.models.sub_module as sm +import FastSurferCNN.models.interpolation_layer as il + + +class CCNetBase(nn.Module): + """ + Network Definition of Fully Competitive Network network + * Spatial view aggregation (input 7 slices of which only middle one gets segmented) + * Same Number of filters per layer (normally 64) + * Dense Connections in blocks + * Unpooling instead of transpose convolutions + * Concatenationes are replaced with Maxout (competitive dense blocks) + * Global skip connections are fused by Maxout (global competition) + * Loss Function (weighted Cross-Entropy and dice loss) + """ + def __init__(self, params, padded_size=256): + super(CCNetBase, self).__init__() + + # Parameters for the Descending Arm + self.encode1 = sm.CompetitiveEncoderBlockInput(params) + params['num_channels'] = params['num_filters'] + self.encode2 = sm.CompetitiveEncoderBlock(params) + self.encode3 = sm.CompetitiveEncoderBlock(params) + self.encode4 = sm.CompetitiveEncoderBlock(params) + self.bottleneck = sm.CompetitiveDenseBlock(params) + + # Parameters for the Ascending Arm + params['num_channels'] = params['num_filters'] + self.decode4 = sm.CompetitiveDecoderBlock(params) + self.decode3 = sm.CompetitiveDecoderBlock(params) + self.decode2 = sm.CompetitiveDecoderBlock(params) + params["num_filters_last"] = params["num_filters"] + self.decode1 = sm.CompetitiveDecoderBlock(params) + + # Code for Network Initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x, scale_factor=None, scale_factor_out=None): + """ + Computational graph + :param tensor x: input image + :return tensor: prediction logits + """ + encoder_output1, skip_encoder_1, indices_1 = self.encode1.forward(x) + encoder_output2, skip_encoder_2, indices_2 = self.encode2.forward(encoder_output1) + encoder_output3, skip_encoder_3, indices_3 = self.encode3.forward(encoder_output2) + encoder_output4, skip_encoder_4, indices_4 = self.encode4.forward(encoder_output3) + + bottleneck = self.bottleneck(encoder_output4) + + decoder_output4 = self.decode4.forward(bottleneck, skip_encoder_4, indices_4) + decoder_output3 = self.decode3.forward(decoder_output4, skip_encoder_3, indices_3) + decoder_output2 = self.decode2.forward(decoder_output3, skip_encoder_2, indices_2) + decoder_output1 = self.decode1.forward(decoder_output2, skip_encoder_1, indices_1) + + return decoder_output1 + + +class CCNet(CCNetBase): + def __init__(self, params, padded_size): + super(CCNet, self).__init__(params) + params['num_channels'] = params['num_filters'] + self.classifier = sm.ClassifierBlock(params) + + # Code for Network Initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x, scale_factor=None, scale_factor_out=None): + """ + + :param x: [N, C, H, W] + :param scale_factor: [N, 1] + :return: + """ + net_out = super().forward(x, scale_factor) + output = self.classifier.forward(net_out) + + return output + + +class FastSurferVINN(CCNetBase): + """ + Network Definition of Fully Competitive Network + * Spatial view aggregation (input 7 slices of which only middle one gets segmented) + * Same Number of filters per layer (normally 64) + * Dense Connections in blocks + * Unpooling instead of transpose convolutions + * Concatenationes are replaced with Maxout (competitive dense blocks) + * Global skip connections are fused by Maxout (global competition) + * Loss Function (weighted Cross-Entropy and dice loss) + """ + def __init__(self, params, padded_size=256): + num_c = params["num_channels"] + params["num_channels"] = params["num_filters_interpol"] + super(FastSurferVINN, self).__init__(params) + + # Flex options + self.height = params['height'] + self.width = params['width'] + + self.out_tensor_shape = tuple(params.get('out_tensor_' + k, padded_size) for k in ['width', 'height']) + + self.interpolation_mode = params['interpolation_mode'] if 'interpolation_mode' in params else 'bilinear' + self.crop_position = params['crop_position'] if 'crop_position' in params else 'top_left' + + # Reset input channels to original number (overwritten in super call) + params["num_channels"] = num_c + + self.inp_block = sm.InputDenseBlock(params) + + params['num_channels'] = params['num_filters'] + params['num_filters_interpol'] + self.outp_block = sm.OutputDenseBlock(params) + + self.interpol1 = il.Zoom2d((self.width, self.height), + interpolation_mode=self.interpolation_mode, + crop_position=self.crop_position) + + self.interpol2 = il.Zoom2d(self.out_tensor_shape, + interpolation_mode=self.interpolation_mode, + crop_position=self.crop_position) + + # Classifier logits options + params['num_channels'] = params['num_filters'] + self.classifier = sm.ClassifierBlock(params) + + # Code for Network Initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def encoder_decoder(self, x, scale_factor, scale_factor_out=None): + + # Input block + Flex to 1 mm + skip_encoder_0 = self.inp_block(x) + encoder_output0, rescale_factor = self.interpol1(skip_encoder_0, scale_factor) + + # CCNet Base + decoder_output1 = super().forward(encoder_output0, scale_factor=scale_factor) + + # Flex to original res + if scale_factor_out is None: + scale_factor_out = rescale_factor + else: + scale_factor_out = np.asarray(scale_factor_out) * np.asarray(rescale_factor) / np.asarray(scale_factor) + + #prior_target_shape = self.interpol2.target_shape + self.interpol2.target_shape = skip_encoder_0.shape[2:] + + # try: + decoder_output0, _ = self.interpol2(decoder_output1, scale_factor_out, rescale=True) + # finally: # TODO: this should catch an error + # self.interpol2.target_shape = prior_target_shape + + return self.outp_block(decoder_output0, skip_encoder_0) + + def forward(self, x, scale_factor, scale_factor_out=None): + """ + Computational graph + :param tensor x: input image + :return tensor: prediction logits + """ + outblock = self.encoder_decoder(x, scale_factor, scale_factor_out) + + # Final logits layer + logits = self.classifier.forward(outblock) # 1x1 convolution + + return logits + + +class FastSurferPaint(FastSurferVINN): + """ + Network Definition of Fully Competitive Network + * Spatial view aggregation (input 7 slices of which only middle one gets segmented) + * Same Number of filters per layer (normally 64) + * Dense Connections in blocks + * Unpooling instead of transpose convolutions + * Concatenationes are replaced with Maxout (competitive dense blocks) + * Global skip connections are fused by Maxout (global competition) + * Loss Function (weighted Cross-Entropy and dice loss) + """ + def __init__(self, params, padded_size=256): + super().__init__(params, padded_size) + + assert(params['num_classes'] > 0), "Number of classes must be > 0" + + # Modify params for intensity layer + params['num_classes'] = 1 + params['num_channels'] = params['num_filters'] + self.intensity_layer = sm.ClassifierBlock(params) + + def forward(self, x, scale_factor, scale_factor_out=None): + """ + Computational graph + :param tensor x: input image + :return tensor: prediction logits + """ + outblock = self.encoder_decoder(x, scale_factor, scale_factor_out) + + # Generate intensity image + intensity_image = torch.clamp(self.intensity_layer(outblock), 0, 1) #torch.sigmoid(self.intensity_layer(outblock)) + + # Final logits layer + logits = self.classifier.forward(outblock) # 1x1 convolution + + return logits, intensity_image + + +class FastSurferLocalisation(FastSurferVINN): + def __init__(self, params, padded_size=256): + # temporarily save num_channels and num_classes + num_channels = params["num_channels"] + num_classes = params['num_classes'] + + + super(FastSurferLocalisation, self).__init__(params, padded_size=padded_size) + + + # Modify params for classifier to output another channel for localisation + params["num_channels"] = params["num_filters"] + params['num_classes'] += 1 + self.classifier = sm.ClassifierBlock(params) + + # reset num_channels and num_classes + params['num_classes'] = num_classes + params["num_channels"] = num_channels + + + def forward(self, x, scale_factor, scale_factor_out=None): + result = super().forward(x, scale_factor, scale_factor_out) + return result + +_MODELS = { + "CCNet": CCNet, + "FastSurferVINN": FastSurferVINN, + "FastSurferPaint": FastSurferPaint, + "FastSurferLocalisation": FastSurferLocalisation +} + + +def build_model(cfg): + assert(cfg.MODEL.MODEL_NAME in _MODELS.keys()), f"Model {cfg.MODEL.MODEL_NAME} not supported" + params = {k.lower(): v for k, v in dict(cfg.MODEL).items()} + model = _MODELS[cfg.MODEL.MODEL_NAME](params, padded_size=cfg.DATA.PADDED_SIZE[0]) + return model diff --git a/CCNet/run_model.py b/CCNet/run_model.py new file mode 100644 index 00000000..6f49b87c --- /dev/null +++ b/CCNet/run_model.py @@ -0,0 +1,87 @@ + +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +from os.path import join +import sys +import argparse +import json + +from FastSurferCNN.utils import misc +from FastSurferCNN.utils.load_config import get_config +from CCNet.train import CCNetTrainer +from FastSurferCNN.config.global_var import get_class_names + +def setup_options(): + # Training settings + parser = argparse.ArgumentParser(description='Segmentation') + + parser.add_argument( + "--cfg", + dest="cfg_file", + help="Path to the config file", + default="config/FastSurferVINN.yaml", + type=str, + ) + parser.add_argument("--aug", action='append', help="List of augmentations to use.", default=None) + + parser.add_argument("--opt", action='append', help="List of augmentations to use.") + + parser.add_argument( + "opts", + help="See CCNet/config/defaults.py for all options", + default=None, + nargs=argparse.REMAINDER, + ) + + if len(sys.argv) == 1: + parser.print_help() + return parser.parse_args() + + + + +def main(): + args = setup_options() + cfg = get_config(args) + + if args.aug is not None: + cfg.DATA.AUG = args.aug + + if args.opt: + cfg.DATA.CLASS_OPTIONS = args.opt + + summary_path = misc.check_path(join(cfg.LOG_DIR, 'summary')) + if cfg.EXPR_NUM == "Default": + cfg.EXPR_NUM = str(misc.find_latest_experiment(join(cfg.LOG_DIR, 'summary')) + 1) + + if cfg.TRAIN.RESUME and cfg.TRAIN.RESUME_EXPR_NUM != "Default": + cfg.EXPR_NUM = cfg.TRAIN.RESUME_EXPR_NUM + + cfg.SUMMARY_PATH = misc.check_path(join(summary_path, '{}'.format(cfg.EXPR_NUM))) + cfg.CONFIG_LOG_PATH = misc.check_path(join(cfg.LOG_DIR, "config", '{}'.format(cfg.EXPR_NUM))) + + with open(join(cfg.CONFIG_LOG_PATH, "config.yaml"), "w") as json_file: + json.dump(cfg, json_file, indent=2) + + if cfg.MODEL.NUM_CLASSES == 0: + cfg.MODEL.NUM_CLASSES = len(get_class_names(cfg.DATA.PLANE, cfg.DATA.CLASS_OPTIONS))+1 + + trainer = CCNetTrainer(cfg=cfg) + trainer.run() + + +if __name__ == '__main__': + main() diff --git a/CCNet/run_prediction.py b/CCNet/run_prediction.py new file mode 100644 index 00000000..f8dad48f --- /dev/null +++ b/CCNet/run_prediction.py @@ -0,0 +1,533 @@ +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +from typing import Tuple, Union, Literal, Dict, Any, Optional +import glob +import sys +import argparse +import os +import copy + +import numpy as np +import torch +import nibabel as nib +import nibabel.processing + +from CCNet.inference import Inference +from FastSurferCNN.utils import logging +from CCNet.utils import parser_defaults +from CCNet.utils.checkpoint import get_checkpoints, VINN_AXI, VINN_COR, VINN_SAG +from FastSurferCNN.utils.load_config import load_config +from CCNet.utils.misc import find_device, calculate_centers_of_comissures +from FastSurferCNN.data_loader import data_utils as du, conform as conf +from FastSurferCNN.quick_qc import check_volume +import FastSurferCNN.reduce_to_aseg as rta +from FastSurferCNN.config.global_var import get_class_names + +# Set up logging +from FastSurferCNN.utils.logging import setup_logging + +## +# Global Variables +## + +LOGGER = logging.getLogger(__name__) + + +## +# Processing +## +def set_up_cfgs(cfg, args): + cfg = load_config(cfg) + cfg.OUT_LOG_DIR = args.out_dir if args.out_dir is not None else cfg.LOG_DIR + cfg.OUT_LOG_NAME = "fastsurfer" + cfg.TEST.BATCH_SIZE = args.batch_size + + cfg.MODEL.OUT_TENSOR_WIDTH = cfg.DATA.PADDED_SIZE + cfg.MODEL.OUT_TENSOR_HEIGHT = cfg.DATA.PADDED_SIZE + + if cfg.MODEL.NUM_CLASSES == 0: + cfg.MODEL.NUM_CLASSES = len(get_class_names(cfg.DATA.PLANE, cfg.DATA.CLASS_OPTIONS))+1 + return cfg + + +def args2cfg(args: argparse.Namespace): + """ + Extract the configuration objects from the arguments. + """ + cfg_cor = set_up_cfgs(args.cfg_cor, args) if args.cfg_cor is not None else None + cfg_sag = set_up_cfgs(args.cfg_sag, args) if args.cfg_sag is not None else None + cfg_ax = set_up_cfgs(args.cfg_ax, args) if args.cfg_ax is not None else None + cfg_fin = cfg_cor if cfg_cor is not None else cfg_sag if cfg_sag is not None else cfg_ax + return cfg_fin, cfg_cor, cfg_sag, cfg_ax + + +def removesuffix(string, suffix): + import sys + if sys.version_info.minor >= 9: + # removesuffix is a Python3.9 feature + return string.removesuffix(suffix) + else: + return string[:-len(suffix)] if string.endswith(suffix) else string + + +## +# Input array preparation +## +class RunModelOnData: + pred_name: str + conf_name: str + orig_name: str + vox_size: Union[float, Literal["min"]] + current_plane: str + models: Dict[str, Inference] + view_ops: Dict[str, Dict[str, Any]] + conform_to_1mm_threshold: Optional[float] + + def __init__(self, args): + self.pred_name = args.pred_name + self.conf_name = args.conf_name + self.orig_name = args.orig_name + self.remove_suffix = args.remove_suffix + + self.sf = 1.0 + self.out_dir = self.set_and_create_outdir(args.out_dir) + + device = find_device(args.device) + + if args.viewagg_device == "auto": + # check, if GPU is big enough to run view agg on it + # (this currently takes the memory of the passed device) + if device.type == "cuda" and torch.cuda.is_available(): # TODO update the automatic device selection + dev_num = torch.cuda.current_device() if device.index is None else device.index + total_gpu_memory = torch.cuda.get_device_properties(dev_num).__getattribute__("total_memory") + # TODO this rule here should include the batch_size ?! + self.viewagg_device = device if total_gpu_memory > 4000000000 else "cpu" + else: + self.viewagg_device = "cpu" + + else: + try: + self.viewagg_device = torch.device(args.viewagg_device) + except: + LOGGER.exception(f"Invalid device {args.viewagg_device}") + raise + # run view agg on the cpu (either if the available GPU RAM is not big enough (<8 GB), + # or if the model is anyhow run on cpu) + + LOGGER.info(f"Running view aggregation on {self.viewagg_device}") + + try: + self.lut = du.read_classes_from_lut(args.lut) + except FileNotFoundError as e: + raise ValueError( + f"Could not find the ColorLUT in {args.lut}, please make sure the --lut argument is valid.") + self.labels = self.lut["ID"].values + self.torch_labels = torch.from_numpy(self.lut["ID"].values) + self.names = ["SubjectName", "Average", "Subcortical", "Cortical"] + self.cfg_fin, cfg_cor, cfg_sag, cfg_ax = args2cfg(args) + # the order in this dictionary dictates the order in the view aggregation + self.view_ops = {"coronal": {"cfg": cfg_cor, "ckpt": args.ckpt_cor}, + "sagittal": {"cfg": cfg_sag, "ckpt": args.ckpt_sag}, + "axial": {"cfg": cfg_ax, "ckpt": args.ckpt_ax}} + self.num_classes = max(view["cfg"].MODEL.NUM_CLASSES for view in self.view_ops.values()) + self.does_inpainting = any(view["cfg"].MODEL.MODEL_NAME == 'FastSurferPaint' for view in self.view_ops.values()) + + self.does_localisation = {view : self.view_ops[view]["cfg"].MODEL.MODEL_NAME == 'FastSurferLocalisation' for view in self.view_ops} + + self.models = {} + for plane, view in self.view_ops.items(): + if view["cfg"] is not None and view["ckpt"] is not None: + self.models[plane] = Inference(view["cfg"], ckpt=view["ckpt"], device=device, inference_weights={"sagittal": args.inferece_weights[0], "coronal": args.inferece_weights[1], "axial": args.inferece_weights[2]}) + + vox_size = args.vox_size + if vox_size == "min": + self.vox_size = "min" + elif 0. < float(vox_size) <= 1.: + self.vox_size = float(vox_size) + else: + raise ValueError(f"Invalid value for vox_size, must be between 0 and 1 or 'min', was {vox_size}.") + self.conform_to_1mm_threshold = args.conform_to_1mm_threshold + + def set_and_create_outdir(self, out_dir: str) -> str: + if os.path.isabs(self.pred_name): + # Full path defined for input image, extract out_dir from there + tmp = os.path.dirname(self.pred_name) + + # remove potential subject/mri doubling in outdir name + out_dir = tmp if os.path.basename(tmp) != "mri" else os.path.dirname(os.path.dirname(tmp)) + LOGGER.info("Output will be stored in: {}".format(out_dir)) + + if not os.path.exists(out_dir): + LOGGER.info("Output directory does not exist. Creating it now...") + os.makedirs(out_dir) + return out_dir + + def get_out_dir(self) -> str: + return self.out_dir + + + def conform_and_save_orig(self, orig_str: str) -> Tuple[nib.analyze.SpatialImage, np.ndarray]: + orig, orig_data = self.get_img(orig_str) + + # Save input image to standard location + self.save_img(self.input_img_name, orig_data, orig) + + # if not conf.is_conform(orig, conform_vox_size=self.vox_size, check_dtype=True, verbose=False, + # conform_to_1mm_threshold=self.conform_to_1mm_threshold): + # LOGGER.info("Conforming image") + # orig = conf.conform(orig, + # conform_vox_size=self.vox_size, conform_to_1mm_threshold=self.conform_to_1mm_threshold) + # orig_data = np.asanyarray(orig.dataobj) + + # Save conformed input image + self.save_img(self.subject_conf_name, orig_data, orig, dtype=np.uint8) + + return orig, orig_data + + def set_subject(self, subject: str, sid: Union[str, None]): + self.subject_name = os.path.basename(removesuffix(subject, self.remove_suffix)) if sid is None else sid + self.subject_conf_name = os.path.join(self.out_dir, self.subject_name, self.conf_name) + self.input_img_name = os.path.join(self.out_dir, self.subject_name, + os.path.dirname(self.conf_name), 'orig', '001.mgz') + + def get_subject_name(self) -> str: + return self.subject_name + + def set_model(self, plane: str): + self.current_plane = plane + + def get_prediction(self, orig_f: str, orig_data: np.ndarray, zoom: Union[np.ndarray, tuple], lesion_mask=None, crop : bool = False, extra_slices : int = 0) -> np.ndarray: + + if not crop: + shape = orig_data.shape + (self.get_num_classes(),) + shape_loc = orig_data.shape + else: + shape = (orig_data.shape[0]-extra_slices, orig_data.shape[1], orig_data.shape[2],self.get_num_classes()) + shape_loc = (orig_data.shape[0]-extra_slices, orig_data.shape[1], orig_data.shape[2]) + + + kwargs = { + "device": self.viewagg_device, + "dtype": torch.float16, + "requires_grad": False + } + + if self.does_inpainting: + assert(lesion_mask is not None), "Inpainting requires a lesion mask, mismatch between chosen model and input data" + + pred_prob = torch.zeros(shape, **kwargs) + loc_map = torch.zeros(shape_loc, **kwargs) if any(self.does_localisation.values()) else None + + + # inference and view aggregation + for plane, model in self.models.items(): + LOGGER.info(f"Run {plane} prediction") + self.set_model(plane) + # pred_prob is updated inplace to conserve memory + if self.does_localisation[plane]: + + pred_prob, loc_map = model.run(init_pred=pred_prob, img_filename=orig_f, orig_data=orig_data, lesion_mask=lesion_mask, orig_zoom=zoom, out=pred_prob, sdir=f'{self.get_out_dir()}/{self.get_subject_name()}/', pad= not crop) + + elif self.does_inpainting: + pred_prob, inpainted_volume = pred_prob, None + else: + if not crop: + pred_prob = model.run(init_pred = pred_prob, img_filename=orig_f, orig_data=orig_data, lesion_mask=lesion_mask, orig_zoom=zoom, out=pred_prob, pad=True) + else: + pred_prob = model.run(init_pred = pred_prob, img_filename=orig_f, orig_data=orig_data[extra_slices//2:orig_data.shape[0]-extra_slices//2,:,:], lesion_mask=lesion_mask, orig_zoom=zoom, out=pred_prob, pad=True) + + # Get hard predictions + pred_classes = torch.argmax(pred_prob, 3) + del pred_prob + # map to freesurfer label space + pred_classes = du.map_label2aparc_aseg(pred_classes, self.labels) + # return numpy array TODO: split_cortex_labels requires a numpy ndarray input + pred_classes = du.split_cortex_labels(pred_classes.cpu().numpy()) + + # get AC and PC from loc_map + if any(self.does_localisation.values()): + # remove last dimension + loc_map = loc_map.cpu().squeeze() + # calculate AC and PC from loc_map + orig, orig_data = self.get_img(orig_f) + self.save_img(os.path.join(self.out_dir, self.subject_name, 'mri', 'loc_map.mgz'), loc_map, orig, dtype=np.float32) + + ac_pc = calculate_centers_of_comissures(loc_map) + + LOGGER.debug(f"AC: {ac_pc[0]}, PC: {ac_pc[1]}") + + if self.does_inpainting: + return pred_classes, inpainted_volume + elif loc_map is not None: + return pred_classes, ac_pc + else: + return pred_classes + + @staticmethod + def get_img(filename: Union[str, os.PathLike]) -> Tuple[nib.analyze.SpatialImage, np.ndarray]: + img = nib.load(filename) + data = np.asanyarray(img.dataobj) + + return img, data + + @staticmethod + def save_img(save_as: str, data: Union[np.ndarray, torch.Tensor], orig: nib.analyze.SpatialImage, + dtype: Union[None, type] = None): + # Create output directory if it does not already exist. + if not os.path.exists(os.path.dirname(save_as)): + LOGGER.info("Output image directory does not exist. Creating it now...") + os.makedirs(os.path.dirname(save_as)) + np_data = data if isinstance(data, np.ndarray) else data.cpu().numpy() + + if dtype is not None: + header = orig.header.copy() + header.set_data_dtype(dtype) + else: + header = orig.header + + du.save_image(header, orig.affine, np_data, save_as, dtype=dtype) + LOGGER.info("Successfully saved image as {}".format(save_as)) + + def set_up_model_params(self, plane, cfg, ckpt): + self.view_ops[plane]["cfg"] = cfg + self.view_ops[plane]["ckpt"] = ckpt + + def get_num_classes(self) -> int: + return self.num_classes + + +def handle_cuda_memory_exception(exception: RuntimeError, exit_on_out_of_memory: bool = True) -> bool: + if not isinstance(exception, RuntimeError): + return False + message = exception.args[0] + if message.startswith("CUDA out of memory. "): + LOGGER.critical("ERROR - INSUFFICIENT GPU MEMORY") + LOGGER.info("The memory requirements exceeds the available GPU memory, try using a smaller batch size " + "(--batch_size ) and/or view aggregation on the cpu (--viewagg_device 'cpu')." + "Note: View Aggregation on the GPU is particularly memory-hungry at approx. 5 GB for standard " + "256x256x256 images.") + memory_message = message[message.find("(") + 1:message.find(")")] + LOGGER.info(f"Using {memory_message}.") + if exit_on_out_of_memory: + sys.exit("----------------------------\nERROR: INSUFFICIENT GPU MEMORY\n") + else: + return True + else: + return False + + +def main(): + parser = argparse.ArgumentParser(description='Evaluation metrics') + + # 1. Options for input directories and filenames + parser = parser_defaults.add_arguments(parser, ["t1", "sid", "in_dir", "tag", "csv_file", "lut", "remove_suffix", "crop"]) + + # 2. Options for output + parser = parser_defaults.add_arguments(parser, ["asegdkt_segfile", "conformed_name", "brainmask_name", + "aseg_name", "sd", "seg_log", "qc_log"]) + + # 3. Checkpoint to load + parser = parser_defaults.add_plane_flags(parser, "checkpoint", + {"coronal": VINN_COR, "axial": VINN_AXI, "sagittal": VINN_SAG}) + + # 4. CFG-file with default options for network + parser = parser_defaults.add_plane_flags(parser, "config", + {"coronal": "CCNet/config/FastSurferVINN_coronal.yaml", + "axial": "CCNet/config/FastSurferVINN_axial.yaml", + "sagittal": "CCNet/config/FastSurferVINN_sagittal.yaml"}) + + # 5. technical parameters + parser = parser_defaults.add_arguments(parser, ["vox_size", "conform_to_1mm_threshold", "device", "viewagg_device", + "batch_size", "allow_root"]) + + parser.add_argument('--inferece_weights', nargs=3, type=float, default=[0.6,0.2,0.2], help='Weights for the inference of the three planes') + + args = parser.parse_args() + + # Warning if run as root user + if not args.allow_root and os.name == 'posix' and os.getuid() == 0: + sys.exit( + """---------------------------- + ERROR: You are trying to run 'run_prediction.py' as root. We advice to avoid running + FastSurfer as root, because it will lead to files and folders created as root. + If you are running FastSurfer in a docker container, you can specify the user with + '-u $(id -u):$(id -g)' (see https://docs.docker.com/engine/reference/run/#user). + If you want to force running as root, you may pass --allow_root to run_prediction.py. + """) + + # Check input and output options + if args.in_dir is None and args.csv_file is None and not os.path.isfile(args.orig_name): + parser.print_help(sys.stderr) + sys.exit( + '----------------------------\nERROR: Please specify data input directory or full path to input volume\n') + + if args.out_dir is None and not os.path.isabs(args.pred_name): + parser.print_help(sys.stderr) + sys.exit( + '----------------------------\nERROR: Please specify data output directory or absolute path to output volume' + ' (can be same as input directory)\n') + + qc_file_handle = None + if args.qc_log != "": + try: + qc_file_handle = open(args.qc_log, 'w') + except NotADirectoryError: + LOGGER.warning("The directory in the provided QC log file path does not exist!") + LOGGER.warning("The QC log file will not be saved.") + + + setup_logging(args.log_name) + + + # Download checkpoints if they do not exist + # see utils/checkpoint.py for default paths + LOGGER.info("Checking or downloading default checkpoints ...") + print("Checkpoints: " +args.ckpt_ax, args.ckpt_cor, args.ckpt_sag) + print("Working directory: " + os.getcwd()) + get_checkpoints(args.ckpt_ax, args.ckpt_cor, args.ckpt_sag) + + # Set Up Model + eval = RunModelOnData(args) + + # Get all subjects of interest + if args.csv_file is not None: + with open(args.csv_file, "r") as s_dirs: + s_dirs = [line.strip() for line in s_dirs.readlines()] + LOGGER.info("Analyzing all {} subjects from csv_file {}".format(len(s_dirs), args.csv_file)) + + elif args.in_dir is not None: + s_dirs = glob.glob(os.path.join(args.in_dir, args.search_tag)) + LOGGER.info("Analyzing all {} subjects from in_dir {}".format(len(s_dirs), args.in_dir)) + + else: + s_dirs = [os.path.dirname(args.orig_name)] + LOGGER.info("Analyzing single subject {}".format(args.orig_name)) + + qc_failed_subject_count = 0 + + for subject in s_dirs: + # Set subject and load orig + eval.set_subject(subject, args.sid) + orig_fn = args.orig_name if os.path.isfile(args.orig_name) else os.path.join(subject, args.orig_name) + orig_img, data_array = eval.conform_and_save_orig(orig_fn) + + original_size = orig_img.header.get_data_shape() + # TODO: Align image to midplan + # We currently assume the image is already aligned to the midplane of 128. + # Assumes midplane ON sagittal voxel plane 128 (not in between) + midplane = 128 + if args.crop: + volume_thickness = np.ceil((5/orig_img.header.get_zooms()[0])) + if volume_thickness % 2 == 0: + volume_thickness += 1 + offset = midplane - volume_thickness//2 + extra_slices = eval.view_ops['sagittal']['cfg'].MODEL.NUM_CHANNELS//2*2 + volume_thickness += extra_slices + data_array = du.subset_volume_plane(data_array, plane=midplane, thickness=volume_thickness, axis=0) + else: + # already cropped + offset = midplane-original_size[0]//2 + extra_slices = 0 + + lesion_mask = None + + # Set prediction name + out_dir, sbj_name = eval.get_out_dir(), eval.get_subject_name() + pred_name = args.pred_name if os.path.isabs(args.pred_name) else \ + os.path.join(out_dir, sbj_name, args.pred_name) + + # Run model + try: + # orig_f: str, orig_data: np.ndarray, zoom: Union[np.ndarray, tuple] + pred_data = eval.get_prediction(orig_f=orig_fn, orig_data=data_array, zoom=orig_img.header.get_zooms(), lesion_mask=lesion_mask, crop = args.crop, extra_slices=extra_slices) + if eval.does_inpainting: + pred_data, inpainted_volume = pred_data + + inpainted_volume = inpainted_volume.cpu().numpy() + inpainted_volume = np.round(inpainted_volume * 255).astype(np.uint8) + + # Save inpainted volume + eval.save_img(os.path.join(eval.out_dir, eval.subject_name, 'mri', 'inpainted_network_output.mgz'), inpainted_volume, orig_img, dtype=np.uint8) + + inpainted_volume[~lesion_mask] = data_array[~lesion_mask] + eval.save_img(os.path.join(eval.out_dir, eval.subject_name, 'mri', 'inpainted.mgz'), inpainted_volume, orig_img, dtype=np.uint8) + elif np.array([value for value in eval.does_localisation.values()]).any(): + pred_data, ac_pc = pred_data + # adjust to original image size + ac = ac_pc[0] + ac[0]+=offset + pc = ac_pc[1] + pc[0]+=offset + + points_file = os.path.join(eval.out_dir, eval.subject_name, 'mri', 'points.txt') + if not os.path.exists(points_file): + os.makedirs(os.path.dirname(points_file), exist_ok=True) + with open(points_file, 'w') as f: + f.write(f'Type,r,a,s\n') + f.write(f'PC,{pc[0]},{pc[1]},{pc[2]}\n') + f.write(f'AC,{ac[0]},{ac[1]},{ac[2]}') + + + # todo: pad to original size if crop + + if args.crop: + pred_data = du.pad_to_size(pred_data, original_size, plane = 128, axis = 0) + + eval.save_img(pred_name, pred_data, orig_img, dtype=np.int16) + +# Create aseg and brainmask + # Change datatype to np.uint8, else mri_cc will fail! + + # # get mask + # LOGGER.info("Creating brainmask based on segmentation...") + # bm = rta.create_mask(copy.deepcopy(pred_data), 5, 4) + # mask_name = os.path.join(out_dir, sbj_name, args.brainmask_name) + # eval.save_img(mask_name, bm, orig_img, dtype=np.uint8) + + # reduce aparc to aseg and mask regions + # eval.save_img(aseg_name, aseg, orig_img, dtype=np.uint8) + + # Run QC check + # LOGGER.info("Running volume-based QC check on segmentation...") + # seg_voxvol = np.product(orig_img.header.get_zooms()) + # if not check_volume(pred_data, seg_voxvol): + # LOGGER.warning("Total segmentation volume is too small. Segmentation may be corrupted.") + # if qc_file_handle is not None: + # qc_file_handle.write(subject.split('/')[-1] + "\n") + # qc_file_handle.flush() + # qc_failed_subject_count += 1 + except RuntimeError as e: + if not handle_cuda_memory_exception(e): + raise e + + if qc_file_handle is not None: + qc_file_handle.close() + + # Single case: exit with error if qc fails. Batch case: report ratio of failures. + if len(s_dirs) == 1: + if qc_failed_subject_count: + LOGGER.error("Single subject failed the volume-based QC check.") + sys.exit(1) + else: + LOGGER.info("Segmentations from {} out of {} processed cases failed the volume-based QC check.".format( + qc_failed_subject_count, len(s_dirs))) + + #sys.exit(0) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/CCNet/train.py b/CCNet/train.py new file mode 100644 index 00000000..40de8bca --- /dev/null +++ b/CCNet/train.py @@ -0,0 +1,548 @@ + +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +import pprint +import subprocess +import time +import os +from networkx import center +#import concurrent.futures +#import multiprocessing # TODO: replace this with concurrent.futures when python 3.8 is phased out +#from collections import defaultdict +#import sys + +import torch +import ignite # needed for ignite.metrics +from ignite.metrics import Metric +import ignite.metrics +from torch.utils.tensorboard import SummaryWriter +#from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM + + +import numpy as np +from tqdm import tqdm +import random +from tqdm.contrib.logging import logging_redirect_tqdm + +from CCNet.data_loader import loader +from CCNet.models.networks import build_model +from FastSurferCNN.models.optimizer import get_optimizer +from CCNet.models.losses import get_loss_func, SSIMLoss, GradientLoss, MSELoss +from FastSurferCNN.utils import logging, checkpoint as cp +from FastSurferCNN.utils.lr_scheduler import get_lr_scheduler +from CCNet.utils.meters import Meter, DiceScore, LocDistance +#from CCNet.utils.metrics import iou_score, precision_recall +from CCNet.utils.misc import update_num_steps, plot_predictions, plot_predictions_inpainting, calculate_centers_of_comissures +from FastSurferCNN.config.global_var import get_class_names + +logger = logging.getLogger(__name__) + + +class CCNetTrainer: + def __init__(self, cfg): + # Set random seed from configs. + #np.random.seed(cfg.RNG_SEED) + #torch.manual_seed(cfg.RNG_SEED) + + # self.set_determinism(cfg.RNG_SEED) + self.cfg = cfg + + # Create the checkpoint dir. + self.checkpoint_dir = cp.create_checkpoint_dir(cfg.LOG_DIR, cfg.EXPR_NUM) + logging.setup_logging(os.path.join(cfg.LOG_DIR, "logs", cfg.EXPR_NUM + ".log")) + logger.setLevel(cfg.LOG_LEVEL) + logger.info("New training run with config:") + logger.info(pprint.pformat(cfg)) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = build_model(cfg) + self.loss_func = get_loss_func(cfg) + + + #assert(16 % cfg.TRAIN.BATCH_SIZE == 0), "Batch size must be a divisor of 16" + + # set up class names + self.class_names = get_class_names(cfg.DATA.PLANE, cfg.DATA.CLASS_OPTIONS) + + # Set up logger format + self.format_placeholder_classes = "{}\t" * (cfg.MODEL.NUM_CLASSES - 2) + "{}" + self.num_classes = cfg.MODEL.NUM_CLASSES + self.plot_dir = os.path.join(cfg.LOG_DIR, "pred", str(cfg.EXPR_NUM)) + os.makedirs(self.plot_dir, exist_ok=True) + + #self.subepoch = False if self.cfg.TRAIN.BATCH_SIZE == 16 else True + #self.subepoch = False + + self.has_cutout = 'Cutout' in self.cfg.DATA.AUG or any('cutout' in i.lower() for i in self.cfg.DATA.AUG) + + if self.has_cutout: + self.ssim_loss = SSIMLoss() + #self.gradient_loss = GradientLoss() + #self.mse = MSELoss(reduction='mean') + self.mae = torch.nn.L1Loss(reduction='mean') + + self.inpainting_loss = lambda pred, orig: self.ssim_loss(pred, orig) #+ self.gradient_loss(pred, orig) + self.mse(pred) + self.inpainting_loss_mask = lambda pred, orig, mask: self.mae(pred[mask], orig[mask]) + self.mae(pred[~mask], torch.zeros_like(pred[~mask])) #+ self.ssim_loss(pred, orig, mask) #+ self.gradient_loss(pred, orig, mask) + self.mse(pred, orig, mask) + + def run_epoch(self, train: bool, data_loader, meter: Meter, epoch: int, log_name: str, optimizer: torch.optim.Optimizer = None, scheduler=None): + meter.reset() + + if train: + self.model.train() + else: + self.model.eval() # TODO: add nograd + + logger.info(f"{log_name} epoch started ") + epoch_start = time.time() + loss_batch = torch.zeros(1, device=self.device) + + with torch.set_grad_enabled(train), logging_redirect_tqdm(): #, multiprocessing.Pool() as process_executor: + for curr_iter, batch in tqdm(enumerate(data_loader), total=len(data_loader)): + + images, labels, weights, scale_factors, cutout_mask = batch['image'].to(self.device), \ + batch['label'].to(self.device), \ + batch['weight'].float().to(self.device), \ + batch['scale_factor'], \ + batch['cutout_mask'] + + loc_loss = None + mse_loss = None + loss_seg = None + + cutout_mask_midslice = cutout_mask[:, cutout_mask.shape[1]//2] # get middle slice + + if train: # and (not self.subepoch or (curr_iter)%(16/self.cfg.TRAIN.BATCH_SIZE) == 0): + optimizer.zero_grad() # every second epoch to get batchsize of 16 if using 8 + + if self.cfg.MODEL.MODEL_NAME == 'FastSurferPaint': + + cutout_mask_empty = (cutout_mask_midslice == 0).all() + + network_input = torch.concat([images, cutout_mask.to(self.device)], dim=1) + pred = self.model(network_input, scale_factors) + + orig_slice = batch['unmodified_center_slice'].to(self.device) + pred, estimated_slice = pred + estimated_slice = estimated_slice.squeeze(1) + loss_total, loss_dice, loss_ce = self.loss_func(pred, labels, weights) + + #inpainting_loss = self.ssim_loss(estimated_slice, orig_slice, cutout_mask if not cutout_mask_empty else None) + + if not cutout_mask_empty: + inpainting_loss = self.inpainting_loss_mask(estimated_slice, orig_slice, cutout_mask_midslice) + self.update_metrics(meter.ignite_metrics, pred, labels, pred_slice=estimated_slice, orig_slice=orig_slice, cutout_mask=cutout_mask_midslice) + else: + inpainting_loss = 0 + #inpainting_loss = self.inpainting_loss(estimated_slice, orig_slice) + self.update_metrics(meter.ignite_metrics, pred, labels) + + #assert(cutout_mask_empty == (torch.sum(cutout_mask_midslice) == 0)), 'cutout mask is not empty but sum is 0' + + loss_total = loss_total * (1 - self.cfg.INPAINT_LOSS_WEIGHT) + (self.cfg.INPAINT_LOSS_WEIGHT * inpainting_loss if not cutout_mask_empty else 0) + elif self.cfg.MODEL.MODEL_NAME == 'FastSurferLocalisation': + + pred = self.model(images, scale_factors) + loss_total, loss_seg, loss_dice, loss_ce, loc_loss, mse_loss = self.loss_func(pred, labels, weights) + # metrics only on classification + self.update_metrics(meter.ignite_metrics, pred, labels) + #logger.debug(f'Losses: loss_total: {loss_total}, loss_dice: {loss_dice}, loss_ce: {loss_ce}, seg_loss: {seg_loss}, loc_loss: {loc_loss}') + else: + pred = self.model(images, scale_factors) + loss_total, loss_dice, loss_ce = self.loss_func(pred, labels, weights) + self.update_metrics(meter.ignite_metrics, pred, labels) + + if torch.isnan(loss_total): + logger.info('loss is nan - stopping training and starting debugging') + #import pdb; pdb.set_trace() + + meter.update_stats(pred, labels, loss_total) + meter.log_iter(curr_iter, epoch) + + meter.write_summary(loss_total, + lr = scheduler.get_last_lr() if scheduler is not None else [self.cfg.OPTIMIZER.BASE_LR], + loss_ce = loss_ce, + loss_dice = loss_dice, + loss_seg = loss_seg, + loc_loss = loc_loss, + dist_loss = mse_loss) + + + + + if train: + loss_total.backward() # TODO: should this be loss_batch? + #if not self.subepoch or (curr_iter+1)%(16/self.cfg.TRAIN.BATCH_SIZE) == 0: + optimizer.step() # every second epoch to get batchsize of 16 if using 8 + if scheduler is not None: + scheduler.step(epoch + curr_iter / len(data_loader)) + + loss_batch += loss_total + + # Plot sample predictions + if curr_iter == 1: # try to log cutout images if possible + plt_title = 'Training Results Epoch ' + str(epoch) + file_save_name = os.path.join(self.plot_dir, + 'Epoch_' + str(epoch) + f'_{log_name}_Predictions.pdf') + + logger.debug(f'Plotting {file_save_name}') + + if self.cfg.MODEL.MODEL_NAME == 'FastSurferLocalisation': + logger.debug(f'pred shape: {pred.shape}') + logger.debug(f'labels shape: {labels.shape}') + seg_pred = pred[:,:-1,:,:].detach() + seg_labels = labels[:,:labels.shape[1]//2, :] + logger.debug(f'seg_pred shape: {seg_pred.shape}') + logger.debug(f'seg_labels shape: {seg_labels.shape}') + + else: + seg_pred = pred.detach() + seg_labels = labels + + _, batch_output = torch.max(seg_pred, dim=1) + + logger.debug(f'Unique labels: {np.unique(seg_labels.detach().cpu())}') + logger.debug(f'Unique batch output: {np.unique(batch_output.detach().cpu())}') + + plt_images = images.detach() + + assert (batch_output.cpu().numpy().astype(np.int64) == batch_output.cpu().numpy()).all(), 'batch output is not int' + + assert (seg_labels.cpu().numpy().astype(np.int64) == seg_labels.cpu().numpy()).all(), 'seg labels is not int' + + + + plot_predictions(plt_images, seg_labels.detach().int(), batch_output.int(), plt_title, file_save_name, process_executor=None, lut=self.cfg.DATA.LUT) #process_executor) + + + if self.cfg.MODEL.MODEL_NAME == 'FastSurferPaint': + file_save_name = os.path.join(self.plot_dir, + 'Epoch_' + str(epoch) + f'_{log_name}_Predictions_EstimatedSlice.pdf') + plot_predictions_inpainting(plt_images, orig_slice.detach(), estimated_slice.detach(), plt_title, file_save_name, process_executor=None) #process_executor) + + #process_executor.shutdown(wait=True) + + # temporarily removed + # process_executor.close() + # process_executor.join() + + + meter.log_epoch(epoch, runtime=time.time() - epoch_start) + logger.info("{} epoch {} finished in {:.04f} seconds".format(log_name, epoch, time.time() - epoch_start)) + + if 'IoU' in meter.ignite_metrics.keys(): + mIOU = meter.ignite_metrics['IoU'].compute().mean() + elif 'FastSurfer_dice' in meter.ignite_metrics.keys(): + mIOU = meter.ignite_metrics['FastSurfer_dice'].compute().mean() + else: + mIOU = None + + + return mIOU + + def create_ignite_metrics(self, train=True): + + #device = torch.device("cpu") # confusion matrix is faster on cpu + device = self.device + + ignite_metrics = {} + + if not train: + # ignite_metrics = {'confusion_matrix': ignite.metrics.confusion_matrix.ConfusionMatrix(self.num_classes, device=torch.device('cpu'))} + # ignite_metrics['DICE_ignite'] = ignite.metrics.DiceCoefficient(ignite_metrics['confusion_matrix'], ignore_index=0) # ignore background + # ignite_metrics['IoU'] = ignite.metrics.IoU(ignite_metrics['confusion_matrix'], ignore_index=0) # ignore background + # ignite_metrics['MeanRecall'] = ignite.metrics.Recall(average=True, device=device) + # ignite_metrics['MeanPrecision'] = ignite.metrics.Precision(average=True, device=device) + + + ignite_metrics['FastSurfer_dice'] = DiceScore(self.num_classes, device=device) + else: + ignite_metrics['FastSurfer_dice'] = DiceScore(self.num_classes, device=device) + + if self.has_cutout: + ignite_metrics['PSNR'] = ignite.metrics.PSNR(data_range=255, device=device) # TODO: should be 1? + ignite_metrics['PSNR_inpaint'] = ignite.metrics.PSNR(data_range=255, device=device) + ignite_metrics['SSIM'] = ignite.metrics.SSIM(data_range=255, kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True, device=torch.device('cpu')) # always cpu because of determinism TODO: make flag + ignite_metrics['SSIM_inpaint'] = ignite.metrics.SSIM(data_range=255, kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True, device=torch.device('cpu')) + + if self.cfg.MODEL.MODEL_NAME == 'FastSurferLocalisation': + # Add localisational distance error + ignite_metrics['locational_distance'] = LocDistance(device=device, axis = 2) + + return ignite_metrics + + + def update_metrics(self, ignite_metrics, pred, labels, pred_slice=None, orig_slice=None, cutout_mask=None): + if self.cfg.MODEL.MODEL_NAME == 'FastSurferLocalisation': + if 'locational_distance' in ignite_metrics.keys(): + ignite_metrics['locational_distance'].update(pred, labels) + pred = pred[:,:-1,:,:] + labels = labels[:,:labels.shape[1]//2, :] + + # dice and iou are calculated from confusion matrix + + if 'FastSurfer_dice' in ignite_metrics.keys(): + ignite_metrics['FastSurfer_dice'].update((pred, labels.long())) + + if 'confusion_matrix' in ignite_metrics.keys(): + # NOTE: this will give gibberish if given uint8 tensors + ignite_metrics['confusion_matrix'].update((torch.nn.functional.softmax(pred.detach().to(torch.float32), dim=1).cpu(), labels.detach().to(torch.int64)).cpu()) # this is slow on GPU - fixed in https://github.com/pytorch/pytorch/pull/97090, but we dont have that version yet + + if 'MeanRecall' in ignite_metrics.keys(): + ignite_metrics['MeanRecall'].update((pred, labels.long())) + if 'MeanPrecision' in ignite_metrics.keys(): + ignite_metrics['MeanPrecision'].update((pred, labels.long())) + + if self.has_cutout and pred_slice is not None and orig_slice is not None and cutout_mask.sum() > 0: + if 'PSNR' in ignite_metrics.keys(): + ignite_metrics['PSNR'].update((pred_slice.unsqueeze(1), orig_slice.unsqueeze(1))) + if 'SSIM' in ignite_metrics.keys(): + ignite_metrics['SSIM'].update((pred_slice.unsqueeze(1), orig_slice.unsqueeze(1))) + + + if ('PSNR_inpaint' in ignite_metrics.keys() or 'SSIM_inpaint' in ignite_metrics.keys()) and cutout_mask is not None: + for s in range(pred_slice.shape[0]): # iterate over samples in batch + i, j = np.where(cutout_mask[s]) + if len(i) == 0 or len(j) == 0: + continue + + maxi = np.max(i) + mini = np.min(i) + maxj = np.max(j) + minj = np.min(j) + + pred_slice_masked = pred_slice[s, mini:maxi+1, minj:maxj+1] + orig_slice_masked = orig_slice[s, mini:maxi+1, minj:maxj+1] + + if 'PSNR_inpaint' in ignite_metrics.keys(): + ignite_metrics['PSNR_inpaint'].update((pred_slice_masked[None, None, ...], orig_slice_masked[None, None, ...])) + # if np.isnan(ignite_metrics['PSNR_inpaint'].compute()): + # print('PSNR is nan - stopping training and starting debugging') + if 'SSIM_inpaint' in ignite_metrics.keys(): + try: + ignite_metrics['SSIM_inpaint'].update((pred_slice_masked[None, None, ...], orig_slice_masked[None, None, ...])) + except RuntimeError: + assert(pred_slice_masked.shape == orig_slice_masked.shape) + assert(pred_slice_masked.shape[0] <= 5 or pred_slice_masked.shape[1] <= 5), \ + 'SSIM on cutout region failed - but cutout area was of sufficient size - unexpected error!' + + return True + + + + def create_logging(self, start_epoch, train_loader, val_loader, val_loader2=None, val_loader_inpaint=None, has_cutout=False): + # Create tensorboard summary writer + writer = SummaryWriter(self.cfg.SUMMARY_PATH, flush_secs=15) + + create_meter = lambda name, is_train, data_loader: Meter(self.cfg, + mode=name, + global_step=start_epoch, + total_iter=len(data_loader), + total_epoch=self.cfg.TRAIN.NUM_EPOCHS, + writer=writer, + ignite_metrics=self.create_ignite_metrics(train=is_train)) + + + train_meter = create_meter('train', True, train_loader) + + val_meter = create_meter('val', False, val_loader) + + # if cutout is used, create a separate validation meter for inpainting + if self.has_cutout: + val_meter_inpaint = create_meter('val_inpaint', False, val_loader_inpaint) + else: + val_meter_inpaint = None + + if val_loader2: + val_meter_tumor = create_meter('val_tumor', False, val_loader2) + else: + val_meter_tumor = None + + return train_meter, val_meter, val_meter_inpaint, val_meter_tumor + + def make_data_loaders(self): + train_loader = loader.get_dataloader(self.cfg, "train") + val_loader = loader.get_dataloader(self.cfg, "val") + if self.has_cutout: + val_loader_inpaint = loader.get_dataloader(self.cfg, "val_inpainting", val_dataset=val_loader.dataset) + else: + val_loader_inpaint = None + if self.cfg.DATA.PATH_HDF5_VAL2: + val_loader2 = loader.get_dataloader(self.cfg, "val_tumor", data_path=self.cfg.DATA.PATH_HDF5_VAL2) + else: + val_loader2 = None + + return train_loader, val_loader, val_loader_inpaint, val_loader2 + + def save_code(self, dir_path): + dir_path = os.path.abspath(dir_path) + if not os.path.exists(dir_path): + os.mkdir(dir_path) + # run git ls-files | tar Tczf - mycode.zip + cmd = f'git ls-files | tar Tczf - {os.path.join(dir_path, "repository_code.zip")}' + #subprocess.Popen([cmd], shell=True) # TODO: Fix tar "Cannot stat: No such file or directory" + + + + + def run(self): + if self.cfg.NUM_GPUS > 1: + assert self.cfg.NUM_GPUS <= torch.cuda.device_count(), f"Trying to use {self.cfg.NUM_GPUS} GPUs, but only {torch.cuda.device_count()} are available" + logger.info(f"Using {self.cfg.NUM_GPUS} GPUs!") + self.model = torch.nn.DataParallel(self.model) + + train_loader, val_loader, val_loader_inpaint, val_loader2 = self.make_data_loaders() + + # check loaders + + sample_train = next(iter(train_loader)) + sample_val = next(iter(val_loader)) + assert(sample_train['image'].shape[1:] == sample_val['image'].shape[1:]), 'Train and val loader have different input shapes' + + # TODO: add additional logging for loaders + update_num_steps(train_loader, self.cfg) + + # Transfer the model to device(s) + self.model = self.model.to(self.device) + + optimizer = get_optimizer(self.model, self.cfg) + scheduler = get_lr_scheduler(optimizer, self.cfg) + + checkpoint_paths = cp.get_checkpoint_path(self.cfg.LOG_DIR, + self.cfg.TRAIN.RESUME_EXPR_NUM) + + if self.cfg.TRAIN.RESUME and checkpoint_paths: + try: + checkpoint_path = checkpoint_paths.pop() + checkpoint_epoch, best_metric = cp.load_from_checkpoint( + checkpoint_path, + self.model, + optimizer, + scheduler, + self.cfg.TRAIN.FINE_TUNE + ) + start_epoch = checkpoint_epoch + best_miou = best_metric + logger.info(f"Resume training from epoch {start_epoch}") + except Exception as e: + logger.warning("No model to restore. Resuming training from Epoch 0. {}".format(e)) + else: + logger.info("Training from scratch") + start_epoch = 0 + best_miou = 0 + + logger.info("Saving code to {}".format(self.cfg.LOG_DIR)) + self.save_code(self.cfg.LOG_DIR) + + logger.info("{} parameters in total".format(sum(x.numel() for x in self.model.parameters()))) + + + train_meter, val_meter, val_meter_inpaint, val_meter_tumor = self.create_logging( + start_epoch, train_loader, val_loader, val_loader2, val_loader_inpaint, has_cutout=self.has_cutout) + + + logger.info("Summary path {}".format(self.cfg.SUMMARY_PATH)) + # Perform the training loop. + logger.info("Start epoch: {}".format(start_epoch + 1)) + + # with torch.profiler.profile( + # schedule=torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=0), + # on_trace_ready=torch.profiler.tensorboard_trace_handler(self.cfg.SUMMARY_PATH), + # record_shapes=False, + # profile_memory=False, + # with_stack=True) as prof: #, torch.amp.autocast(enabled=True, device_type=self.device.__str__(), cache_enabled=None): TODO: not supported in prelu? + + + + for epoch in range(start_epoch, self.cfg.TRAIN.NUM_EPOCHS): + #self.train(train_loader, optimizer, scheduler, train_meter, epoch=epoch) + _ = self.run_epoch(train=True, data_loader=train_loader, optimizer=optimizer, scheduler=scheduler, meter=train_meter, epoch=epoch, log_name='Train') + + #miou = self.eval(val_loader, val_meter, epoch=epoch) + miou = self.run_epoch(train=False, data_loader=val_loader, optimizer=None, scheduler=None, meter=val_meter, epoch=epoch, log_name='Validation') + + if val_loader2: + #_ = self.eval(val_loader2, val_meter_tumor, epoch=epoch, log_name='Validation_Tumor') + _ = self.run_epoch(train=False, data_loader=val_loader2, optimizer=None, scheduler=None, meter=val_meter_tumor, epoch=epoch, log_name='Validation_Tumor') + + if self.has_cutout: + #_ = self.eval(val_loader_inpaint, val_meter_inpaint, epoch=epoch, log_name='Validation_Inpainting') + _ = self.run_epoch(train=False, data_loader=val_loader_inpaint, optimizer=None, scheduler=None, meter=val_meter_inpaint, epoch=epoch, log_name='Validation_Inpainting') + + + if (epoch+1) % self.cfg.TRAIN.CHECKPOINT_PERIOD == 0: + logger.info(f"Saving checkpoint at epoch {epoch+1}") + cp.save_checkpoint(self.checkpoint_dir, + epoch+1, + best_miou, + self.cfg.NUM_GPUS, + self.cfg, + self.model, + optimizer, + scheduler + ) + + if miou > best_miou: + best_miou = miou + logger.info(f"New best checkpoint reached at epoch {epoch+1} with miou of {best_miou}\nSaving new best model.") + cp.save_checkpoint(self.checkpoint_dir, + epoch+1, + best_miou, + self.cfg.NUM_GPUS, + self.cfg, + self.model, + optimizer, + scheduler, + best=True + ) + + @staticmethod + def set_determinism( # from monai.utils https://docs.monai.io/en/stable/_modules/monai/utils/misc.html#set_determinism + seed: int = 0, + use_deterministic_algorithms: bool = True) -> None: + """ + Set random seed for modules to enable or disable deterministic training. + + Args: + seed: the random seed to use, default is np.iinfo(np.int32).max. + It is recommended to set a large seed, i.e. a number that has a good balance + of 0 and 1 bits. Avoid having many 0 bits in the seed. + if set to None, will disable deterministic training. + use_deterministic_algorithms: Set whether PyTorch operations must use "deterministic" algorithms. + """ + seed = int(seed) + torch.manual_seed(seed) + + global _seed + _seed = seed + random.seed(seed) + np.random.seed(seed) + + if torch.backends.flags_frozen(): + logger.warn("PyTorch global flag support of backends is disabled, enable it to set global `cudnn` flags.") + torch.backends.__allow_nonbracketed_mutation_flag = True + + if seed is not None: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: # restore the original flags + torch.backends.cudnn.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.benchmark = torch.backends.cudnn.benchmark + if use_deterministic_algorithms: + if hasattr(torch, "use_deterministic_algorithms"): # `use_deterministic_algorithms` is new in torch 1.8.0 + torch.use_deterministic_algorithms(use_deterministic_algorithms, warn_only=True) + elif hasattr(torch, "set_deterministic"): # `set_deterministic` is new in torch 1.7.0 + torch.set_deterministic(use_deterministic_algorithms) + else: + logger.warn("use_deterministic_algorithms=True, but PyTorch version is too old to set the mode.") \ No newline at end of file diff --git a/CCNet/utils/__init_.py b/CCNet/utils/__init_.py new file mode 100644 index 00000000..33b04bb8 --- /dev/null +++ b/CCNet/utils/__init_.py @@ -0,0 +1,2 @@ + +__all__ = ["checkpoint", "load_config", "logging", "lr_scheduler", "meters", "metrics", "misc", "parser_defaults"] diff --git a/CCNet/utils/checkpoint.py b/CCNet/utils/checkpoint.py new file mode 100644 index 00000000..79befba1 --- /dev/null +++ b/CCNet/utils/checkpoint.py @@ -0,0 +1,177 @@ + +# Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +import os +import glob + +import requests +import torch + +from FastSurferCNN.utils import logging + +LOGGER = logging.getLogger(__name__) + +# Defaults +URL = "TODO" +VINN_AXI = os.path.join(os.path.dirname(os.path.dirname(__file__)), "checkpoints/CCNet_sagittal_v0.1.0.pkl") +VINN_COR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "checkpoints/CCNet_sagittal_v0.1.0.pkl") +VINN_SAG = os.path.join(os.path.dirname(os.path.dirname(__file__)), "checkpoints/CCNet_sagittal_v0.1.0.pkl") + + +def create_checkpoint_dir(expr_dir, expr_num): + """ + Create the checkpoint dir if not exists + :param expr_dir: + :param expr_num: + :return: checkpoint path + """ + checkpoint_dir = os.path.join(expr_dir, "checkpoints", str(expr_num)) + os.makedirs(checkpoint_dir, exist_ok=True) + return checkpoint_dir + +def get_checkpoint(ckpt_dir, epoch): + checkpoint_dir = os.path.join(ckpt_dir, 'Epoch_{:05d}_training_state.pkl'.format(epoch)) + return checkpoint_dir + +def get_checkpoint_path(log_dir, resume_expr_num): + """ + + :param log_dir: + :param resume_expr_num: + :return: + """ + if resume_expr_num == "Default": + return None + checkpoint_path = os.path.join(log_dir, "checkpoints", str(resume_expr_num)) + prior_model_paths = sorted(glob.glob(os.path.join(checkpoint_path, 'Epoch_*')), key=os.path.getmtime) + if len(prior_model_paths) == 0: + return None + return prior_model_paths + + +def load_from_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None, fine_tune=False): + """ + Loading the model from the given experiment number + :param checkpoint_path: + :param model: + :param optimizer: + :param scheduler: + :param fine_tune: + :return: + epoch number + """ + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + try: + model.load_state_dict(checkpoint['model_state']) + except RuntimeError: + model.module.load_state_dict(checkpoint['model_state']) + + if not fine_tune: + if optimizer: + optimizer.load_state_dict(checkpoint['optimizer_state']) + if scheduler and "scheduler_state" in checkpoint.keys(): + scheduler.load_state_dict(checkpoint["scheduler_state"]) + + return checkpoint['epoch']+1, checkpoint['best_metric'] + + +def save_checkpoint(checkpoint_dir, epoch, best_metric, num_gpus, cfg, model, optimizer, scheduler=None, best=False): + """ + Saving the state of training for resume or fine-tune + :param checkpoint_dir: + :param epoch: + :param best_metric: + :param num_gpus: + :param cfg: + :param model: + :param optimizer: + :param scheduler: + :return: + """ + save_name = f"Epoch_{epoch:05d}_training_state.pkl" + saving_model = model.module if num_gpus > 1 else model + checkpoint = { + "model_state": saving_model.state_dict(), + "optimizer_state": optimizer.state_dict(), + "epoch": epoch, + "best_metric": best_metric, + "config": cfg.dump() + } + + if scheduler is not None: + checkpoint['scheduler_state'] = scheduler.state_dict() + + torch.save(checkpoint, checkpoint_dir + "/" + save_name) + + if best: + remove_ckpt(checkpoint_dir + "/Best_training_state.pkl") + torch.save(checkpoint, checkpoint_dir + "/Best_training_state.pkl") + + +def remove_ckpt(ckpt): + try: + os.remove(ckpt) + except FileNotFoundError: + pass + + +def download_checkpoint(download_url, checkpoint_name, checkpoint_path): + """ + Download a checkpoint file. Raises an HTTPError if the file is not found + or the server is not reachable. + :param download_url: str: URL of checkpoint hosting site + :param checkpoint_name: str: name of checkpoint + :param checkpoint_path: str: path of the file in which the checkpoint will be saved + :return: + """ + try: + response = requests.get(os.path.join(download_url, checkpoint_name), verify=True) + # Raise error if file does not exist: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + LOGGER.info('Response code: {}'.format(e.response.status_code)) + response = requests.get(os.path.join(download_url, checkpoint_name), verify=False) + response.raise_for_status() + + with open(checkpoint_path, 'wb') as f: + f.write(response.content) + + +def check_and_download_ckpts(checkpoint_path, url): + """ + Check and download a checkpoint file, if it does not exist. + :param checkpoint_path: str: path of the file in which the checkpoint will be saved + :param download_url: str: URL of checkpoint hosting site + :return: + """ + # Download checkpoint file from url if it does not exist + if not os.path.exists(checkpoint_path): + ckptdir, ckptname = os.path.split(checkpoint_path) + if not os.path.exists(ckptdir) and ckptdir: + os.makedirs(ckptdir) + download_checkpoint(url, ckptname, checkpoint_path) + + +def get_checkpoints(axi, cor, sag, url=URL): + """ + Check and download checkpoint files if not exist + :param download_url: str: URL of checkpoint hosting site + :return: + """ + check_and_download_ckpts(axi, url) + check_and_download_ckpts(cor, url) + check_and_download_ckpts(sag, url) diff --git a/CCNet/utils/meters.py b/CCNet/utils/meters.py new file mode 100644 index 00000000..a1077d58 --- /dev/null +++ b/CCNet/utils/meters.py @@ -0,0 +1,499 @@ + +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +import os + +from ignite.metrics import Metric +import ignite +from matplotlib import axis +import numpy as np +import torch + +from FastSurferCNN.utils import logging +from CCNet.utils.misc import calculate_centers_of_comissures, plot_confusion_matrix + +from scipy.ndimage import _ni_support +from scipy.ndimage.morphology import ( + distance_transform_edt, + binary_erosion, + generate_binary_structure, +) + + +logger = logging.getLogger(__name__) + + +class DiceScore(Metric): + """ + Accumulating the component of the dice coefficient i.e. the union and intersection + Args: + op (callable): a callable to update accumulator. Method's signature is `(accumulator, output)`. + For example, to compute arithmetic mean value, `op = lambda a, x: a + x`. + output_transform (callable, optional): a callable that is used to transform the + :class:`~ignite.engine.Engine`'s `process_function`'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + device (str of torch.device, optional): device specification in case of distributed computation usage. + In most of the cases, it can be defined as "cuda:local_rank" or "cuda" + if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is + initialized and available, device is set to `cuda`. + """ + + def __init__( + self, + num_classes, + class_ids=None, + device=None, + one_hot=False, + output_transform=lambda y_pred, y: (y_pred.data.max(1)[1], y), + ): + self._device = device + self.out_transform = output_transform + self.class_ids = class_ids + if self.class_ids is None: + self.class_ids = np.arange(num_classes) + self.n_classes = num_classes + assert len(self.class_ids) == self.n_classes, ( + f"Number of class ids is not correct," + f" given {len(self.class_ids)} but {self.n_classes} is needed." + ) + self.one_hot = one_hot + self.reset() + + def reset(self): + self.union = torch.zeros(self.n_classes, self.n_classes, device=self._device) + self.intersection = torch.zeros(self.n_classes, self.n_classes, device=self._device) + + def _check_output_type(self, output): + if not (isinstance(output, tuple)): + raise TypeError( + "Output should a tuple consist of of torch.Tensors, but given {}".format( + type(output) + ) + ) + + def _update_union_intersection(self, batch_output: torch.Tensor, labels_batch: torch.Tensor): + """Update the union intersection. + + Parameters + ---------- + batch_output : torch.Tensor + batch output (prediction, labels) + labels_batch : torch.Tensor + + """ + for i in range(self.n_classes): + gt = (labels_batch == i).float() + pred = (batch_output == i).float() + self.intersection[i, i] += torch.sum(torch.mul(gt, pred)) + self.union[i, i] += torch.sum(gt) + torch.sum(pred) + + def update(self, output): + self._check_output_type(output) + + if self.out_transform is not None: + y_pred, y = self.out_transform(*output) + else: + y_pred, y = output + + if not isinstance(y, torch.Tensor): + y = torch.from_numpy(y) + + if not isinstance(y_pred, torch.Tensor): + y_pred = torch.from_numpy(y_pred) + + if self._device is not None: + y = y.to(self._device) + y_pred = y_pred.to(self._device) + + self._update_union_intersection(y_pred, y) + + def compute(self, per_class=False, class_idxs=None): + dice_cm_mat = self._dice_confusion_matrix(class_idxs) + dice_score_per_class = dice_cm_mat.diagonal() + dice_score = dice_score_per_class.mean() + + self.reset() + if per_class: + return dice_score_per_class #, dice_cm_mat + else: + return dice_score #, dice_cm_mat + + def _dice_confusion_matrix(self, class_idxs): + dice_intersection = self.intersection.cpu().numpy() + dice_union = self.union.cpu().numpy() + if class_idxs is not None: + dice_union = dice_union[class_idxs[:, None], class_idxs] + dice_intersection = dice_intersection[class_idxs[:, None], class_idxs] + if not (dice_union > 0).all(): + logger.info("Union of some classes are all zero") + dice_cnf_matrix = 2 * np.divide(dice_intersection, dice_union) + return dice_cnf_matrix + +class LocDistance(Metric): + """ + Calculates the Locational Distance between pred and reference. + """ + + def __init__(self, device=None, axis = 2): + self._device = device + self.axis = axis + self.reset() + + def reset(self): + self.loc_distance = 0.0 + + def update(self, pred, reference): + """ + Updates the Locational Distance metric with new predictions and reference. + """ + if not isinstance(pred, torch.Tensor): + pred = torch.from_numpy(pred) + + if not isinstance(reference, torch.Tensor): + reference = torch.from_numpy(reference) + + if self._device is not None: + reference = reference.to(self._device) + pred = pred.to(self._device) + + self.loc_pred = pred[:,-1,:,:].detach() + self.loc_ref = reference[:,reference.shape[1]//2:, :] + + def compute(self): + """ + Computes the localication Distance metric. + """ + self.reset() + + self.loc_distance = localisation_distance(self.loc_pred, self.loc_ref, axis=self.axis) + return self.loc_distance + + + +def dice_score(pred, gt): + """ + Calculates the Dice Similarity between pred and gt. + """ + from scipy.spatial.distance import dice + + return dice(pred.flat, gt.flat) + + +def volume_similarity(pred, gt): + """ + Calculate the Volume Similarity between pred and gt. + """ + pred_vol, gt_vol = np.sum(pred), np.save(gt) + return 1.0 - np.abs(pred_vol - gt_vol) / (pred_vol + gt_vol) + + +# https://github.com/amanbasu/3d-prostate-segmentation/blob/master/metric_eval.py +def hd(result, reference, voxelspacing=None, connectivity=1): + """ + Hausdorff Distance. + Computes the (symmetric) Hausdorff Distance (HD) between the binary objects in two + images. It is defined as the maximum surface distance between the objects. + Parameters + ---------- + result : array_like + Input data containing objects. Can be any type but will be converted + into binary: background where 0, object everywhere else. + reference : array_like + Input data containing objects. Can be any type but will be converted + into binary: background where 0, object everywhere else. + voxelspacing : float or sequence of floats, optional + The voxelspacing in a distance unit i.e. spacing of elements + along each dimension. If a sequence, must be of length equal to + the input rank; if a single number, this is used for all axes. If + not specified, a grid spacing of unity is implied. + connectivity : int + The neighbourhood/connectivity considered when determining the surface + of the binary objects. This value is passed to + `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. + Note that the connectivity influences the result in the case of the Hausdorff distance. + Returns + ------- + hd : float + The symmetric Hausdorff Distance between the object(s) in ```result``` and the + object(s) in ```reference```. The distance unit is the same as for the spacing of + elements along each dimension, which is usually given in mm. + See also + -------- + :func:`assd` + :func:`asd` + Notes + ----- + This is a real metric. The binary images can therefore be supplied in any order. + """ + hd1 = __surface_distances(result, reference, voxelspacing, connectivity) + hd2 = __surface_distances(reference, result, voxelspacing, connectivity) + hd = max(hd1.max(), hd2.max()) + hd95 = np.percentile(np.hstack((hd1, hd2)), 95) + hd50 = np.percentile(np.hstack((hd1, hd2)), 50) + return hd, hd50, hd95 + + +def hd95(result, reference, voxelspacing=None, connectivity=1): + """ + 95th percentile of the Hausdorff Distance. + Computes the 95th percentile of the (symmetric) Hausdorff Distance (HD) between the binary objects in two + images. Compared to the Hausdorff Distance, this metric is slightly more stable to small outliers and is + commonly used in Biomedical Segmentation challenges. + Parameters + ---------- + result : array_like + Input data containing objects. Can be any type but will be converted + into binary: background where 0, object everywhere else. + reference : array_like + Input data containing objects. Can be any type but will be converted + into binary: background where 0, object everywhere else. + voxelspacing : float or sequence of floats, optional + The voxelspacing in a distance unit i.e. spacing of elements + along each dimension. If a sequence, must be of length equal to + the input rank; if a single number, this is used for all axes. If + not specified, a grid spacing of unity is implied. + connectivity : int + The neighbourhood/connectivity considered when determining the surface + of the binary objects. This value is passed to + `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. + Note that the connectivity influences the result in the case of the Hausdorff distance. + Returns + ------- + hd : float + The symmetric Hausdorff Distance between the object(s) in ```result``` and the + object(s) in ```reference```. The distance unit is the same as for the spacing of + elements along each dimension, which is usually given in mm. + See also + -------- + :func:`hd` + Notes + ----- + This is a real metric. The binary images can therefore be supplied in any order. + """ + hd1 = __surface_distances(result, reference, voxelspacing, connectivity) + hd2 = __surface_distances(reference, result, voxelspacing, connectivity) + hd95 = np.percentile(np.hstack((hd1, hd2)), 95) + return hd95 + + +def __surface_distances(result, reference, voxelspacing=None, connectivity=1): + """ + The distances between the surface voxel of binary objects in result and their + nearest partner surface voxel of a binary object in reference. + """ + result = np.atleast_1d(result.astype(np.bool)) + reference = np.atleast_1d(reference.astype(np.bool)) + if voxelspacing is not None: + voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim) + voxelspacing = np.asarray(voxelspacing, dtype=np.float64) + if not voxelspacing.flags.contiguous: + voxelspacing = voxelspacing.copy() + + # binary structure + footprint = generate_binary_structure(result.ndim, connectivity) + + # test for emptiness + if 0 == np.count_nonzero(result): + raise RuntimeError( + "The first supplied array does not contain any binary object." + ) + if 0 == np.count_nonzero(reference): + raise RuntimeError( + "The second supplied array does not contain any binary object." + ) + + # extract only 1-pixel border line of objects + result_border = result ^ binary_erosion(result, structure=footprint, iterations=1) + reference_border = reference ^ binary_erosion( + reference, structure=footprint, iterations=1 + ) + + # compute average surface distance + # Note: scipys distance transform is calculated only inside the borders of the + # foreground objects, therefore the input has to be reversed + dt = distance_transform_edt(~reference_border, sampling=voxelspacing) + sds = dt[result_border] + + return sds + +def localisation_distance(result, reference, axis=2): + """ + Calculates the distances between the centers of the comissures of the predicted and reference volume + + Parameters + ---------- + result : np.ndarray + Predicted volume + reference : np.ndarray + Reference volume + axis : int + Axis along which the commisure centers are separated (default: 2) + """ + + assert result.shape == reference.shape, f'Localisation shapes do not match: {result.shape} vs {reference.shape}' + + batch_size = reference.shape[0] + + # calculate centers of comissures + commissures_pred = [] + commissures_ref = [] + + for i in range(batch_size): + try: + ac_pc_pred = calculate_centers_of_comissures(result[i, ...], axis) + ac_pc_ref = calculate_centers_of_comissures(reference[i, ...], axis) + + except ValueError: + logger.warn(f'No valid comissures found.') + continue + + commissures_pred.append(ac_pc_pred) + commissures_ref.append(ac_pc_ref) + + commissures_pred = np.array(commissures_pred, dtype=np.float32) + commissures_ref = np.array(commissures_ref, dtype=np.float32) + + # catch case where all distances are None + if len(commissures_pred) == 0: + logger.warn(f'No valid distances found. Skipping.') + return float('nan') + + # calculate distances + distances = np.array([np.linalg.norm(commissures_pred - commissures_ref, axis=1)], dtype=np.float32) + + # calculate average distance + avg_distance = distances[distances != None].flatten().mean() + + logger.debug(f'Average distance between predicted and reference points: {avg_distance}') + + if np.isnan(avg_distance): + logger.warn(f'Average distance is NaN') + return float('nan') + + return avg_distance + +class Meter: + def __init__(self, + cfg, + mode, + global_step, + total_iter=None, + total_epoch=None, + class_names=None, + #device=None, + writer=None, + ignite_metrics={}): + self._cfg = cfg + self.mode = mode.capitalize() + self.class_names = class_names + if self.class_names is None: + self.class_names = [f"{c+1}" for c in range(cfg.MODEL.NUM_CLASSES)] + + self.batch_losses = [] + self.writer = writer + self.global_iter = global_step + self.total_iter_num = total_iter + self.total_epochs = total_epoch + + self.ignite_metrics = ignite_metrics + + self.save_path = os.path.join(cfg.LOG_DIR, "pred", str(cfg.EXPR_NUM)) + + def reset(self): + self.batch_losses = [] + for n, metric in self.ignite_metrics.items(): + metric.reset() + + def add_ignitemetric(self, name, ignitemetric): + self.ignite_metrics[name] = ignitemetric + + + def update_stats(self, pred, labels, batch_loss): + # self.dice_score.update((pred, labels)) + self.batch_losses.append(batch_loss.item()) + + def write_summary(self, loss_total, lr=None, loss_ce=None, loss_dice=None, loss_seg = None, loc_loss=None, dist_loss=None): + self.global_iter += 1 + if self.writer is None: + raise ValueError("Writer is None. Cannot write summary.") + + # add standart metrics + self.writer.add_scalar(f"{self.mode}/total_loss", loss_total.item(), self.global_iter) + + + # add metrics only for training + if self.mode == 'Train': + if lr: + self.writer.add_scalar("Train/lr", lr[0], self.global_iter) + if loss_ce: + self.writer.add_scalar("Train/ce_loss", loss_ce.item(), self.global_iter) + if loss_dice: + self.writer.add_scalar("Train/dice_loss", loss_dice.item(), self.global_iter) + if loc_loss: + self.writer.add_scalar("Train/loc_loss", loc_loss, self.global_iter) + if dist_loss: + self.writer.add_scalar("Train/dist_loss", dist_loss, self.global_iter) + if loss_seg: + self.writer.add_scalar("Train/seg_loss", loss_seg, self.global_iter) + + + def log_iter(self, cur_iter, cur_epoch): + if (cur_iter+1) % self._cfg.TRAIN.LOG_INTERVAL == 0: + logger.info("{} Epoch [{}/{}] Iter [{}/{}] with loss {:.4f}".format(self.mode, + cur_epoch + 1, self.total_epochs, + cur_iter + 1, self.total_iter_num, + np.array(self.batch_losses).mean() + )) + + def log_epoch(self, cur_epoch, runtime=None): + if self.writer is None: + raise ValueError("Writer is None. Cannot log the epoch.") + + if runtime is not None: + self.writer.add_scalar(f"{self.mode}/runtime_per_epoch", runtime, cur_epoch) + + for metric_name in self.ignite_metrics.keys(): + if metric_name == 'confusion_matrix': + confusion_mat = self.ignite_metrics['confusion_matrix'].compute() + # file_save_name = os.path.join(self.save_path, 'Epoch_' + str(cur_epoch) + f'_{self.mode}_ConfusionMatrix.pdf') + # plot_confusion_matrix(cm=confusion_mat, file_save_name=file_save_name, classes=self.class_names) + # #self.writer.add_figure(f"{self.mode}/confusion_mat", fig, cur_epoch) + # #plt.close('all') + continue + + if metric_name == 'locational_distance': + loc_distance = self.ignite_metrics['locational_distance'].compute() + self.writer.add_scalar(f"{self.mode}/locational_distance", loc_distance, cur_epoch) + continue + + metric = self.ignite_metrics[metric_name] + + try: + log_output = metric.compute() + except ignite.exceptions.NotComputableError: + logger.warn(f"{metric_name} is not computable. Skipping.") + continue + + # todo: Change to different types of float or check length + if isinstance(log_output, float): + self.writer.add_scalar(f"{self.mode}/{metric_name}", log_output, cur_epoch) + else: + print(f"WARNING: {metric_name} has more than one value. Only the mean is logged.") + self.writer.add_scalar(f"{self.mode}/{metric_name}", log_output.mean(), cur_epoch) + + + diff --git a/CCNet/utils/misc.py b/CCNet/utils/misc.py new file mode 100644 index 00000000..802f3e39 --- /dev/null +++ b/CCNet/utils/misc.py @@ -0,0 +1,409 @@ + +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +import os +from itertools import product +from typing import Union +import scipy + +import torch +import numpy as np +import pandas as pd +import matplotlib +import matplotlib.pyplot as plt +from mpl_toolkits.axes_grid1 import make_axes_locatable +from torchvision import utils + +from FastSurferCNN.utils import logging + +logger = logging.getLogger(__name__) + +matplotlib.use('agg') # standard QT backend does not allow to be run in thread +#matplotlib.use('TkAgg') # standard QT backend does not allow to be run in thread + + +assert(os.environ['FASTSURFER_HOME'] is not None), 'Please set the environment variable FASTSURFER_HOME to the root of the CCNet repository!' +DEFAULT_LUT_PATH = os.path.join(os.environ['FASTSURFER_HOME'] ,'CCNet/config/FastSurfer_ColorLUT.tsv') + + +def calculate_centers_of_comissures(volume: Union[torch.Tensor, np.ndarray], axis: int = 3) -> np.ndarray : + """ + Calculate the centers of the commisures from a probaility volume. + + Args: + volume np.ndarray or torch.Tensor: + volume to calculate the centers from in logits or probability space. + axis int: + axis (counting from 1) for which the minimum coordinate value is considered to be the pc and the max the ac. + If negative the max of coordinate value is considered to be the pc and the min the ac. + Default: 2 (Anterior-axis in RAS space) + + Returns: + tuple[int, int, int]: voxel space coordinates of the center of the anterior commisure + tuple[int, int, int]: voxel space coordinates of the center of the posterior commisure + + """ + + # convert to numpy + if isinstance(volume, torch.Tensor): + volume = volume.detach().cpu().numpy().astype(np.float32) + + if np.sum(volume) == 0: + raise ValueError('Volume is empty!') + + if np.min(volume) < 0: + # map logits space to probability space + volume = torch.nn.functional.sigmoid(torch.from_numpy(volume)).detach().cpu().numpy() + + # put gaußian filter over volume + volume = scipy.ndimage.gaussian_filter(volume, sigma=5, mode='constant', cval=0.0) + + # calculate two centers + + # Apply maximum filter + size = [8 for _ in range(len(volume.shape))] + local_max = scipy.ndimage.maximum_filter(volume, size=size) == volume + + # Filter out the 0s + local_max = np.logical_and(local_max, volume != 0) + + # Flatten the volume and get the indices of the maxima + maxima_indices = np.argwhere(local_max) + + # Get the values at these indices + maxima_values = volume[local_max] + + # Sort the indices by the values + sorted_indices = np.argsort(maxima_values)[::-1] + + + # Get the coordinates of the 5 strongest maxima + strongest_maxima = maxima_indices[sorted_indices[:100]] + strongest_maxima_values = maxima_values[sorted_indices[:100]] + + # Calculate the pairwise distance between all maxima + distances = scipy.spatial.distance.cdist(strongest_maxima, strongest_maxima) + + # Create a mask of the maxima to remove + mask = np.ones(len(strongest_maxima), dtype=bool) + + for i in range(len(strongest_maxima)): + if mask[i] == 0: + continue + + # Get the distances to the other maxima + dist_to_others = distances[i] + + # Find the maxima within a distance of 5 voxels + close_maxima = np.where((dist_to_others < 10) & (dist_to_others > 0))[0] + + # Mark the close maxima for removal + mask[close_maxima] = False + + # Apply the mask to remove the close maxima + strongest_maxima = strongest_maxima[mask][:2] + strongest_maxima_values = strongest_maxima_values[mask][:2] + + # differentiate between AC and PC + + # PC is the one in posterior direction + if axis < 0: + PC_idx = int(np.argmax(strongest_maxima[:, (-axis)-1])) + else: + PC_idx = int(np.argmin(strongest_maxima[:, axis-1])) + + # AC is the one in anterior direction + AC_idx = 1 - PC_idx + + if len(strongest_maxima) == 2: + return np.array([strongest_maxima[AC_idx], strongest_maxima[PC_idx]], dtype=np.int32) + else: + raise ValueError('Could not find two maxima!') + + + + +def get_lut(lookup_table: str = DEFAULT_LUT_PATH): + # Load lookup table + # try: # to map with original fastsurfer lookup table + lut_df = pd.read_csv(lookup_table, sep='\t')#.set_index('ID') + lut_df = lut_df[['R','G','B']] + + np_lut = np.full((lut_df.index.max() +1, 3), -1, dtype=int) + + for i in lut_df.index: + np_lut[i, :] = lut_df.iloc[i, :] + + # except: + # colors = None + # color_grid = color.label2rgb(label=grid.numpy(), image=img_grid, colors=colors, bg_label=0) + + return np_lut + + +def save_imgage_label_plot(img_grid: np.ndarray, label_grid: np.ndarray, net_output_grid: np.ndarray, plt_title: str, file_save_name: str, lut_name: str = DEFAULT_LUT_PATH): + np_lut = get_lut(lookup_table=lut_name) + + logger.debug(f'Saving image: {file_save_name}') + f = plt.figure(figsize=(20, 10)) + plt.subplot(211) + + logger.debug(f'Creating GT') + logger.debug(f'Unique labels: {np.unique(label_grid)}') + + color_grid = np_lut[label_grid] + if not (color_grid >= 0).all(): + logger.warn('Some labels are not in the lookup table!') + color_grid[color_grid < 0] = 0 + + plt.imshow(img_grid) + plt.imshow(color_grid, alpha=0.5) + + plt.title('Ground Truth') + + plt.subplot(212) + #color_grid = color.label2rgb(grid.numpy(), bg_label=0) + + logger.debug(f'Creating predictions') + logger.debug(f'Unique labels: {np.unique(net_output_grid)}') + + color_grid = np_lut[net_output_grid] + + if not (color_grid >= 0).all(): + logger.warn('Some labels are not in the lookup table!') + color_grid[color_grid < 0] = 0 + + plt.imshow(img_grid) + plt.imshow(color_grid, alpha=0.5) + plt.title('Prediction') + + plt.suptitle(plt_title) + plt.tight_layout() + + f.savefig(file_save_name, bbox_inches='tight') + #plt.close(f) + #plt.gcf().clear() + logger.debug(f'Successfully created {file_save_name}') + return 0 + + +def plot_predictions(images_batch, labels_batch, batch_output, plt_title, file_save_name, process_executor=None, lut = DEFAULT_LUT_PATH): #, lookup_table='./config/FastSurfer_ColorLUT.tsv'): + """ + Function to plot predictions from validation set. + :param images_batch: input images + :param labels_batch: ground truth + :param batch_output: output from the network + :param plt_title: title of the plot + :param file_save_name: path to save the plot + :param lookup_table: path to lookup table for colors + :return: + """ + logger.debug('Plotting predictions') + + c = images_batch.shape[1] # slice dimension + mid_slice = c // 2 + images_batch = torch.unsqueeze(images_batch[:, mid_slice, :, :], 1) + + logger.debug('Creating grid') + img_grid = utils.make_grid(images_batch, nrow=4).cpu().numpy().transpose((1, 2, 0)) + label_grid = utils.make_grid(labels_batch.unsqueeze(1), nrow=4)[0].cpu().numpy() + net_output_grid = utils.make_grid(batch_output.unsqueeze(1), nrow=4)[0].cpu().numpy() + logger.debug('Done creating grid') + + #with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: # do IO in a seperate process + if process_executor is not None: + logger.debug(f'Saving image in process executor') + #_ = process_executor.submit(save_imgage_label_plot, img_grid, label_grid, net_output_grid, plt_title, file_save_name) + _ = process_executor.apply_async(save_imgage_label_plot, (img_grid, label_grid, net_output_grid, plt_title, file_save_name, lut)) + else: + save_imgage_label_plot(img_grid, label_grid, net_output_grid, plt_title, file_save_name, lut) + + # # get the result from the task + # exception = plot1_future.exception() + # # handle exceptional case + # if exception: + # print(exception) + # else: + # result = plot1_future.result() + # print(result) + + +def save_inpainting_plot(img_grid: np.ndarray, label_grid: np.ndarray, net_output_grid: np.ndarray, plt_title: str, file_save_name: str, dpi=100): + f = plt.figure(figsize=(20, 20)) + + plt.subplot(511) + plt.imshow(label_grid, vmin=0, vmax=1) + plt.title('Ground Truth') + + plt.subplot(512) + plt.imshow(net_output_grid, vmin=0, vmax=1) + plt.title('Prediction') + + plt.subplot(513) + plt.imshow(img_grid, vmin=0, vmax=1) + plt.title('Input') + + plt.subplot(514) + plt.imshow(np.abs(label_grid - net_output_grid), vmin=0, vmax=1) + plt.title('Ground truth - Prediction') + + plt.subplot(515) + plt.imshow(np.abs(img_grid - label_grid), vmin=0, vmax=1) + plt.title('Input - Ground truth') + + plt.suptitle(plt_title) + plt.tight_layout() + + f.savefig(file_save_name, bbox_inches='tight', dpi=dpi) + plt.close(f) + plt.gcf().clear() + + return 0 + +def plot_predictions_inpainting(images_batch, labels_batch, batch_output, plt_title, file_save_name, process_executor=None, dpi=100): + """ + Function to plot predictions from validation set. + :param images_batch: + :param labels_batch: + :param batch_output: + :param plt_title: + :param file_save_name: + :return: + """ + + c = images_batch.shape[1] # slice dimension + mid_slice = c // 2 + img_grid = utils.make_grid(images_batch[:, mid_slice, :, :].unsqueeze(1), nrow=4).cpu().numpy().transpose((1, 2, 0)) + label_grid = utils.make_grid(labels_batch.unsqueeze(1), nrow=4).cpu().numpy().transpose((1, 2, 0)) + output_grid = utils.make_grid(batch_output.unsqueeze(1), nrow=4).cpu().numpy().transpose((1, 2, 0)) + + #with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: # do IO in a seperate process + + if process_executor is not None: + #_ = process_executor.submit(save_inpainting_plot, img_grid, label_grid, output_grid, plt_title, file_save_name) + _ = process_executor.apply_async(save_inpainting_plot, (img_grid, label_grid, output_grid, plt_title, file_save_name)) + else: + save_inpainting_plot(img_grid, label_grid, output_grid, plt_title, file_save_name, dpi=dpi) + + # # get the result from the task + # exception = plot2_future.exception() + # # handle exceptional case + # if exception: + # print(exception) + # else: + # result = plot2_future.result() + # print(result) + + + + +def plot_confusion_matrix(cm, + classes, + file_save_name=None, + title='Confusion matrix', + cmap=plt.cm.Blues, + ): + n_classes = len(classes) + + fig, ax = plt.subplots(figsize=(n_classes, n_classes)) + im_ = ax.imshow(cm, interpolation='nearest', cmap=cmap) + text_ = None + ax.set_title(title) + + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.08) + fig.colorbar(im_, cax=cax) + + tick_marks = np.arange(n_classes) + ax.set(xticks=tick_marks, + yticks=tick_marks, + xticklabels=classes, + yticklabels=classes, + ylabel="True label", + xlabel="Predicted label") + + cmap_min, cmap_max = im_.cmap(0), im_.cmap(256) + + text_ = np.empty_like(cm, dtype=object) + + values_format = '.2f' + thresh = (cm.max() + cm.min()) / 2.0 + + for i, j in product(range(n_classes), range(n_classes)): + color = cmap_max if cm[i, j] < thresh else cmap_min + text_[i, j] = ax.text(j, i, + format(cm[i, j], values_format), + ha="center", va="center", + color=color) + + ax.set_ylim((n_classes - 0.5, -0.5)) + plt.setp(ax.get_xticklabels(), rotation='horizontal') + + if file_save_name is not None: + plt.savefig(file_save_name) + + return fig + + +def find_latest_experiment(path): + list_of_experiments = os.listdir(path) + list_of_int_experiments = [] + for exp in list_of_experiments: + try: + int_exp = int(exp) + except ValueError: + continue + list_of_int_experiments.append(int_exp) + + if len(list_of_int_experiments) == 0: + return 0 + + return max(list_of_int_experiments) + + +def check_path(path): + os.makedirs(path, exist_ok=True) + return path + + +def update_num_steps(dataloader, cfg): + cfg.TRAIN.NUM_STEPS = len(dataloader) + + +def find_device(device: str = "auto", flag_name:str = "device") -> torch.device: + """Create a device object from the device string passed, including detection of devices if device is not defined + or "auto". + """ + logger = logging.get_logger(__name__ + ".auto_device") + # if specific device is requested, check and stop if not available: + if device.split(':')[0] == "cuda" and not torch.cuda.is_available(): + logger.info(f"cuda not available, try switching to cpu: --{flag_name} cpu") + raise ValueError(f"--device cuda not available, try --{flag_name} cpu !") + if device == "mps" and not torch.backends.mps.is_available(): + logger.info(f"mps not available, try switching to cpu: --{flag_name} cpu") + raise ValueError(f"--device mps not available, try --{flag_name} cpu !") + # If auto detect: + if device == "auto" or not device: + # 1st check cuda + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + # Define device and transfer model + logger.info(f"Using {flag_name}: {device}") + return torch.device(device) diff --git a/CCNet/utils/parser_defaults.py b/CCNet/utils/parser_defaults.py new file mode 100644 index 00000000..3f71afdb --- /dev/null +++ b/CCNet/utils/parser_defaults.py @@ -0,0 +1,375 @@ +# Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Contains the ALL_FLAGS dictionary, which can be used as follows to add default flags. + +>>> parser = argparse.ArgumentParser() +>>> ALL_FLAGS["allow_root"](parser, dest="root") +>>> args = parser.parse_args() +>>> allows_root = args.root # instead of the default dest args.allow_root + +Values can also be extracted by +>>> print(ALL_FLAGS["allow_root"](dict, dest="root") +>>> # {'flag': '--allow_root', 'flags': ('--allow_root',), 'action': 'store_true', 'dest': 'root', +>>> # 'help': 'Allow execution as root user.'} +""" + +import argparse +import multiprocessing +from os import path +from typing import Iterable, Mapping, Union, Literal, Dict, Protocol, TypeVar, Type + +from FastSurferCNN.utils.arg_types import ( + vox_size as __vox_size, + float_gt_zero_and_le_one as __conform_to_one_mm, + unquote_str, +) + +FASTSURFER_ROOT = path.dirname(path.dirname(path.dirname(__file__))) +PLANE_SHORT = {"checkpoint": "ckpt", "config": "cfg"} +PLANE_HELP = { + "checkpoint": "{} checkpoint to load", + "config": "Path to the {} config file", +} +VoxSize = Union[Literal["min"], float] + + +class CanAddArguments(Protocol): + """[MISSING].""" + + def add_argument(self, *args, **kwargs): + """[MISSING].""" + ... + + +def __arg(*default_flags, **default_kwargs): + """Create stub function, which sets default settings for argparse arguments. + + The positional and keyword arguments function as if they were directly passed to parser.add_arguments(). + + The result will be a stub function, which has as first argument a parser (or other object with an + add_argument method) to which the argument is added. The stub function also accepts positional and + keyword arguments, which overwrite the default arguments. Additionally, these specific values can be callables, + which will be called upon the default values (to alter the default value). + + This function is private for this module. + """ + def _stub(parser: Union[CanAddArguments, Type[Dict]], *flags, **kwargs): + # prefer the value passed to the "new" call + for kw, arg in kwargs.items(): + if callable(arg) and kw in default_kwargs.keys(): + kwargs[kw] = arg(default_kwargs[kw]) + # if no new value is provided to _stub (which is the callable in ALL_FLAGS), use the + # default value (stored in the callable/passed to the default below) + for kw, default in default_kwargs.items(): + if kw not in kwargs.keys(): + kwargs[kw] = default + + _flags = flags if len(flags) != 0 else default_flags + if hasattr(parser, "add_argument"): + return parser.add_argument(*_flags, **kwargs) + elif parser == dict: + return {"flag": _flags[0], "flags": _flags, **kwargs} + else: + raise ValueError( + f"Unclear parameter, should be dict or argparse.ArgumentParser, not {type(parser).__name__}." + ) + + return _stub + + +ALL_FLAGS = { + "t1": __arg( + "--t1", + type=str, + dest="orig_name", + default="mri/orig.mgz", + help="Name of T1 full head MRI. Absolute path if single image else " + "common image name. Default: mri/orig.mgz", + ), + "remove_suffix": __arg( + "--remove_suffix", + type=str, + dest="remove_suffix", + default="", + help="Optional: remove suffix from path definition of input file to yield correct subject name " + "(e.g. /ses-x/anat/ for BIDS or /mri/ for FreeSurfer input). Default: do not remove anything.", + ), + "aparc_aseg_segfile": __arg( + "--aparc_aseg_segfile", + type=str, + dest="pred_name", + default="", + help="aparc_aseg_segfile", + ), + "lesion_mask": __arg( + "--lesion_mask", + type=str, + dest="lesion_mask", + default=None, + help="aparc_aseg_segfile", + ), + "sid": __arg( + "--sid", + type=str, + dest="sid", + default=None, + help="Optional: directly set the subject id to use. Can be used for single subject input. For multi-subject " + "processing, use remove suffix if sid is not second to last element of input file passed to --t1", + ), + "asegdkt_segfile": __arg( + "--asegdkt_segfile", + "--aparc_aseg_segfile", + type=str, + dest="pred_name", + default="mri/aparc.DKTatlas+aseg.deep.mgz", + help="Name of intermediate DL-based segmentation file (similar to aparc+aseg). " + "When using FastSurfer, this segmentation is already conformed, since inference " + "is always based on a conformed image. Absolute path if single image else common " + "image name. Default: mri/aparc.DKTatlas+aseg.deep.mgz", + ), + "conformed_name": __arg( + "--conformed_name", + type=str, + dest="conf_name", + default="mri/orig.mgz", + help="Name under which the conformed input image will be saved, in the same directory " + "as the segmentation (the input image is always conformed first, if it is not " + "already conformed). The original input image is saved in the output directory " + "as $id/mri/orig/001.mgz. Default: mri/orig.mgz.", + ), + "norm_name": __arg( + "--norm_name", + type=str, + dest="norm_name", + default="mri/norm.mgz", + help="Name under which the bias field corrected image is stored. Default: mri/norm.mgz.", + ), + "brainmask_name": __arg( + "--brainmask_name", + type=str, + dest="brainmask_name", + default="mri/mask.mgz", + help="Name under which the brainmask image will be saved, in the same directory " + "as the segmentation. The brainmask is created from the aparc_aseg segmentation " + "(dilate 5, erode 4, largest component). Default: mri/mask.mgz.", + ), + "aseg_name": __arg( + "--aseg_name", + type=str, + dest="aseg_name", + default="mri/aseg.auto_noCCseg.mgz", + help="Name under which the reduced aseg segmentation will be saved, in the same directory " + "as the aparc-aseg segmentation (labels of full aparc segmentation are reduced to aseg). " + "Default: mri/aseg.auto_noCCseg.mgz.", + ), + "seg_log": __arg( + "--seg_log", + type=str, + dest="log_name", + default="", + help="Absolute path to file in which run logs will be saved. If not set, logs will " + "not be saved.", + ), + "device": __arg( + "--device", + default="auto", + help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or specify a certain gpu " + "(e.g. cuda:1), default: auto", + ), + "viewagg_device": __arg( + "--viewagg_device", + dest="viewagg_device", + type=str, + default="auto", + help="Define the device, where the view aggregation should be run. By default, the program checks " + "if you have enough memory to run the view aggregation on the gpu (cuda). The total memory is " + "considered for this decision. If this fails, or you actively overwrote the check with setting " + "> --viewagg_device cpu <, view agg is run on the cpu. Equivalently, if you define " + "> --viewagg_device cuda <, view agg will be run on the gpu (no memory check will be done).", + ), + "in_dir": __arg( + "--in_dir", + type=str, + default=None, + help="Directory in which input volume(s) are located. " + "Optional, if full path is defined for --t1.", + ), + "tag": __arg( + "--tag", + type=unquote_str, + dest="search_tag", + default="*", + help="Search tag to process only certain subjects. If a single image should be analyzed, " + "set the tag with its id. Default: processes all.", + ), + "csv_file": __arg( + "--csv_file", + type=str, + help="Csv-file with subjects to analyze (alternative to --tag)", + default=None, + ), + "batch_size": __arg( + "--batch_size", type=int, default=1, help="Batch size for inference. Default=1" + ), + "sd": __arg( + "--sd", + type=str, + default=None, + dest="out_dir", + help="Directory in which evaluation results should be written. " + "Will be created if it does not exist. Optional if full path is defined for --pred_name.", + ), + "qc_log": __arg( + "--qc_log", + type=str, + dest="qc_log", + default="", + help="Absolute path to file in which a list of subjects that failed QC check (when processing multiple " + "subjects) will be saved. If not set, the file will not be saved.", + ), + "vox_size": __arg( + "--vox_size", + type=__vox_size, + default="min", + dest="vox_size", + help="Choose the primary voxelsize to process, must be either a number between 0 and 1 (below 0.7 is " + "experimental) or 'min' (default). A number forces processing at that specific voxel size, 'min' " + "determines the voxel size from the image itself (conforming to the minimum voxel size, or 1 if " + "the minimum voxel size is above 0.95mm). ", + ), + "conform_to_1mm_threshold": __arg( + "--conform_to_1mm_threshold", + type=__conform_to_one_mm, + default=0.95, + dest="conform_to_1mm_threshold", + help="The voxelsize threshold, above which images will be conformed to 1mm isotropic, if the --vox_size " + "argument is also 'min' (the --vox_size default setting). Contrary to conform.py, the default behavior" + "of %(prog)s is to resample all images _above 0.95mm_ to 1mm.", + ), + "lut": __arg( + "--lut", + type=str, + help="Path and name of LUT to use.", + default=path.join( + FASTSURFER_ROOT, "CCNet/config/FastSurfer_ColorLUT.tsv" + ), + ), + "allow_root": __arg( + "--allow_root", + action="store_true", + dest="allow_root", + help="Allow execution as root user.", + ), + "threads": __arg( + "--threads", + dest="threads", + default=multiprocessing.cpu_count(), + type=int, + help=f"Number of threads to use (defaults to number of hardware threads: {multiprocessing.cpu_count()})", + ), + "async_io": __arg( + "--async_io", + dest="async_io", + action="store_true", + help="Allow asynchronous file operations (default: off). Note, this may impact the order of " + "messages in the log, but speed up the segmentation specifically for slow file systems.", + ), + "crop": __arg( + "--crop", + dest="crop", + action="store_true", + help="Crop the input image to 5mm around the midplane 128", #TODO: variable midplane + ), + +} + +T_AddArgs = TypeVar("T_AddArgs", bound=CanAddArguments) + + +def add_arguments(parser: T_AddArgs, flags: Iterable[str]) -> T_AddArgs: + """Add default flags to the parser from the flags list in order. + + Parameters + ---------- + parser : T_AddArgs + The parser to add flags to. + flags : Iterable[str] + the flags to add from 'device', 'viewagg_device'. + + Returns + ------- + T_AddArgs + The parser object + + """ + for flag in flags: + if flag.startswith("--"): + flag = flag[2:] + add_flag = ALL_FLAGS.get(flag, None) + if add_flag is not None: + add_flag(parser) + else: + raise ValueError( + f"The flag '{flag}' is not defined in CCNet.utils.parse.add_arguments()." + ) + return parser + + +def add_plane_flags( + parser: argparse.ArgumentParser, + type: Literal["checkpoint", "config"], + files: Mapping[str, str], +) -> argparse.ArgumentParser: + """Add plane arguments. + + Arguments will be added for each entry in files, where the key is the "plane" + and the values is the file name (relative for path relative to FASTSURFER_HOME. + + Parameters + ---------- + parser : argparse.ArgumentParser + The parser to add flags to. + type : Literal["checkpoint", "config"] + The type of files (for help text and prefix from "checkpoint" and "config". + "checkpoint" will lead to flags like "--ckpt_{plane}", "config" to "--cfg_{plane}" + files : Mapping[str, str] + A dictionary of plane to filename. Relative files are assumed to be relative to the FastSurfer root + directory. + + Returns + ------- + argparse.ArgumentParser + The parser object. + + """ + if type not in PLANE_SHORT: + raise ValueError("type must be either config or checkpoint.") + + for key, filepath in files.items(): + if not path.isabs(filepath): + filepath = path.join(FASTSURFER_ROOT, filepath) + # find the first vowel in the key + flag = key.strip().lower() + index = min(i for i in (flag.find(v) for v in "aeiou") if i >= 0) + flag = flag[: index + 2] + parser.add_argument( + f"--{PLANE_SHORT[type]}_{flag}", + type=str, + dest=f"{PLANE_SHORT[type]}_{flag}", + help=PLANE_HELP[type].format(key), + default=filepath, + ) + return parser diff --git a/FastSurferCNN/config/defaults.py b/FastSurferCNN/config/defaults.py index 4008fbef..d9cc5a02 100644 --- a/FastSurferCNN/config/defaults.py +++ b/FastSurferCNN/config/defaults.py @@ -16,6 +16,12 @@ # Loss function, combined = dice loss + cross entropy, combined2 = dice loss + boundary loss _C.MODEL.LOSS_FUNC = "combined" +_C.MODEL.WEIGHT_SEG = 1e-5 + +_C.MODEL.WEIGHT_LOC = 1.0 + +_C.MODEL.WEIGHT_DIST = 0.0 + # Filter dimensions for DenseNet (all layers same) _C.MODEL.NUM_FILTERS = 71 @@ -112,6 +118,9 @@ # Delta = change below which early stopping starts (previous - current < delta = stop) _C.TRAIN.EARLY_STOPPING_DELTA = 0.00001 +# Flag to enable debugging run (smaller dataset, less epochs, etc.) +_C.TRAIN.DEBUG = False + # ---------------------------------------------------------------------------- # # Testing options # ---------------------------------------------------------------------------- # @@ -131,6 +140,7 @@ # path to validation hdf5-dataset _C.DATA.PATH_HDF5_VAL = "" +_C.DATA.PATH_HDF5_VAL2 = "" # The plane to load ['axial', 'coronal', 'sagittal'] _C.DATA.PLANE = "coronal" @@ -143,11 +153,16 @@ _C.DATA.SIZES = [256, 311, 320] # the size that all inputs are padded to -_C.DATA.PADDED_SIZE = 320 +_C.DATA.PADDED_SIZE = [320, 320, 320] # Augmentations _C.DATA.AUG = ["Scaling", "Translation"] +# Augmentation probability +_C.DATA.AUG_LIKELYHOOD = [0.3, 0.3] + +_C.DATA.LUT = "" + # ---------------------------------------------------------------------------- # # DataLoader options (common for test and train) # ---------------------------------------------------------------------------- # @@ -159,6 +174,9 @@ # Load data to pinned host memory. _C.DATA_LOADER.PIN_MEMORY = True +# How many batches to prefetch (maximum) +_C.DATA_LOADER.PREFETCH_FACTOR = 2 + # ---------------------------------------------------------------------------- # # Optimizer options # ---------------------------------------------------------------------------- # @@ -222,6 +240,8 @@ # log directory for run _C.LOG_DIR = "./experiments" +_C.LOG_LEVEL = 20 + # experiment number _C.EXPR_NUM = "Default" @@ -229,8 +249,8 @@ # operator implementations in GPU operator libraries. _C.RNG_SEED = 1 -_C.SUMMARY_PATH = "FastSurferVINN/summary/FastSurferVINN_coronal" -_C.CONFIG_LOG_PATH = "FastSurferVINN/config/FastSurferVINN_coronal" +# Predict healthy tissue for inpainting / anomaly detection +_C.INPAINT_LOSS_WEIGHT = 0.5 # Weight of intensity image in loss function def get_cfg_defaults(): diff --git a/FastSurferCNN/data_loader/data_utils.py b/FastSurferCNN/data_loader/data_utils.py index b5e9abd0..6b857b6e 100644 --- a/FastSurferCNN/data_loader/data_utils.py +++ b/FastSurferCNN/data_loader/data_utils.py @@ -391,9 +391,8 @@ def get_thick_slices( def filter_blank_slices_thick( - img_vol: npt.NDArray, + volume_list: list, label_vol: npt.NDArray, - weight_vol: npt.NDArray, threshold: int = 50 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ @@ -401,12 +400,10 @@ def filter_blank_slices_thick( Parameters ---------- - img_vol : npt.NDArray - Orig image volume. + volume_list : list + List of image volumes (thick slices). label_vol : npt.NDArray Label images (ground truth). - weight_vol : npt.NDArray - Weight corresponding to labels. threshold : int Threshold for number of pixels needed to keep slice (below = dropped). (Default value = 50). @@ -423,11 +420,11 @@ def filter_blank_slices_thick( select_slices = np.sum(label_vol, axis=(0, 1)) > threshold # Retain only slices with more than threshold labels/pixels - img_vol = img_vol[:, :, select_slices, :] - label_vol = label_vol[:, :, select_slices] - weight_vol = weight_vol[:, :, select_slices] + return_list = [] + for v in volume_list: + return_list.append(v[:, :, select_slices]) - return img_vol, label_vol, weight_vol + return return_list # weight map generator @@ -1122,6 +1119,13 @@ def map_prediction_sagittal2full( _idx.extend([[20, 22, 27], r(29, 32), [33, 34], r(38, 43), [45]]) elif num_classes == 21: _idx = [[0], r(5, 15), r(1, 4), [15, 16, 4], r(17, 20), r(5, 21)] + elif num_classes == 1: + _idx = [[0, 1, 1]] + elif num_classes == 3: + _idx = [[0, 1, 2]] + elif num_classes == 2: + _idx = [[0, 1]] + if _idx: from itertools import chain idx_list = list(chain(*_idx)) @@ -1192,3 +1196,96 @@ def get_largest_cc(segmentation: npt.NDArray) -> np.ndarray: largest_cc = labels == np.argmax(bincount) return largest_cc + + +def map_incrementing(mapped_aseg, lut): + """Map labels to an incrementing space""" + for idx, label in enumerate(lut['ID']): + mapped_aseg[mapped_aseg == label] = idx + return mapped_aseg + +def subset_volume_plane(volume: npt.NDArray, plane: Union[int, float] = 128, thickness: int = 5, axis: int = 0) -> np.ndarray: + """Return a subset of the volume around the plane + + Parameters: + ---------- + volume : npt.NDArray + volume to subset + plane : Union[int, float] + plane of the subset. Defaults to 127.5. (optional) + thickness : int + Total thickness of the subset. Defaults to 7. If odd, the plane should be inbetween two indices. (optional) + axis : int + axis of the plane. Defaults to 0. (optional) + + Returns: + np.ndarray: subset of the volume. + + Raises: + ValueError: If the thickness and plane combination is unknown. + ValueError: If the axis is unknown. + + Notes: + If plane is an integer, it is included in the subset, the subset will need to have an odd thickness. Otherwise, the subset will have an even thickness. + """ + assert 2*plane == int(2*plane), "Plane must be an integer or between two indices" + assert axis in [0, 1, 2], "axis must be 0, 1 or 2" + + # Calculate the lower and upper bounds of the subset + lower_bound, upper_bound = 0, 0 + if (thickness % 2 == 1) and plane == int(plane): + lower_bound = int(plane - thickness//2) + upper_bound = int(plane + thickness//2 + 1) + elif (thickness % 2 == 0) and plane != int(plane): + lower_bound = int(np.ceil(plane - thickness/2)) + upper_bound = int(np.floor(plane + thickness/2)) + else: + raise ValueError("Unknown thickness and plane combination.") + + + if axis == 0: + return volume[lower_bound: upper_bound, :, :] + elif axis == 1: + return volume[:, lower_bound: upper_bound, :] + elif axis == 2: + return volume[:, :, lower_bound: upper_bound] + else: + raise ValueError("Unknown axis") + +def pad_to_size(volume: np.ndarray, size: tuple, plane: Union[int, float] = 128, axis: int = 0): + """ Pad a volume to a given size. Functions as a reverse to subset_volume_plane + + Parameters: + ---------- + volume : np.ndarray + volume to pad + size : int + size to pad to + plane : Union[int, float] + plane of the subset. Defaults to 128. (optional) + axis (int, optional): + axis of the plane. Defaults to 0. + + """ + thickness = volume.shape[axis] + + pad_left = int(plane - thickness//2) + pad_right = int(size[axis] - pad_left - thickness) + + assert pad_left >= 0, "Volume is larger than the size" + assert pad_right >= 0, "Volume is larger than the size" + assert pad_left + pad_right + thickness == size[axis] , "Padding is not correct" + + if (size[2]-volume.shape[2]) != 0: + print("Padding volume") + + if (size[1]-volume.shape[1]) != 0: + print("Padding volume") + + # Calculate the lower and upper bounds of the subset + if axis == 0: + return np.pad(volume, ((pad_left, pad_right), (0, size[1]-volume.shape[1]), (0, size[2]-volume.shape[2])), mode='constant', constant_values=0) + elif axis == 1: + return np.pad(volume, ((0, size[0]-volume.shape[0]), (pad_left, pad_right), (0, size[2]-volume.shape[2])), mode='constant', constant_values=0) + elif axis == 2: + return np.pad(volume, ((0, size[0]-volume.shape[0]), (0, size[1]-volume.shape[1]), (pad_left, pad_right)), mode='constant', constant_values=0) diff --git a/FastSurferCNN/utils/parser_defaults.py b/FastSurferCNN/utils/parser_defaults.py index 767c990d..2d2f4fd6 100644 --- a/FastSurferCNN/utils/parser_defaults.py +++ b/FastSurferCNN/utils/parser_defaults.py @@ -343,6 +343,19 @@ class SubjectDirectoryConfig: "the order of messages in the log, but speed up the segmentation " "specifically for slow file systems.", ), + "crop": __arg( + "--crop", + dest="crop", + action="store_true", + help="Crop the input image to 5mm around the midplane 128", #TODO: variable midplane + ), + "lesion_mask": __arg( + "--lesion_mask", + type=str, + dest="lesion_mask", + default=None, + help="aparc_aseg_segfile", + ), } T_AddArgs = TypeVar("T_AddArgs", bound=CanAddArguments) diff --git a/checkpoints/CCNet_axial_v0.1.0.pkl b/checkpoints/CCNet_axial_v0.1.0.pkl new file mode 100755 index 00000000..52dc1ccf Binary files /dev/null and b/checkpoints/CCNet_axial_v0.1.0.pkl differ diff --git a/checkpoints/CCNet_coronal_v0.1.0.pkl b/checkpoints/CCNet_coronal_v0.1.0.pkl new file mode 100755 index 00000000..56d5661f Binary files /dev/null and b/checkpoints/CCNet_coronal_v0.1.0.pkl differ diff --git a/checkpoints/CCNet_sagittal_v0.1.0.pkl b/checkpoints/CCNet_sagittal_v0.1.0.pkl new file mode 100755 index 00000000..63311ef3 Binary files /dev/null and b/checkpoints/CCNet_sagittal_v0.1.0.pkl differ diff --git a/data/axial_train.hdf5 b/data/axial_train.hdf5 new file mode 120000 index 00000000..a08cf99f --- /dev/null +++ b/data/axial_train.hdf5 @@ -0,0 +1 @@ +/groups/ag-reuter/projects/corpus_callosum_fornix/FStumor/data/cropped-nocom-loc/axial_train.hdf5 \ No newline at end of file diff --git a/data/axial_val.hdf5 b/data/axial_val.hdf5 new file mode 120000 index 00000000..860e6d5a --- /dev/null +++ b/data/axial_val.hdf5 @@ -0,0 +1 @@ +/groups/ag-reuter/projects/corpus_callosum_fornix/FStumor/data/cropped-nocom-loc/axial_val.hdf5 \ No newline at end of file diff --git a/data/coronal_train.hdf5 b/data/coronal_train.hdf5 new file mode 120000 index 00000000..9de438a6 --- /dev/null +++ b/data/coronal_train.hdf5 @@ -0,0 +1 @@ +/groups/ag-reuter/projects/corpus_callosum_fornix/FStumor/data/cropped-nocom-loc/coronal_train.hdf5 \ No newline at end of file diff --git a/data/coronal_val.hdf5 b/data/coronal_val.hdf5 new file mode 120000 index 00000000..163d7d35 --- /dev/null +++ b/data/coronal_val.hdf5 @@ -0,0 +1 @@ +/groups/ag-reuter/projects/corpus_callosum_fornix/FStumor/data/cropped-nocom-loc/coronal_val.hdf5 \ No newline at end of file diff --git a/data/sagittal_train.hdf5 b/data/sagittal_train.hdf5 new file mode 120000 index 00000000..69e91a79 --- /dev/null +++ b/data/sagittal_train.hdf5 @@ -0,0 +1 @@ +/groups/ag-reuter/projects/corpus_callosum_fornix/FStumor/data/cropped-nocom-loc/sagittal_train.hdf5 \ No newline at end of file diff --git a/data/sagittal_val.hdf5 b/data/sagittal_val.hdf5 new file mode 120000 index 00000000..9b99e478 --- /dev/null +++ b/data/sagittal_val.hdf5 @@ -0,0 +1 @@ +/groups/ag-reuter/projects/corpus_callosum_fornix/FStumor/data/cropped-nocom-loc/sagittal_val.hdf5 \ No newline at end of file diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index 1e1c51cf..a09dd32c 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -33,12 +33,14 @@ fi fastsurfercnndir="$FASTSURFER_HOME/FastSurferCNN" cerebnetdir="$FASTSURFER_HOME/CerebNet" reconsurfdir="$FASTSURFER_HOME/recon_surf" +ccnetdir="$FASTSURFER_HOME/CCNet" # Regular flags defaults subject="" t1="" merged_segfile="" cereb_segfile="" +cc_segfile="" asegdkt_segfile="" asegdkt_segfile_default="\$SUBJECTS_DIR/\$SID/mri/aparc.DKTatlas+aseg.deep.mgz" asegdkt_statsfile="" @@ -58,6 +60,7 @@ surf_flags=() vox_size="min" run_asegdkt_module="1" run_cereb_module="1" +run_cc_module="1" threads="1" # python3.10 -s excludes user-directory package inclusion python="python3.10 -s" @@ -322,6 +325,11 @@ case $key in shift # past argument shift # past value ;; + --cc_segfile) + cc_segfile="$2" + shift # past argument + shift # past value + ;; --cereb_statsfile) cereb_statsfile="$2" shift # past argument @@ -393,6 +401,10 @@ case $key in run_cereb_module="0" shift # past argument ;; + --no_cc) + run_cc_module="0" + shift # past argument + ;; --tal_reg) run_talairach_registration="true" shift @@ -595,6 +607,11 @@ if [[ -z "$cereb_segfile" ]] cereb_segfile="${sd}/${subject}/mri/cerebellum.CerebNet.nii.gz" fi +if [[ -z "$cc_segfile" ]] + then + cc_segfile="${sd}/${subject}/mri/cc_aseg.deep.mgz" +fi + if [[ -z "$cereb_statsfile" ]] then cereb_statsfile="${sd}/${subject}/stats/cerebellum.CerebNet.stats" @@ -835,6 +852,35 @@ if [[ "$run_seg_pipeline" == "1" ]] exit 1 fi fi + + if [[ "$run_cc_module" == "1" ]] + then + + cmd=($python "$ccnetdir/run_prediction.py" --t1 "$t1" + --sid "$subject" + --aparc_aseg_segfile "$cc_segfile" + --seg_log "$seg_log" + --batch_size "$batch_size" --viewagg_device "$viewagg" --device "$device" + "${allow_root[@]}" --crop + --lut "$ccnetdir/config/CC_ColorLUT.tsv" + --ckpt_sag "$FASTSURFER_HOME/checkpoints/CCNet_sagittal_v0.1.0.pkl" + --ckpt_cor "$FASTSURFER_HOME/checkpoints/CCNet_coronal_v0.1.0.pkl" + --ckpt_ax "$FASTSURFER_HOME/checkpoints/CCNet_axial_v0.1.0.pkl" + --cfg_sag "$ccnetdir/config/CCNet_sagittal.yaml" + --cfg_cor "$ccnetdir/config/CCNet_coronal.yaml" + --cfg_ax "$ccnetdir/config/CCNet_axial.yaml") + # specify the subject dir $sd, if asegdkt_segfile explicitly starts with it + if [[ "$sd" == "${cc_segfile:0:${#sd}}" ]] ; then cmd=("${cmd[@]}" --sd "$sd"); fi + echo "${cmd[@]}" 2>&1 | tee -a "$seg_log" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" -ne 0 ]] + then + echo "ERROR: CC Segmentation failed" 2>&1 | tee -a "$seg_log" + exit 1 + fi + + + fi # if [[ ! -f "$merged_segfile" ]] # then