-
Notifications
You must be signed in to change notification settings - Fork 2
/
estimator.py
29 lines (29 loc) · 867 Bytes
/
estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def EM_estimate(dic_prob):
values = []
labels = []
for label in dic_prob.keys():
g0 = dic_prob[label]["g"]
p0 = dic_prob[label]["p"]
p0_ = dic_prob[label]["n"]
P0_ = g0.mean()
labels.append(label)
values.append((g0,p0,p0_,P0_))
delta =1
while delta > 1e-3:
numes = []
alpha_dict = {}
for value in values:
g0,p0,p0_,P0_ = value
P0_ini = P0_
nume0 = (P0_*p0)/g0.mean()
numes.append(nume0)
denom = sum(numes)
delta = 0
for i in range(len(numes)):
nume = numes[i]
g0,p0,p0_,P0_ini = values[i]
P0_ = (nume/denom).mean()
values[i] = (g0,p0,p0_,P0_)
delta = delta + abs(P0_-P0_ini)
alpha_dict.update({labels[i]:P0_})
return alpha_dict