forked from leftthomas/CapsNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
24 lines (16 loc) · 746 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn.functional as F
from torch import nn
class CapsuleLoss(nn.Module):
def __init__(self):
super(CapsuleLoss, self).__init__()
self.reconstruction_loss = nn.MSELoss(size_average=False)
def forward(self, images, labels, classes, reconstructions):
left = F.relu(0.9 - classes, inplace=True) ** 2
right = F.relu(classes - 0.1, inplace=True) ** 2
margin_loss = labels * left + 0.5 * (1. - labels) * right
margin_loss = margin_loss.sum()
reconstruction_loss = self.reconstruction_loss(reconstructions, images)
return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)
if __name__ == "__main__":
digit_loss = CapsuleLoss()
print(digit_loss)