-
Notifications
You must be signed in to change notification settings - Fork 0
/
sage_oco.py
90 lines (70 loc) · 2.35 KB
/
sage_oco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
from sampling import madowSampling
from sacred import Experiment
from config import initialise
from easydict import EasyDict as edict
ex = Experiment()
ex = initialise(ex)
class sageOCO:
def __init__(self, args):
self.args = args
self.N = args.N
self.k = args.k
self.R = args.R
self.eta = args.eta
def initialize(self):
self.R = np.zeros(self.N)
def calculate_marginals_oco(self, y, vec=None):
R_ = np.sort(self.R)[::-1]
exp_R = np.exp(self.eta * (R_))
tailsum_expR = np.cumsum(exp_R)[::-1] # tailsum
i_star = 0
for i in range(self.k):
if (self.k - i) * exp_R[i] >= (tailsum_expR[i]):
i_star = i
K = (self.k - i_star) / (tailsum_expR[i_star])
p = np.minimum(np.ones(self.N), K * np.exp(self.eta * (self.R)))
if np.abs(p.sum() - self.k) > 1e-3:
print(p.sum())
# gradient update step
self.R[y] += 1 # for all the experiments except monotone set function.
return p
def get_kset(self, y):
p = self.calculate_marginals_oco(y)
return p, madowSampling(self.N, p, self.k)
class sageOCOMonotone:
def __init__(self, args):
self.args = args
self.N = args.N
self.k = args.k
self.R = args.R
self.eta = args.eta
def initialize(self):
self.R = np.zeros(self.N)
def calculate_marginals_oco(self, grad):
R_ = np.sort(self.R)[::-1]
exp_R = np.exp(self.eta * R_)
tailsum_expR = np.cumsum(exp_R)[::-1] # tailsum
i_star = 0
for i in range(self.k):
if (self.k - i) * exp_R[i] >= (tailsum_expR[i]):
i_star = i
K = (self.k - i_star) / (tailsum_expR[i_star])
p = np.minimum(np.ones(self.N), K * np.exp(self.eta * (self.R)))
if np.abs(p.sum() - self.k) > 1e-3:
print(p.sum())
# gradient update step
self.R += grad
return p
def get_kset(self, grad):
p = self.calculate_marginals_oco(grad)
return p, madowSampling(self.N, p, self.k)
@ex.automain
def main(_run):
args = edict(_run.config)
oco = sageOCOMonotone(args)
oco.initialize()
for t in range(args.T):
y = np.random.randint(args.N)
p, kset = oco.get_kset(y)
print(t, y, p.sum(), kset)