-
Notifications
You must be signed in to change notification settings - Fork 5
/
losses.py
80 lines (71 loc) · 3.13 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import logsumexp
def l1_regularizer(model, lambda_l1=0.1):
lossl1 = 0
for model_param_name, model_param_value in model.named_parameters():
if model_param_name.endswith('weight'):
lossl1 += lambda_l1 * model_param_value.abs().sum()
return lossl1
def calc_loss(preds, batch_label, args):
loss_fn = torch.nn.MSELoss()
kl_loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
losses = 0.
if args.kl:
eps = 1e-16
graphs = preds[0][0]
k = len(graphs)
batch_size = graphs[0].shape[0]
##### Total correlation
_logqz = torch.stack(graphs, dim=1)
# print(_logqz.shape): batch_size, # layers, 5, 4
# logqz: 2 latent variables of size 5x4
sample = F.gumbel_softmax(_logqz, dim=-1, hard=True).unsqueeze(1)
# sample.size: batch_size, 1, #layers, 5, 4
_logqz = torch.log((sample * _logqz.unsqueeze(0)).sum(dim=-1) + eps).sum(dim=-1)
# _logqz[i,j] is the log probability of sample i to be generated by input j
# batch_size, batch_size, #layers
logqz_prodmarginals = (logsumexp(_logqz, dim=1, keepdim=False) - math.log(batch_size * dataset_size)).sum(1)
# compute log q(z) ~= log 1/(NM) sum_m=1^M q(z|x_m) = - log(MN) + logsumexp_m(q(z|x_m))
logqz = (logsumexp(_logqz.sum(2), dim=1, keepdim=False) - math.log(batch_size * dataset_size))
# total correlation
kl_loss = args.kl_coef * (logqz - logqz_prodmarginals).mean()
tot_kl_loss += kl_loss.item()
losses += kl_loss
##### Jasnon-shenon Div.
# for i in range(k):
# for j in range(i+1,k):
# M = (graphs[i] + graphs[j])/2
# kl_loss = - 1e-2 * 0.5 * (F.kl_div(torch.log(graphs[i] + eps), M, reduction='batchmean') + \
# F.kl_div(torch.log(graphs[j] + eps), M, reduction='batchmean'))
# tot_kl_loss += kl_loss.item()
# losses += kl_loss
if args.l1:
reg_loss = l1_regularizer(model.encoder.mlp1, args.lambda1)
tot_l1_loss += reg_loss
losses += reg_loss
if args.dis_obj:
for idx in range(args.rollouts//2):
pred = preds[idx][-1]
label = batch_label[:, idx, :args.num_humans, :args.feat_dim]
loss = loss_fn(pred, label)
losses += loss
losses.backward(retain_graph=True)
hook = preds[0][0][0].register_hook(lambda grad: torch.zeros_like(grad))
long_term_losses = 0.
for idx in range(args.rollouts//2, args.rollouts):
pred = preds[idx][-1]
label = batch_label[:, idx, :args.num_humans, :args.feat_dim]
loss = loss_fn(pred, label)
long_term_losses += loss
long_term_losses.backward()
hook.remove()
else:
for idx in range(args.rollouts):
pred = preds[idx][-1]
label = batch_label[:, idx, :args.num_humans, :args.feat_dim]
loss = loss_fn(pred, label)
losses += loss
losses.backward()
return losses