-
Notifications
You must be signed in to change notification settings - Fork 0
/
torch_Python-PLN.qmd
130 lines (104 loc) · 3.54 KB
/
torch_Python-PLN.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
---
title: "PLN version Python"
lang: fr
execute:
freeze: auto
output:
html_document:
toc: true
toc_float: true
---
## PLN with Pytorch
Comme vous pourrez le constater, les syntaxtes {torch} et Pytorch sont très proches.
### Préliminaires
On charge le jeu de données `oaks` contenu dans le package python `pyPLNmodels`
```{python load-plnmodels, messages=FALSE}
import pyPLNmodels
import numpy as np
import matplotlib.pyplot as plt
from pyPLNmodels.models import PlnPCAcollection, Pln
from pyPLNmodels.oaks import load_oaks
oaks = load_oaks()
Y = oaks['counts']
O = np.log(oaks['offsets'])
X = np.ones([Y.shape[0],1])
```
Pour référence, on optimise avec le package dédié (qui utilise pytorch et l'optimiseur Rprop.
```{python oaks-pyplnmodels}
pln = Pln.from_formula("counts ~ 1 ", data = oaks, take_log_offsets = True)
%timeit pln.fit()
```
### Implémentation simple en Pytorch
```{python simple-plnclass}
import torch
import numpy as np
import math
def _log_stirling(integer: torch.Tensor) -> torch.Tensor:
integer_ = integer + (integer == 0) # Replace 0 with 1 since 0! = 1!
return torch.log(torch.sqrt(2 * np.pi * integer_)) + integer_ * torch.log(integer_ / math.exp(1))
class PLN() :
Y : torch.Tensor
O : torch.Tensor
X : torch.Tensor
n : int
p : int
d : int
M : torch.Tensor
S : torch.Tensor
B : torch.Tensor
Sigma : torch.Tensor
Omega : torch.Tensor
ELBO_list : list
## Constructor
def __init__(self, Y: np.array, O: np.array, X: np.array) :
self.Y = torch.tensor(Y)
self.O = torch.tensor(O)
self.X = torch.tensor(X)
self.n, self.p = Y.shape
self.d = X.shape[1]
## Variational parameters
self.M = torch.full(Y.shape, 0.0, requires_grad = True)
self.S = torch.full(Y.shape, 1.0, requires_grad = True)
## Model parameters
self.B = torch.zeros(self.d, self.p, requires_grad = True)
self.Sigma = torch.eye(self.p)
self.Omega = torch.eye(self.p)
def get_Sigma(self) :
return 1/self.n * (self.M.T @ self.M + torch.diag(torch.sum(self.S**2, dim = 0)))
def get_ELBO(self):
S2 = torch.square(self.S)
XB = self.X @ self.B
A = torch.exp(self.O + self.M + XB + S2/2)
elbo = self.n/2 * torch.logdet(self.Omega) + torch.sum(- A + self.Y * (self.O + self.M + XB) + .5 * torch.log(S2)) - .5 * torch.trace(self.M.T @ self.M + torch.diag(torch.sum(S2, dim = 0)) @ self.Omega) + .5 * self.n * self.p - torch.sum(_log_stirling(self.Y))
return elbo
def fit(self, N_iter, lr, tol = 1e-8) :
self.ELBO = np.zeros(N_iter)
optimizer = torch.optim.Rprop([self.B, self.M, self.S], lr = lr)
objective0 = np.infty
for i in range(N_iter):
## reinitialize gradients
optimizer.zero_grad()
## compute current ELBO
loss = - self.get_ELBO()
## backward propagation and optimization
loss.backward()
optimizer.step()
## update parameters with close form
self.Sigma = self.get_Sigma()
self.Omega = torch.inverse(self.Sigma)
objective = -loss.item()
self.ELBO[i] = objective
if (abs(objective0 - objective)/abs(objective) < tol):
self.ELBO = self.ELBO[0:i]
break
else:
objective0 = objective
```
### Évaluation du temps de calcul
Testons notre implémentation simple de PLN utilisant:
```{python run-simplepln}
myPLN = PLN(Y, O, X)
%timeit myPLN.fit(50, lr = 0.1, tol = 1e-8)
plt.plot(np.log(-myPLN.ELBO))
```
```