forked from haofeixu/aanet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
metric.py
47 lines (35 loc) · 1.02 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import numpy as np
EPSILON = 1e-8
def epe_metric(d_est, d_gt, mask, use_np=False):
d_est, d_gt = d_est[mask], d_gt[mask]
if use_np:
epe = np.mean(np.abs(d_est - d_gt))
else:
epe = torch.mean(torch.abs(d_est - d_gt))
return epe
def d1_metric(d_est, d_gt, mask, use_np=False):
d_est, d_gt = d_est[mask], d_gt[mask]
if use_np:
e = np.abs(d_gt - d_est)
else:
e = torch.abs(d_gt - d_est)
err_mask = (e > 3) & (e / d_gt > 0.05)
if use_np:
mean = np.mean(err_mask.astype('float'))
else:
mean = torch.mean(err_mask.float())
return mean
def thres_metric(d_est, d_gt, mask, thres, use_np=False):
assert isinstance(thres, (int, float))
d_est, d_gt = d_est[mask], d_gt[mask]
if use_np:
e = np.abs(d_gt - d_est)
else:
e = torch.abs(d_gt - d_est)
err_mask = e > thres
if use_np:
mean = np.mean(err_mask.astype('float'))
else:
mean = torch.mean(err_mask.float())
return mean