-
Notifications
You must be signed in to change notification settings - Fork 0
/
NB_module.py
46 lines (37 loc) · 1.5 KB
/
NB_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class MeanAct(nn.Module):
def __init__(self):
super(MeanAct, self).__init__()
def forward(self, x):
return torch.clamp(torch.exp(x), min=1e-5, max=1e6)
class DispAct(nn.Module):
def __init__(self):
super(DispAct, self).__init__()
def forward(self, x):
return torch.clamp(F.softplus(x), min=1e-4, max=1e4)
def NB_loss(x, h_r, h_p):
ll = torch.lgamma(torch.exp(h_r) + x) - torch.lgamma(torch.exp(h_r))
ll += h_p * x - torch.log(torch.exp(h_p) + 1) * (x + torch.exp(h_r))
loss = -torch.mean(torch.sum(ll, axis=-1))
return loss
def ZINB_loss(x, mean, disp, pi, scale_factor=1.0, ridge_lambda=0.0):
eps = 1e-10
if isinstance(scale_factor,float):
scale_factor=np.full((len(mean),),scale_factor)
scale_factor = scale_factor[:, None]
mean = mean * scale_factor
t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps)
t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps)))
nb_final = t1 + t2
nb_case = nb_final - torch.log(1.0-pi+eps)
zero_nb = torch.pow(disp/(disp+mean+eps), disp)
zero_case = -torch.log(pi + ((1.0-pi)*zero_nb)+eps)
result = torch.where(torch.le(x, 1e-8), zero_case, nb_case)
if ridge_lambda > 0:
ridge = ridge_lambda*torch.square(pi)
result += ridge
result = torch.mean(result)
return result