From c79fba87d7879838c8cb60fccfee8b6a797c6385 Mon Sep 17 00:00:00 2001 From: Jakob Heinzle Date: Thu, 2 Sep 2021 14:00:39 +0000 Subject: [PATCH] Bugfix waic202108 --- tools/tapas_logsumexp.m | 42 +++++++++++++++++++++++++++++++++++++++++ tools/tapas_waic.m | 20 ++++++++++++++------ 2 files changed, 56 insertions(+), 6 deletions(-) create mode 100644 tools/tapas_logsumexp.m diff --git a/tools/tapas_logsumexp.m b/tools/tapas_logsumexp.m new file mode 100644 index 00000000..1c0a4204 --- /dev/null +++ b/tools/tapas_logsumexp.m @@ -0,0 +1,42 @@ +function [y_sum,y_mean] = tapas_logsumexp(x) + + +%% ------------------------------------------------------------------------------------------- +% [y_sum,y_mean] = tapas_logsumexp(x) takes the values in x, exponates +% them, then takes the sum over the column, and finally applies the natural logarithm. +% The calculation uses the "log-sum-exp" trick: See e.g. http://gregorygundersen.com/blog/2020/02/09/log-sum-exp/ +% The function also returns the log-mean-exp. +%--------------------------------------------------------------------------------------------- +% INPUT: +% x - A column vector or matrix of values. All computations are +% made along the direction of a column. +% +% Optional: +% +%-------------------------------------------------------------------------------------------- +% OUTPUT: +% y_sum - The log-sum-exp of all columns of x. +% y_mean - The log-mean-exp of all columns of x. +% +% Author: Jakob Heinzle, TNU, UZH & ETHZ - April, 2021 +% +% REVISION LOG: +% +% Jakob Heinzle, 2021/04/16: new function +% +%% + +sz = size(x); + +if numel(sz)~=2 + error('Input x needs to be a matrix of 2 dimensions'); +end + +max_x = max(x); %compute maximum of each column +y_sum = max_x + log(sum(exp(x-ones(sz(1),1)*max_x))); + +if nargout==2 +y_mean = y_sum-log(sz(1)); % compute mean if necessary. +end + +return; \ No newline at end of file diff --git a/tools/tapas_waic.m b/tools/tapas_waic.m index eb70d21c..af32f41b 100644 --- a/tools/tapas_waic.m +++ b/tools/tapas_waic.m @@ -1,4 +1,4 @@ -function [waic, accuracy] = tapas_waic(llh) +function [waic, lppd] = tapas_waic(llh) %% Computes the Watanabe-Akaike Information criterion % % Input @@ -7,10 +7,10 @@ % the number of samples. % % Output -% waic -- The Watanaba AIC usin the methods in [1]. -% accuracy -- The expected log likelihood of the model. +% waic -- The Watanabe AIC usin the methods in [1]. +% lppd -- The log pointwise predictive density of the model. % -% The WAIC can be computed as the accuracy - penalization, where the +% The WAIC can be computed as the lppd - penalization, where the % penalization is the sum of the variance of the log likelihood, i.e., the % gradient of the Free energy. % @@ -21,13 +21,21 @@ % aponteeduardo@gmail.com % copyright (C) 2019 % +% REVISION LOG: +% +% Jakob Heinzle, 2021/04/16: changed computation of accuracy to be exactly log +% pointwise predictive density as in Gelman et al. +% +%% + +s = size(llh,2); % Estimator of the accuracy -accuracy = sum(mean(llh, 2)); +lppd = sum(tapas_logsumexp(llh')-log(s)); % Estimator of the variance penalization = sum(var(llh, [], 2)); -waic = accuracy - penalization; +waic = lppd - penalization; end