Skip to content

Commit

Permalink
update init_data_config with contrast method
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanMolinier committed Mar 11, 2024
1 parent 2190249 commit 273fc47
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 9 deletions.
30 changes: 21 additions & 9 deletions scripts/init_data_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Based on https://github.com/spinalcordtoolbox/disc-labeling-hourglass
"""
Script copied from https://github.com/spinalcordtoolbox/disc-labeling-hourglass
"""

import os
import argparse
Expand All @@ -7,10 +9,7 @@
import itertools
import numpy as np

from utils import CONTRAST, get_img_path_from_mask_path, fetch_contrast


CONTRAST_LOOKUP = {tuple(sorted(value)): key for key, value in CONTRAST.items()}
from utils import get_img_path_from_mask_path, get_cont_path_from_other_cont, fetch_contrast, fetch_subject_and_session


# Determine specified contrasts
Expand All @@ -26,10 +25,17 @@ def init_data_config(args):
file_paths = [os.path.abspath(path.replace('\n', '')) for path in open(args.txt)]
if args.type == 'LABEL':
label_paths = file_paths
img_paths = [get_img_path_from_mask_path(lp) for lp in label_paths]
img_paths = [get_img_path_from_label_path(lp) for lp in label_paths]
file_paths = label_paths + img_paths
elif args.type == 'IMAGE':
img_paths = file_paths
elif args.type == 'CONTRAST':
if not args.cont: # If the target contrast is not specified
raise ValueError(f'When using the type CONTRAST, please specify the target contrast using the flag "--cont"')
img_paths = file_paths
new_contrast = args.cont
label_paths = [get_cont_path_from_other_cont(ip) for ip in img_paths]
file_paths = label_paths + img_paths
else:
raise ValueError(f"invalid args.type: {args.type}")
missing_paths = [
Expand All @@ -50,14 +56,18 @@ def init_data_config(args):
raise ValueError('Please store all the BIDS datasets inside the same parent folder !')

# Look up the right code for the set of contrasts present
contrasts = CONTRAST_LOOKUP[tuple(sorted(set(map(fetch_contrast, img_paths))))]
contrasts = "_".join(tuple(sorted(set(map(fetch_contrast, img_paths)))))

config = {
'TYPE': args.type,
'CONTRASTS': contrasts,
'DATASETS_PATH': dataset_parent_path
}

# Add target contrast when the type CONTRAST is used
if args.type == 'CONTRAST':
config['TARGET_CONTRAST'] = args.cont

# Split into training, validation, and testing sets
split_ratio = (1 - (args.split_validation + args.split_test), args.split_validation, args.split_test) # TRAIN, VALIDATION, and TEST
config_paths = label_paths if args.type == 'LABEL' else img_paths
Expand Down Expand Up @@ -91,8 +101,10 @@ def pairwise(iterable):
## Parameters
parser.add_argument('--txt', required=True,
help='Path to TXT file that contains only image or label paths. (Required)')
parser.add_argument('--type', choices=('LABEL', 'IMAGE'),
help='Type of paths specified. Choices "LABEL" or "IMAGE". (Required)')
parser.add_argument('--type', choices=('LABEL', 'IMAGE', 'CONTRAST'),
help='Type of paths specified. Choices are "LABEL", "IMAGE" or "CONTRAST". (Required)')
parser.add_argument('--cont', type=str, default='',
help='If the type CONTRAST is selected, this variable specifies the wanted contrast for target.')
parser.add_argument('--split-validation', type=float, default=0.1,
help='Split ratio for validation. Default=0.1')
parser.add_argument('--split-test', type=float, default=0.1,
Expand Down
27 changes: 27 additions & 0 deletions scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,33 @@ def get_mask_path_from_img_path(img_path, deriv_sub_folders, short_suffix='_seg'
mask_path.append(paths[0])
return mask_path


def get_cont_path_from_other_cont(str_path, cont):
"""
:param str_path: absolute path to the input nifti img. Example: /<path_to_BIDS_data>/sub-amuALT/anat/sub-amuALT_T1w.nii.gz
:param cont: contrast of the target output image stored in the same data folder. Example: T2w
:return: path to the output target image. Example: /<path_to_BIDS_data>/sub-amuALT/anat/sub-amuALT_T2w.nii.gz
"""
# Load path
path = Path(str_path)

# Extract file extension
ext = ''.join(path.suffixes)

# Remove input contrast from name
path_list = path.name.split('_')
suffixes_pos = [1 if len(part.split('-')) == 1 else 0 for part in path_list]
contrast_idx = suffixes_pos.index(1) # Find suffix

# New image name
img_name = '_'.join(path_list[:contrast_idx]+[cont]) + ext

# Recreate img path
img_path = os.path.join(str(path.parent), img_name)

return img_path

def get_deriv_sub_from_img_path(img_path, derivatives_folder='derivatives'):
"""
This function returns the derivatives path of the subject from an image path or an empty string if the path does not exists. Images need to be stored in a BIDS compliant dataset.
Expand Down

0 comments on commit 273fc47

Please sign in to comment.