-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_transformation.py
114 lines (85 loc) · 5.72 KB
/
evaluate_transformation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import sys
import argparse
import os
import json
import csv
import numpy as np
from utils.filemanager import get_points_paths
from utils.logger import logger, pprint
from utils.landmarks import get_landmarks_from_txt, write_landmarks_to_list
from utils.metrics import compute_TRE
if __name__ == "__main__":
# optional arguments from the command line
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_name', type=str, default='elastix_01', help='experiment name')
parser.add_argument('--reg_params_key', type=str, default='Parameter.affine+Parameter.bsplines', help='registration parameters key generated by create_script.py')
parser.add_argument('--output_path', type=str, default='output', help='root dir for output scripts')
parser.add_argument("--generate_report", action='store_true', help='if True, an evaluation report .txt file will be generated. If not, only the transformed keypoints txt file will be generated for each test sample.')
parser.add_argument('--dataset_path', type=str, default='dataset/train', help='root dir for nifti data to get the gt exhale landmarks')
# parse the arguments
args = parser.parse_args()
# create experiment search path
# points is the folder where the transformed points are saved using transformix
args.exp_points_output = os.path.join(args.output_path, args.experiment_name, args.reg_params_key, 'points')
# get a list of all the transformed keypoints files
transformed_points = get_points_paths(args.exp_points_output, "outputpoints", num_occurrences=2)
if len(transformed_points) == 0:
logger.error(f"No transformed points found in {args.exp_points_output} directory.")
sys.exit(1)
# check if generate_report is True
if args.generate_report:
gt_points = get_points_paths(args.dataset_path, "_300_eBH_xyz_r1", num_occurrences=1)
if len(gt_points) == 0:
logger.error(f"No gt points found in {args.dataset_path} directory.")
sys.exit(1)
# Create a list to store the TRE results
tre_results = []
else:
gt_points = [0 for _ in range(len(transformed_points))] # the list has to have values for the zip(*) to return the values inside
logger.info(f"Found {len(transformed_points)} transformed points files for subjects ({[subject.split('/')[-2] for subject in transformed_points]})")
# extract the transformed points from the transformed_points transformix files and save them in a separate file
for transformed_points_file, gt_point in zip(transformed_points, gt_points):
print(f"Processing {transformed_points_file}...")
# get the transformed points
transformed_landmarks = get_landmarks_from_txt(transformed_points_file, search_key='OutputIndexFixed')
# the transformed points has to be 300
assert len(transformed_landmarks) == 300, f"Transformed points file {transformed_points_file} has {len(transformed_landmarks)} points instead of 300."
# write the transformed points to a file
# the points are written inside the same directory as the transformed_points_file
output_landmarks_path = os.path.join(transformed_points_file.replace('outputpoints.txt', ''), 'outputpoints_transformed.txt')
write_landmarks_to_list(transformed_landmarks, output_landmarks_path)
# generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files
if args.generate_report:
sample_name = gt_point.split('/')[-1].split('_')[0] #copd1, copd2, ...
# load the dataset dictionary, we remove the last path element because we want to get the description.json file
with open(os.path.join(args.dataset_path.replace("train", "", 1).replace("test", "", 1),'description.json'), 'r') as json_file:
dictionary = json.loads(json_file.read())
file_information = dictionary[args.dataset_path.replace('\\', '/').split("/")[-1]][sample_name]
print(file_information)
TRE_mean, TRE_std = compute_TRE(output_landmarks_path, gt_point, tuple(file_information['voxel_dim']))
print("TRE (After Registration):- ", f"(Mean TRE: {TRE_mean})", f"(STD TRE: {TRE_std}). \n")
# Append TRE results to the list
tre_results.append({'sample_name': sample_name, 'TRE_mean': TRE_mean, 'TRE_std': TRE_std})
# generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files
if args.generate_report:
# write the TRE results to a csv file for each sample
output_csv_path = os.path.join(args.exp_points_output, 'TRE_sample_results.csv')
with open(output_csv_path, 'w', newline='') as csv_file:
fieldnames = ['sample_name', 'TRE_mean', 'TRE_std']
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
# Write the header
writer.writeheader()
# Write the data
for result in tre_results:
writer.writerow(result)
# write the overall mean results
TRE_mean_list = [result['TRE_mean'] for result in tre_results]
TRE_std_list = [result['TRE_std'] for result in tre_results]
output_csv_path = os.path.join(args.exp_points_output, 'TRE_overall_results.csv')
with open(output_csv_path, 'w', newline='') as csv_file:
fieldnames = ['Overall mean (TRE_mean)', 'Overall mean (TRE_std)']
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
# Write the header
writer.writeheader()
# Write the data
writer.writerow({'Overall mean (TRE_mean)': np.mean(TRE_mean_list), 'Overall mean (TRE_std)': np.mean(TRE_std_list)})