-
Notifications
You must be signed in to change notification settings - Fork 22
/
run.py
174 lines (135 loc) · 5.06 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from multiprocessing import Process
import cfar
import math
import matplotlib.pyplot as plt
import numpy as np
from skimage.draw import ellipse
from skimage.measure import label, regionprops
from skimage.transform import rotate
import glob
import sys
import shutil
import os
data_dir = '/media/nasir/Drive1/code/SAR/python_cfar/SAR-Ship-Dataset'
def union_area(a,b):
x = min(a[0], b[0])
y = min(a[1], b[1])
w = max(a[0]+a[2], b[0]+b[2]) - x
h = max(a[1]+a[3], b[1]+b[3]) - y
return w*h
def intersection_area(a,b):
x = max(a[0], b[0])
y = max(a[1], b[1])
w = min(a[0]+a[2], b[0]+b[2]) - x
h = min(a[1]+a[3], b[1]+b[3]) - y
if w < 0 or h < 0:
return 0
else:
return w * h
def str2int(a):
return [int(x) for x in a]
def extract_boxes(fname):
with open(fname) as f:
content = f.readlines()
f.close()
content = [x.strip() for x in content]
content = [str2int(x.split(' ')[-4:]) for x in content]
return content
def get_precision_recall(threshold):
paths = glob.glob(f"{data_dir}/detection-results/*.txt")
files_stats = {}
falseNegative = 0
truePositive = 0
falsePositive = 0
trueNegative = 0
for index, path in enumerate(paths):
pred_bboxes = extract_boxes(path)
gt_bboxes = extract_boxes(path.replace('detection-results', 'ground-truth'))
fp = 0; tp = 0; fn = 0
box_index_of_tp = []
for index_g, gt_box in enumerate(gt_bboxes):
ious = []
for index_p, pred_box in enumerate(pred_bboxes):
iou = intersection_area(gt_box, pred_box) / union_area(gt_box, pred_box)
if iou > threshold:
box_index_of_tp.append(index_p)
ious.append(iou)
if len(ious) == 0:
fn+=1
elif len(ious) > 0:
tp+=1
diff = len(pred_bboxes) - (len(list(set(box_index_of_tp))))
if diff > 0:
fp+=diff
falseNegative+=fn
truePositive+=tp
falsePositive+=fp
files_stats[path.split('/')[-1].split('.')[0]] = {
"falseNegative": fn,
"truePositive": tp,
"falsePositive": fp
}
sys.stdout.write(f"\r {index + 1} / {len(paths)}")
sys.stdout.flush()
print(f"\n\nfalsePositives: {falsePositive} , truePositives: {truePositive} , falseNegatives: {falseNegative}")
recall = truePositive / (truePositive + falseNegative)
precision = truePositive / (truePositive + falsePositive)
return precision, recall
def copy(path, source, dest='land'):
os.system(f"cp -r {path} {path.replace(source, dest)}")
def predict(paths, data_dir, source, dest, i):
CACFAR = cfar.ca_cfar
background_size = 100
guard_size = 40
cut_size = 30
threshold_factor = 1.55
for index in range(0, len(paths)):
# for index in range(0, 12):
path = paths[index]
output_file = path.replace(f"{source}", f'{dest}')
box_file = path.replace(f"{source}", 'detection-results').replace('.jpg', '.txt')
gt_file = box_file.replace('detection-results', 'ground-truth')
CACFAR(path, output_file, box_file, gt_file, background_size, guard_size, cut_size, threshold_factor)
sys.stdout.write(f'\r {i}: {index + 1} / {len(paths)}')
sys.stdout.flush()
sys.stdout.write(f"\r {i}: Done\n")
sys.stdout.flush()
# sys.stdout.write(f"\n\n")
# sys.stdout.flush()
if __name__ == "__main__":
# source = "subset"
source = "JPEGImages"
dest = "boxes_drawn"
num_of_process = 20
paths = glob.glob(f"{data_dir}/{source}/*.jpg")
os.path.exists(f'{data_dir}/detection-results') and shutil.rmtree(f'{data_dir}/detection-results')
os.path.exists(f'{data_dir}/{dest}') and shutil.rmtree(f'{data_dir}/{dest}')
if not os.path.exists(f'{data_dir}/detection-results'):
os.mkdir(f'{data_dir}/detection-results')
print(f"Directory detection-results Created ")
if not os.path.exists(f'{data_dir}/{dest}'):
os.mkdir(f'{data_dir}/{dest}')
print(f"Directory {dest} Created ")
proceses = []
paths_per_process = len(paths) // num_of_process
for i in range(0, num_of_process):
start = paths_per_process*i
end = (i+1)*paths_per_process
p = Process(target=predict, args=(paths[start : end], data_dir, source, dest, i))
proceses.append(p)
p.start()
if end + 1 < len(paths):
p = Process(target=predict, args=(paths[end:], data_dir, source, dest, num_of_process))
proceses.append(p)
p.start()
for p in proceses:
p.join()
# predict(paths, data_dir, source, dest)
# thresholds = [0.4]
# precisions = []
# recalls = []
# for threshold in thresholds:
# precision, recall = get_precision_recall(threshold)
# precisions.append(precision)
# recalls.append(recall)
# print(f"\nthreshold: {threshold} recall: {round(recall * 100, 2)}% precision: {round(precision*100, 2)}% \n")