Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
fix: LKJ example with Turing 0.29.1+ (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
storopoli authored Oct 3, 2023
1 parent 71aec5e commit 039766b
Showing 1 changed file with 46 additions and 71 deletions.
117 changes: 46 additions & 71 deletions turing/12-hierarchical_varying_intercept_slope-cheese.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# TODO: This model currently does not sample.
# It needs either https://github.com/TuringLang/Turing.jl/issues/1629
# or https://github.com/TuringLang/Bijectors.jl/issues/134
# or https://github.com/JuliaStats/PDMats.jl/issues/132
# Currently it uses a workaround with TransformVariables.jl and PDMats.jl thanks to @sethaxen
using Turing
using CSV
using DataFrames
Expand Down Expand Up @@ -42,45 +37,12 @@ y = standardize(ZScoreTransform, y; dims=1)
idx = cheese[:, :background_int]

# define the model
# @model function correlated_varying_intercept_slope_regression(X, idx, y;
# predictors=size(X, 2),
# N=size(X, 1),
# n_gr=length(unique(idx)))
# # priors
# Ω ~ LKJCholesky(predictors, 2.0)
# σ ~ Exponential(1)

# # prior for variance of random intercepts and slopes
# # usually requires thoughtful specification
# τ ~ filldist(truncated(Cauchy(0, 2); lower=0), predictors) # group-level SDs
# γ ~ filldist(Normal(0, 5), predictors, n_gr) # matrix of group coefficients
# Z ~ filldist(Normal(0, 1), predictors, n_gr) # matrix of non-centered group coefficients

# # reconstruct β from Ω and τ
# β = γ + τ .* Ω.L * Z

# # likelihood
# for i in 1:N
# y[i] ~ Normal(X[i, :] ⋅ β[:, idx[i]], σ)
# end
# return(; y, β, σ, Ω, τ, γ, Z)
# end

# workaround with TransformVariables.jl and PDMats.jl
using TransformVariables
using PDMats

@model function correlated_varying_intercept_slope_regression(X, idx, y;
predictors=size(X, 2),
N=size(X, 1),
n_gr=length(unique(idx)))
# workaround
trans = CorrCholeskyFactor(predictors)
L_tilde ~ filldist(Flat(), dimension(trans))
L_U, logdetJ = transform_and_logjac(trans, L_tilde)
Turing.@addlogprob! logpdf(LKJCholesky(predictors, 2.0), Cholesky(L_U)) + logdetJ

# priors
Ω ~ LKJCholesky(predictors, 2.0)
σ ~ Exponential(1)

# prior for variance of random intercepts and slopes
Expand All @@ -90,9 +52,7 @@ using PDMats
Z ~ filldist(Normal(0, 1), predictors, n_gr) # matrix of non-centered group coefficients

# reconstruct β from Ω and τ
Ω_L = LowerTriangular(collect.* L_U')) # collect is necessary for ReverseDiff for some reason
Ω = PDMat(Cholesky(Ω_L))
β = γ + Ω * Z
β = γ + τ .* Ω.L * Z

# likelihood
for i in 1:N
Expand All @@ -108,32 +68,47 @@ model = correlated_varying_intercept_slope_regression(X, idx, y)
chn = sample(model, NUTS(1_000, 0.8), MCMCThreads(), 1_000, 4)

# results:
# parameters mean std naive_se mcse ess rhat ess_per_sec
# Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
# L_tilde[1] 0.0134 0.8729 0.0138 0.0219 1978.9056 1.0014 1.8512
# L_tilde[2] -0.0077 0.8509 0.0135 0.0178 2059.1537 1.0008 1.9263
# L_tilde[3] 0.0763 1.0059 0.0159 0.0255 1659.5684 1.0005 1.5525
# L_tilde[4] 0.0210 0.9055 0.0143 0.0275 1040.8284 1.0023 0.9737
# L_tilde[5] -0.0190 0.9752 0.0154 0.0228 1674.3909 1.0004 1.5664
# L_tilde[6] 0.0632 1.1176 0.0177 0.0291 1439.6238 1.0039 1.3468
# σ 0.6084 0.0350 0.0006 0.0009 1801.6933 1.0029 1.6855
# τ[1] 0.9784 0.7305 0.0116 0.0272 579.6022 1.0057 0.5422
# τ[2] 0.8940 0.6539 0.0103 0.0184 1150.8453 1.0074 1.0766
# τ[3] 0.9722 0.7015 0.0111 0.0284 590.7927 1.0149 0.5527
# τ[4] 0.9602 0.7077 0.0112 0.0260 736.0099 1.0144 0.6885
# γ[1,1] 0.6740 1.4735 0.0233 0.0536 706.3475 1.0062 0.6608
# γ[2,1] -0.9898 1.2884 0.0204 0.0452 723.9490 1.0027 0.6772
# γ[3,1] 1.0575 1.5200 0.0240 0.0749 381.3226 1.0023 0.3567
# γ[4,1] 0.2115 1.5138 0.0239 0.0522 654.1878 1.0014 0.6120
# γ[1,2] -0.0924 1.4938 0.0236 0.0603 586.5729 1.0047 0.5487
# γ[2,2] -1.3776 1.2600 0.0199 0.0451 694.1231 1.0086 0.6493
# γ[3,2] 0.4357 1.4541 0.0230 0.0639 425.2658 1.0057 0.3978
# γ[4,2] -0.1971 1.3396 0.0212 0.0362 656.2398 1.0019 0.6139
# Z[1,1] 0.0057 0.9390 0.0148 0.0251 1210.1329 1.0013 1.1321
# Z[2,1] 0.0383 0.9775 0.0155 0.0328 916.9673 1.0009 0.8578
# Z[3,1] 0.0385 0.9365 0.0148 0.0264 1536.4146 0.9996 1.4373
# Z[4,1] 0.0779 0.9458 0.0150 0.0232 1283.2695 1.0020 1.2005
# Z[1,2] 0.0255 0.9185 0.0145 0.0264 1241.2293 1.0004 1.1612
# Z[2,2] -0.0534 0.9356 0.0148 0.0289 1346.0092 1.0020 1.2592
# Z[3,2] -0.0185 0.9483 0.0150 0.0266 1271.8442 1.0008 1.1898
# Z[4,2] 0.0281 0.9088 0.0144 0.0182 1510.2904 1.0004 1.4129
# Summary Statistics
# parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
# Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
# Ω.L[1,1] 1.0000 0.0000 NaN NaN NaN NaN NaN
# Ω.L[2,1] 0.0077 0.3529 0.0055 3960.8548 2002.1197 1.0017 5.1527
# Ω.L[3,1] -0.0030 0.3525 0.0073 2318.7503 2318.8227 1.0047 3.0165
# Ω.L[4,1] 0.0057 0.3630 0.0062 3418.7403 2347.2750 1.0004 4.4474
# Ω.L[5,1] 0.0033 0.3536 0.0056 4052.7561 2195.1157 1.0021 5.2722
# Ω.L[2,2] 0.9318 0.0848 0.0022 1140.1517 503.9699 1.0019 1.4832
# Ω.L[3,2] -0.0051 0.3544 0.0061 3350.9388 2521.8811 1.0010 4.3592
# Ω.L[4,2] 0.0069 0.3598 0.0064 3169.0739 2424.1733 1.0013 4.1226
# Ω.L[5,2] -0.0015 0.3436 0.0058 3500.7859 2434.4279 1.0007 4.5542
# Ω.L[3,3] 0.8577 0.1206 0.0031 1675.2394 2014.9484 1.0014 2.1793
# Ω.L[4,3] -0.0044 0.3626 0.0092 1342.8179 324.2139 1.0049 1.7469
# Ω.L[5,3] 0.0153 0.3593 0.0093 1308.5168 377.1732 1.0046 1.7022
# Ω.L[4,4] 0.7640 0.1534 0.0042 1427.1434 1922.1280 1.0024 1.8566
# Ω.L[5,4] 0.0076 0.3580 0.0058 3683.6774 2416.5623 1.0029 4.7921
# Ω.L[5,5] 0.6856 0.1714 0.0044 1611.1157 2222.5532 1.0035 2.0959
# σ 0.6100 0.0353 0.0006 3299.5373 2580.0330 1.0005 4.2924
# τ[1] 1.6321 1.5013 0.0324 1325.2906 921.5658 1.0030 1.7241
# τ[2] 2.0360 2.1504 0.0499 2028.7166 1857.7307 1.0008 2.6392
# τ[3] 2.0166 2.0165 0.0490 1813.5635 1616.3720 1.0006 2.3593
# τ[4] 2.1234 2.5838 0.2042 761.5893 328.3545 1.0060 0.9908
# τ[5] 1.9468 1.9703 0.0457 1507.1912 1439.6925 1.0021 1.9607
# γ[1,1] 0.3104 2.5783 0.0773 1114.6309 1516.6261 1.0006 1.4500
# γ[2,1] 0.3607 2.7544 0.0747 1360.3213 1571.1077 1.0013 1.7696
# γ[3,1] -1.0665 2.8623 0.0814 1239.2979 1593.4171 1.0019 1.6122
# γ[4,1] 0.8003 2.7856 0.1029 791.0495 344.1204 1.0032 1.0291
# γ[5,1] 0.1438 2.8079 0.0955 909.5711 1596.1574 1.0048 1.1833
# γ[1,2] -0.4093 2.7175 0.0822 1103.6350 557.1815 1.0028 1.4357
# γ[2,2] 0.3843 2.8844 0.0805 1280.2732 1453.5727 1.0035 1.6655
# γ[3,2] -0.9120 2.7927 0.0766 1336.0124 1892.9783 1.0058 1.7380
# γ[4,2] 0.6153 2.7966 0.0780 1287.4641 1591.0980 1.0036 1.6749
# γ[5,2] 0.3028 2.8947 0.0821 1238.1865 1650.1950 1.0028 1.6108
# Z[1,1] -0.0134 0.8883 0.0179 2460.2676 2453.7872 1.0020 3.2006
# Z[2,1] 0.0463 0.9159 0.0181 2563.1459 2738.2363 1.0059 3.3344
# Z[3,1] -0.0250 0.9320 0.0188 2455.9229 2419.3012 1.0008 3.1949
# Z[4,1] 0.0409 0.9215 0.0186 2465.3055 2295.0078 1.0009 3.2071
# Z[5,1] -0.0045 0.9716 0.0174 3107.9059 2600.5361 1.0013 4.0431
# Z[1,2] 0.0096 0.9085 0.0203 2002.6466 2463.5179 1.0022 2.6052
# Z[2,2] 0.0172 0.9168 0.0179 2621.0574 2284.6429 1.0022 3.4097
# Z[3,2] -0.0564 0.9068 0.0171 2829.5749 2393.1926 1.0006 3.6810
# Z[4,2] 0.0241 0.9430 0.0185 2591.3177 2463.7692 1.0009 3.3710
# Z[5,2] 0.0144 0.9475 0.0169 3169.3612 2689.0769 1.0012 4.1230

0 comments on commit 039766b

Please sign in to comment.