diff --git a/turing/12-hierarchical_varying_intercept_slope-cheese.jl b/turing/12-hierarchical_varying_intercept_slope-cheese.jl index 4efbb0d..c3573db 100644 --- a/turing/12-hierarchical_varying_intercept_slope-cheese.jl +++ b/turing/12-hierarchical_varying_intercept_slope-cheese.jl @@ -1,8 +1,3 @@ -# TODO: This model currently does not sample. -# It needs either https://github.com/TuringLang/Turing.jl/issues/1629 -# or https://github.com/TuringLang/Bijectors.jl/issues/134 -# or https://github.com/JuliaStats/PDMats.jl/issues/132 -# Currently it uses a workaround with TransformVariables.jl and PDMats.jl thanks to @sethaxen using Turing using CSV using DataFrames @@ -42,45 +37,12 @@ y = standardize(ZScoreTransform, y; dims=1) idx = cheese[:, :background_int] # define the model -# @model function correlated_varying_intercept_slope_regression(X, idx, y; -# predictors=size(X, 2), -# N=size(X, 1), -# n_gr=length(unique(idx))) -# # priors -# Ω ~ LKJCholesky(predictors, 2.0) -# σ ~ Exponential(1) - -# # prior for variance of random intercepts and slopes -# # usually requires thoughtful specification -# τ ~ filldist(truncated(Cauchy(0, 2); lower=0), predictors) # group-level SDs -# γ ~ filldist(Normal(0, 5), predictors, n_gr) # matrix of group coefficients -# Z ~ filldist(Normal(0, 1), predictors, n_gr) # matrix of non-centered group coefficients - -# # reconstruct β from Ω and τ -# β = γ + τ .* Ω.L * Z - -# # likelihood -# for i in 1:N -# y[i] ~ Normal(X[i, :] ⋅ β[:, idx[i]], σ) -# end -# return(; y, β, σ, Ω, τ, γ, Z) -# end - -# workaround with TransformVariables.jl and PDMats.jl -using TransformVariables -using PDMats - @model function correlated_varying_intercept_slope_regression(X, idx, y; predictors=size(X, 2), N=size(X, 1), n_gr=length(unique(idx))) - # workaround - trans = CorrCholeskyFactor(predictors) - L_tilde ~ filldist(Flat(), dimension(trans)) - L_U, logdetJ = transform_and_logjac(trans, L_tilde) - Turing.@addlogprob! logpdf(LKJCholesky(predictors, 2.0), Cholesky(L_U)) + logdetJ - # priors + Ω ~ LKJCholesky(predictors, 2.0) σ ~ Exponential(1) # prior for variance of random intercepts and slopes @@ -90,9 +52,7 @@ using PDMats Z ~ filldist(Normal(0, 1), predictors, n_gr) # matrix of non-centered group coefficients # reconstruct β from Ω and τ - Ω_L = LowerTriangular(collect(τ .* L_U')) # collect is necessary for ReverseDiff for some reason - Ω = PDMat(Cholesky(Ω_L)) - β = γ + Ω * Z + β = γ + τ .* Ω.L * Z # likelihood for i in 1:N @@ -108,32 +68,47 @@ model = correlated_varying_intercept_slope_regression(X, idx, y) chn = sample(model, NUTS(1_000, 0.8), MCMCThreads(), 1_000, 4) # results: -# parameters mean std naive_se mcse ess rhat ess_per_sec -# Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64 -# L_tilde[1] 0.0134 0.8729 0.0138 0.0219 1978.9056 1.0014 1.8512 -# L_tilde[2] -0.0077 0.8509 0.0135 0.0178 2059.1537 1.0008 1.9263 -# L_tilde[3] 0.0763 1.0059 0.0159 0.0255 1659.5684 1.0005 1.5525 -# L_tilde[4] 0.0210 0.9055 0.0143 0.0275 1040.8284 1.0023 0.9737 -# L_tilde[5] -0.0190 0.9752 0.0154 0.0228 1674.3909 1.0004 1.5664 -# L_tilde[6] 0.0632 1.1176 0.0177 0.0291 1439.6238 1.0039 1.3468 -# σ 0.6084 0.0350 0.0006 0.0009 1801.6933 1.0029 1.6855 -# τ[1] 0.9784 0.7305 0.0116 0.0272 579.6022 1.0057 0.5422 -# τ[2] 0.8940 0.6539 0.0103 0.0184 1150.8453 1.0074 1.0766 -# τ[3] 0.9722 0.7015 0.0111 0.0284 590.7927 1.0149 0.5527 -# τ[4] 0.9602 0.7077 0.0112 0.0260 736.0099 1.0144 0.6885 -# γ[1,1] 0.6740 1.4735 0.0233 0.0536 706.3475 1.0062 0.6608 -# γ[2,1] -0.9898 1.2884 0.0204 0.0452 723.9490 1.0027 0.6772 -# γ[3,1] 1.0575 1.5200 0.0240 0.0749 381.3226 1.0023 0.3567 -# γ[4,1] 0.2115 1.5138 0.0239 0.0522 654.1878 1.0014 0.6120 -# γ[1,2] -0.0924 1.4938 0.0236 0.0603 586.5729 1.0047 0.5487 -# γ[2,2] -1.3776 1.2600 0.0199 0.0451 694.1231 1.0086 0.6493 -# γ[3,2] 0.4357 1.4541 0.0230 0.0639 425.2658 1.0057 0.3978 -# γ[4,2] -0.1971 1.3396 0.0212 0.0362 656.2398 1.0019 0.6139 -# Z[1,1] 0.0057 0.9390 0.0148 0.0251 1210.1329 1.0013 1.1321 -# Z[2,1] 0.0383 0.9775 0.0155 0.0328 916.9673 1.0009 0.8578 -# Z[3,1] 0.0385 0.9365 0.0148 0.0264 1536.4146 0.9996 1.4373 -# Z[4,1] 0.0779 0.9458 0.0150 0.0232 1283.2695 1.0020 1.2005 -# Z[1,2] 0.0255 0.9185 0.0145 0.0264 1241.2293 1.0004 1.1612 -# Z[2,2] -0.0534 0.9356 0.0148 0.0289 1346.0092 1.0020 1.2592 -# Z[3,2] -0.0185 0.9483 0.0150 0.0266 1271.8442 1.0008 1.1898 -# Z[4,2] 0.0281 0.9088 0.0144 0.0182 1510.2904 1.0004 1.4129 \ No newline at end of file +# Summary Statistics +# parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec +# Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64 +# Ω.L[1,1] 1.0000 0.0000 NaN NaN NaN NaN NaN +# Ω.L[2,1] 0.0077 0.3529 0.0055 3960.8548 2002.1197 1.0017 5.1527 +# Ω.L[3,1] -0.0030 0.3525 0.0073 2318.7503 2318.8227 1.0047 3.0165 +# Ω.L[4,1] 0.0057 0.3630 0.0062 3418.7403 2347.2750 1.0004 4.4474 +# Ω.L[5,1] 0.0033 0.3536 0.0056 4052.7561 2195.1157 1.0021 5.2722 +# Ω.L[2,2] 0.9318 0.0848 0.0022 1140.1517 503.9699 1.0019 1.4832 +# Ω.L[3,2] -0.0051 0.3544 0.0061 3350.9388 2521.8811 1.0010 4.3592 +# Ω.L[4,2] 0.0069 0.3598 0.0064 3169.0739 2424.1733 1.0013 4.1226 +# Ω.L[5,2] -0.0015 0.3436 0.0058 3500.7859 2434.4279 1.0007 4.5542 +# Ω.L[3,3] 0.8577 0.1206 0.0031 1675.2394 2014.9484 1.0014 2.1793 +# Ω.L[4,3] -0.0044 0.3626 0.0092 1342.8179 324.2139 1.0049 1.7469 +# Ω.L[5,3] 0.0153 0.3593 0.0093 1308.5168 377.1732 1.0046 1.7022 +# Ω.L[4,4] 0.7640 0.1534 0.0042 1427.1434 1922.1280 1.0024 1.8566 +# Ω.L[5,4] 0.0076 0.3580 0.0058 3683.6774 2416.5623 1.0029 4.7921 +# Ω.L[5,5] 0.6856 0.1714 0.0044 1611.1157 2222.5532 1.0035 2.0959 +# σ 0.6100 0.0353 0.0006 3299.5373 2580.0330 1.0005 4.2924 +# τ[1] 1.6321 1.5013 0.0324 1325.2906 921.5658 1.0030 1.7241 +# τ[2] 2.0360 2.1504 0.0499 2028.7166 1857.7307 1.0008 2.6392 +# τ[3] 2.0166 2.0165 0.0490 1813.5635 1616.3720 1.0006 2.3593 +# τ[4] 2.1234 2.5838 0.2042 761.5893 328.3545 1.0060 0.9908 +# τ[5] 1.9468 1.9703 0.0457 1507.1912 1439.6925 1.0021 1.9607 +# γ[1,1] 0.3104 2.5783 0.0773 1114.6309 1516.6261 1.0006 1.4500 +# γ[2,1] 0.3607 2.7544 0.0747 1360.3213 1571.1077 1.0013 1.7696 +# γ[3,1] -1.0665 2.8623 0.0814 1239.2979 1593.4171 1.0019 1.6122 +# γ[4,1] 0.8003 2.7856 0.1029 791.0495 344.1204 1.0032 1.0291 +# γ[5,1] 0.1438 2.8079 0.0955 909.5711 1596.1574 1.0048 1.1833 +# γ[1,2] -0.4093 2.7175 0.0822 1103.6350 557.1815 1.0028 1.4357 +# γ[2,2] 0.3843 2.8844 0.0805 1280.2732 1453.5727 1.0035 1.6655 +# γ[3,2] -0.9120 2.7927 0.0766 1336.0124 1892.9783 1.0058 1.7380 +# γ[4,2] 0.6153 2.7966 0.0780 1287.4641 1591.0980 1.0036 1.6749 +# γ[5,2] 0.3028 2.8947 0.0821 1238.1865 1650.1950 1.0028 1.6108 +# Z[1,1] -0.0134 0.8883 0.0179 2460.2676 2453.7872 1.0020 3.2006 +# Z[2,1] 0.0463 0.9159 0.0181 2563.1459 2738.2363 1.0059 3.3344 +# Z[3,1] -0.0250 0.9320 0.0188 2455.9229 2419.3012 1.0008 3.1949 +# Z[4,1] 0.0409 0.9215 0.0186 2465.3055 2295.0078 1.0009 3.2071 +# Z[5,1] -0.0045 0.9716 0.0174 3107.9059 2600.5361 1.0013 4.0431 +# Z[1,2] 0.0096 0.9085 0.0203 2002.6466 2463.5179 1.0022 2.6052 +# Z[2,2] 0.0172 0.9168 0.0179 2621.0574 2284.6429 1.0022 3.4097 +# Z[3,2] -0.0564 0.9068 0.0171 2829.5749 2393.1926 1.0006 3.6810 +# Z[4,2] 0.0241 0.9430 0.0185 2591.3177 2463.7692 1.0009 3.3710 +# Z[5,2] 0.0144 0.9475 0.0169 3169.3612 2689.0769 1.0012 4.1230