-
Notifications
You must be signed in to change notification settings - Fork 1
/
losses.py
107 lines (84 loc) · 2.15 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import prox
import numba
class Loss:
L = None
name = None
@staticmethod
def get(yhat, y):
raise NotImplementedError
@staticmethod
def grad(yhat, y):
raise NotImplementedError
@staticmethod
def conj(u, y):
raise NotImplementedError
@staticmethod
def prox(u, y, gamma):
raise NotImplementedError
class HingeLoss(Loss):
L = 1
name = 'Hinge'
@staticmethod
# @numba.jit(nopython=True)
def get(yhat, y):
return np.maximum(1 - yhat * y, 0)
@staticmethod
# @numba.jit(nopython=True)
def grad(yhat, y):
if not (isinstance(y, tuple) or isinstance(y, np.ndarray)):
return 0 if 1 <= yhat*y else -y
d = -y
d[1 <= yhat * y] = 0
return d
@staticmethod
# @numba.jit(nopython=True)
def conj(u, y):
n = u.shape[0]
res = n * u / y
res[(-1 > res) & (res > 0)] = np.inf
res = np.mean(res)
if res == np.inf:
raise ValueError("infinite value in conjugate loss")
return res
@staticmethod
# @numba.jit(nopython=True)
def prox(u, y, gamma):
return prox.prox_G_hinge_numba(u, y, gamma)
class AbsoluteLoss(Loss):
L = 1
name = 'Abs'
@staticmethod
#@numba.jit(nopython=True)
def get(yhat, y):
return np.abs(yhat - y)
@staticmethod
#@numba.jit(nopython=True)
def grad(yhat, y):
return np.sign(yhat - y)
@staticmethod
#@numba.jit(nopython=True)
def conj(u, y):
n = u.shape[0]
res = n * u * y
res[np.abs(n * u) > 1] = np.inf
res_store = res
res = np.mean(res)
if res == np.inf:
raise ValueError("infinite value in conjugate loss")
return res
@staticmethod
#@numba.jit(nopython=True)
def prox(u, y, gamma):
return prox.prox_G_abs_numba(u, y, gamma)
class MSE(Loss):
L = 1
name = 'MSE'
@staticmethod
#@numba.jit(nopython=True)
def get(yhat, y):
return 0.5*((yhat - y)**2)
@staticmethod
#@numba.jit(nopython=True)
def grad(yhat, y):
return yhat - y