-
Notifications
You must be signed in to change notification settings - Fork 7
/
notes.txt
96 lines (69 loc) · 3.05 KB
/
notes.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import ssm
# Who is the audience?
# - us: great to have a standardized research tool
# - people like us
# - technically savvy experimentalists
# - experimentalists: consumers and users of ssm's
# method vs model(**) developers
# Question: what is model?
# - object with parameters as class properties
# - a collection functions (dynamics, emission) and a dictionary (params)
# what are signatures?
# dynamics(state, params, covariates) -> p(next_state)
# emissions(state, params, covariates) -> p(observations)
model = ssm.models.HMM(...)
data = ssm.sample(model, covariates, hypers)
# library of common inference subroutines
expectations = ssm.inference.core.hmm_message_passing(model, data)
expectations = ssm.inference.core.lds_message_passing(model, data)
# inference_alg(model, dataset, [proposal/approx_family], [hypers], [callback]) -> posterior and trace
# what is a dataset?
# - iterable of length number of time series
# - each entry dict(data=[TxN_obs], covariates=[T,N_in], metadata=...., mask=...)
posterior, trace = ssm.inference.em(model, dataset)
posterior, trace = ssm.inference.gibbs(model, dataset)
posterior, trace = ssm.inference.smc(model, dataset, propose_and_weight)
posterior, trace = ssm.inference.variational_em(model, dataset, variational_posterior_family)
posterior, trace = ssm.inference.svae(model, dataset, recognition_network)
posterior, trace = ssm.inference.fivo(model, dataset, propose_and_weight)
# Possibility: all inference algorithms split into
# E step: (approximately) compute the posterior distribution for each time series in batch
# M step: update model parameters given posteriors for that batch
# or... smart decision tree for guessing the right algorithm
posterior, trace = ssm.inference.fit(model, dataset, init_params)
# evaluation
posterior.marginal_likelihood
# cross validation
posterior, trace = ssm.inference.approximate_posterior(model, data)
# plot posterior summaries
# plot inferred parameters
# evaluate inferred model parameters
# What models?
# - HMM, LDS, SLDS, rSLDS
# - stochastic RNNs (nonlinear dynamical system)
# What algorithms?
# - EM // now you need some extra methods
# - SVAE
# - stochastic RNN
# Question: how does the inference algorithm determine whether or not it applies
# Wish list
# - operate on a single time series (vmap over batch)
# - jit at as high a level as possible e.g. jit(ssm.inference)
# - no for loops
TODO:
- stochastic EM for HMM
- ARHMM
- GLM distribution object like the GaussianLinearRegression object
- Posterior: multivariate gaussian block tridiagonal
Linear potential is h
block diagonals, block lower diagonals ==>
log_normalizer_constants ==> log constant terms bundled into marginal likelihood
log p(x|y)
log p(x, y) - log p(x|y)
forward filtering is equivalent to Thomas' algorithm for linear dynamical system
- investigate this more
how to get determinant of a block-tridiagonal matrix?
- some recurrence relation
whether or not posterior is good determines how we want to compute the ELBO?
- entropy versus subtracting off the marginal likelihood
moving toward class-based architecture