-
Notifications
You must be signed in to change notification settings - Fork 1
/
priors.py
36 lines (26 loc) · 889 Bytes
/
priors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python3
from typing import Callable, Iterator
import torch
import torch.nn as nn
from torch.nn import Parameter
ParamFn = Callable[[bool], Iterator[Parameter]]
class Prior(nn.Module):
def __init__(self, params: ParamFn):
super().__init__()
self.params = params
def log_prob(self):
raise NotImplementedError
def nn_loss(self):
raise NotImplementedError
class Gaussian(Prior):
def __init__(self, params: ParamFn, prior_precision: float = 0.001):
super().__init__(params=params)
self.prior_precision = prior_precision
def log_prob(self):
raise NotImplementedError
def nn_loss(self):
squared_params = torch.cat(
[torch.square(param.view(-1)) for param in self.params()]
)
l2r = 0.5 * torch.sum(squared_params)
return self.prior_precision * l2r