Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new scripts to analyze BIDS compliant datasets #40

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions scripts/analyze_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import argparse
import json
import glob
from progress.bar import Bar
import csv

from utils import get_img_path_from_mask_path, get_mask_path_from_img_path, edit_metric_dict, save_graphs, change_mask_suffix, get_deriv_sub_from_img_path, str_to_float_list, str_to_str_list, mergedict

def run_analysis(args):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is slightly "heavy".
Would it be possible to split it into several smaller functions (10-20 lines per function)? Each one with a self-explaning docstring.
Also, some nested for loops and if-else conditions are hard to follow. Adding comments would make them easier to follow.

"""
Run analysis on a config file
"""

short_suffix_disc = '_label'
short_suffix_seg = '_seg'
derivatives_folder = 'derivatives'
output_folder = 'results'

if not os.path.exists(output_folder):
os.makedirs(output_folder)

if args.config:
data_form = 'split'
# Read json file and create a dictionary
with open(args.config, "r") as file:
config_data = json.load(file)

if config_data['TYPE'] == 'LABEL':
isImage = False
elif config_data['TYPE'] == 'IMAGE':
isImage = True
else:
raise ValueError(f'config with unknown TYPE {config_data['TYPE']}')

# Remove keys that are not lists of paths
keys = list(config_data.keys())
for key in keys:
if key not in ['TRAINING', 'VALIDATION', 'TESTING']:
del config_data[key]

elif args.paths_to_bids:
data_form = 'dataset'
config_data = {}
for path_bids in args.paths_to_bids:
files = glob.glob(path_bids + "/**/" + "*.nii.gz", recursive=True) # Get all niftii files
config_data[os.path.basename(os.path.normpath(path_bids))] = [f for f in files if derivatives_folder not in f] # Remove masks from derivatives folder
isImage = True

elif args.paths_to_csv:
data_form = 'dataset'
config_data = {}
else:
raise ValueError(f"Need to specify either args.paths_to_bids, args.config or args.paths_to_csv !")

# Initialize metrics dictionary
metrics_dict = dict()

if args.paths_to_csv:
for path_csv in args.paths_to_csv:
dataset_name = os.path.basename(path_csv).split('_')[-1].split('.csv')[0]
metrics_dict[dataset_name] = {}
with open(path_csv) as csv_file:
reader = csv.reader(csv_file)
for k, v in dict(reader).items():
metric = k.split('_')
if len(metric) == 2:
metric_name, metric_value = metric
if metric_name not in metrics_dict[dataset_name].keys():
metrics_dict[dataset_name][metric_name] = {metric_value:int(v)}
else:
metrics_dict[dataset_name][metric_name][metric_value] = int(v)
else:
if k.startswith('mismatch'):
metrics_dict[dataset_name][k] = int(v)
else:
metrics_dict[dataset_name][k] = str_to_str_list(v)

# Initialize data finguerprint
fprint_dict = dict()

if config_data.keys():
missing_data = []
# Extract information from the data
for key in config_data.keys():
metrics_dict[key] = dict()
fprint_dict[key] = dict()

# Init progression bar
bar = Bar(f'Analyze data {key} ', max=len(config_data[key]))

for path in config_data[key]:
if isImage:
img_path = path # str
deriv_sub_folders = get_deriv_sub_from_img_path(img_path=img_path, derivatives_folder=derivatives_folder) # list of str
seg_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_seg, deriv_sub_folders=deriv_sub_folders, counterexample=['lesion', 'GM', 'WM']) # list of str
discs_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_disc, deriv_sub_folders=deriv_sub_folders, counterexample=['compression', 'SC_mask', 'seg', 'lesion', 'GM', 'WM']) # list of str
else:
img_path = get_img_path_from_mask_path(path, derivatives_folder=derivatives_folder)
deriv_sub_folders = [os.path.dirname(path)]
# Extract field of view information thanks to discs labels
if short_suffix_disc in path:
discs_paths = [path]
seg_paths = [change_mask_suffix(discs_paths, short_suffix=short_suffix_seg)]
elif short_suffix_seg in path:
seg_paths = [path]
discs_paths = [change_mask_suffix(seg_paths, short_suffix=short_suffix_disc)]
else:
seg_paths = [change_mask_suffix(path, short_suffix=short_suffix_seg)]
discs_paths = [change_mask_suffix(path, short_suffix=short_suffix_disc)]

# Extract data
if os.path.exists(img_path):
metrics_dict[key], fprint_dict[key] = edit_metric_dict(metrics_dict[key], fprint_dict[key], img_path, seg_paths, discs_paths, deriv_sub_folders)
else:
missing_data.append(img_path)

# Plot progress
bar.suffix = f'{config_data[key].index(path)+1}/{len(config_data[key])}'
bar.next()
bar.finish()

# Store csv with computed metrics
if args.create_csv:
# Based on https://stackoverflow.com/questions/8685809/writing-a-dictionary-to-a-csv-file-with-one-line-for-every-key-value
out_csv_folder = os.path.join(output_folder, 'files')
if not os.path.exists(out_csv_folder):
os.makedirs(out_csv_folder)
csv_path_sum = os.path.join(out_csv_folder, f'computed_metrics_{key}.csv')
with open(csv_path_sum, 'w') as csv_file:
writer = csv.writer(csv_file)
for metric_name, metric in sorted(metrics_dict[key].items()):
if isinstance(metric,dict):
for metric_value, count in sorted(metric.items()):
k = f'{metric_name}_{metric_value}'
writer.writerow([k, count])
else:
writer.writerow([metric_name, metric])

# Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark
csv_path_fprint = os.path.join(out_csv_folder, f'fprint_{key}.csv')
sub_list = [sub for sub in fprint_dict[key].keys() if sub.startswith('sub')]
fields = ['subject'] + [k for k in fprint_dict[key][sub_list[0]].keys()]
with open(csv_path_fprint, 'w') as f:
w = csv.DictWriter(f, fields)
w.writeheader()
for k, v in fprint_dict[key].items():
w.writerow(mergedict({'subject': k},v))


if missing_data:
print("missing files:\n" + '\n'.join(missing_data))

# Plot data informations
save_graphs(output_folder=output_folder, metrics_dict=metrics_dict, data_form=data_form)







if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Analyse config file')

## Parameters
parser.add_argument('--paths-to-bids', default='', nargs='+',
help='Paths to BIDS compliant datasets (You can add multiple paths using spaces)')
parser.add_argument('--config', default='',
help='Path to JSON config file that contains all the training splits')
parser.add_argument('--paths-to-csv', default='', nargs='+',
help='Paths to csv files with already computed metrics (You can add multiple paths using spaces)')
parser.add_argument('--split', default='ALL', choices=('TRAINING', 'VALIDATION', 'TESTING', 'ALL'),
help='Split of the data that will be analysed (default="ALL")')
parser.add_argument('--create-csv', default=True,
help='Store computed metrics using a csv file in results/files (default=True)')
Comment on lines +168 to +180
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it would be more "pythonic" to include these lines under a separate function, for example, get_parser, example here and here.


# Start analysis
run_analysis(parser.parse_args())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a main() function (example here) and call run_analysis from it.

86 changes: 86 additions & 0 deletions scripts/init_data_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
import argparse
import random
import json
import itertools

from utils import CONTRAST, get_img_path_from_label_path, fetch_contrast


CONTRAST_LOOKUP = {tuple(sorted(value)): key for key, value in CONTRAST.items()}


# Determine specified contrasts
def init_data_config(args):
"""
Create a JSON configuration file from a TXT file where images paths are specified
"""
if (args.split_validation + args.split_test) >= 1:
raise ValueError("The sum of the ratio between testing and validation cannot exceed 1")

# Get input paths, could be label files or image files,
# and make sure they all exist.
file_paths = [os.path.abspath(path.replace('\n', '')) for path in open(args.txt)]
if args.type == 'LABEL':
label_paths = file_paths
img_paths = [get_img_path_from_label_path(lp) for lp in label_paths]
file_paths = label_paths + img_paths
elif args.type == 'IMAGE':
img_paths = file_paths
else:
raise ValueError(f"invalid args.type: {args.type}")
missing_paths = [
path for path in file_paths
if not os.path.isfile(path)
]
if missing_paths:
raise ValueError("missing files:\n" + '\n'.join(missing_paths))

# Look up the right code for the set of contrasts present
contrasts = CONTRAST_LOOKUP[tuple(sorted(set(map(fetch_contrast, img_paths))))]

config = {
'TYPE': args.type,
'CONTRASTS': contrasts,
}

# Split into training, validation, and testing sets
split_ratio = (1 - (args.split_validation + args.split_test), args.split_validation, args.split_test) # TRAIN, VALIDATION, and TEST
config_paths = label_paths if args.type == 'LABEL' else img_paths
random.shuffle(config_paths)
splits = [0] + [
int(len(config_paths) * ratio)
for ratio in itertools.accumulate(split_ratio)
]
for key, (begin, end) in zip(
['TRAINING', 'VALIDATION', 'TESTING'],
pairwise(splits),
):
config[key] = config_paths[begin:end]

# Save the config
config_path = args.txt.replace('.txt', '') + '.json'
json.dump(config, open(config_path, 'w'), indent=4)

def pairwise(iterable):
# pairwise('ABCDEFG') --> AB BC CD DE EF FG
# based on https://docs.python.org/3.11/library/itertools.html
a, b = itertools.tee(iterable)
next(b, None)
return zip(a, b)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Create config JSON from a TXT file which contains list of paths')

## Parameters
parser.add_argument('--txt', required=True,
help='Path to TXT file that contains only image or label paths. (Required)')
parser.add_argument('--type', choices=('LABEL', 'IMAGE'),
help='Type of paths specified. Choices "LABEL" or "IMAGE". (Required)')
parser.add_argument('--split-validation', type=float, default=0.1,
help='Split ratio for validation. Default=0.1')
parser.add_argument('--split-test', type=float, default=0.1,
help='Split ratio for testing. Default=0.1')

init_data_config(parser.parse_args())
Loading