-
Notifications
You must be signed in to change notification settings - Fork 1
/
sbi_train.py
42 lines (35 loc) · 1.32 KB
/
sbi_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
'''This script illustrates how to train an SBI model,
and generates a pickle file which is in the same format as the one used in tutorial.ipynb
'''
import os, sys
import numpy as np
import pickle
# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from sbi import utils as Ut
from sbi import inference as Inference
nhidden = 500 # architecture
nblocks = 15 # architecture
if torch.cuda.is_available(): device = 'cuda'
else: device = 'cpu'
# load training data
# for fitting galaxy photometry: x = thetas; y = fluxes and uncertainties
x_train, y_train =
# train NPE
fanpe = # name for the .pt file where the trained model will be saved
fsumm = # name for the .p file where the training summary will be saved; useful if want to check the convergence, etc.
anpe = Inference.SNPE(
density_estimator=Ut.posterior_nn('maf', hidden_features=nhidden, num_transforms=nblocks),
device=device)
# because we append_simulations, training set == prior
anpe.append_simulations(
torch.as_tensor(x_train.astype(np.float32), device='cpu'),
torch.as_tensor(y_train.astype(np.float32), device='cpu'))
p_x_y_estimator = anpe.train()
# save trained ANPE
torch.save(p_x_y_estimator.state_dict(), fanpe)
# save training summary
pickle.dump(anpe._summary, open(fsumm, 'wb'))
print(anpe._summary)