Skip to content

Commit

Permalink
Merge pull request #208 from drbenvincent/dev
Browse files Browse the repository at this point in the history
WAIC model comparison
  • Loading branch information
drbenvincent authored Mar 7, 2018
2 parents 79b320f + 752ad73 commit fade821
Show file tree
Hide file tree
Showing 41 changed files with 532 additions and 176 deletions.
138 changes: 138 additions & 0 deletions ddToolbox/WAIC.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
classdef WAIC
%WAIC WAIC object
% Extended description here
%
% References
% Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A.,
% & Rubin, D. B. (2013). Bayesian Data Analysis, Third Edition.
% CRC Press.
properties (SetAccess = protected)
lppd, pWAIC, WAIC_value, WAIC_standard_error
nSamples, nCases
end

properties (Hidden = true)
log_lik
lppd_vec, pWAIC_vec, WAIC_vec
modelName
end

methods

function obj = WAIC(log_lik)

[obj.nCases, obj.nSamples] = size(log_lik);

obj.log_lik = log_lik;
clear log_lik

% Calculate lppd
% Equation 7.5 from Gelman et al (2013)
obj.lppd_vec = log( mean( exp(obj.log_lik) , 2) );
obj.lppd = sum(obj.lppd_vec);

% Calculate effective number of samples, pWAIC
% Equation 7.12 from Gelman et al (2013)
obj.pWAIC_vec = var(obj.log_lik,0,2);
obj.pWAIC = sum(obj.pWAIC_vec);

% Calculate WAIC
obj.WAIC_value = -2 * obj.lppd + 2 * obj.pWAIC;

% Calculate WAIC standard error
obj.WAIC_vec = -2 * obj.lppd_vec + 2 * obj.pWAIC_vec;
obj.WAIC_standard_error = sqrt(obj.nCases)*std(obj.WAIC_vec);

end

function comparisonTable = compare(obj)
% Compare WAIC info from mulitple models
assert(numel(obj)>1, 'expecting an array of >1 WAIC object')

% Build a table of values
model = {obj.modelName}';
WAIC = [obj.WAIC_value]';
pWAIC = [obj.pWAIC]';
lppd = [obj.lppd]';
SE = [obj.WAIC_standard_error]';
dWAIC = WAIC - min(WAIC);
weight = exp(-0.5.*dWAIC) ./ sum(exp(-0.5.*dWAIC));

% dSE is the SE of the difference in WAIC (not SE!) between
% each model and the top ranked model
[~, i_best_model] = min([obj.WAIC_value]);
for m = 1:numel(obj)
if m == i_best_model
dSE(m,1) = NaN;
else
% Calculate SE of difference (of WAIC values) between
% model m and i_best_model
WAIC_diff = obj(i_best_model).WAIC_vec - obj(m).WAIC_vec;
dSE(m,1) = sqrt(obj(m).nCases)*std(WAIC_diff);
end
end
% create table
comparisonTable = table(model, WAIC, pWAIC, dWAIC, weight, SE, dSE, lppd);
% sort so best models (lowest WAIC) values are at top of table
comparisonTable = sortrows(comparisonTable,{'WAIC'},{'ascend'});
end

function plot(obj)
% produce a WAIC comparison plot

comparisonTable = obj.compare();

% define y-value positions for each model
y = [1:1:size(comparisonTable,1)];

ms = 6;
clf
hold on

% in-sample deviance as solid circles
in_sample_deviance = -2*comparisonTable.lppd;
isd = plot(in_sample_deviance, y, 'ko',...
'MarkerFaceColor','k',...
'MarkerSize', ms);

% WAIC as empty cirlcles, with SE errorbars
%waic = plot(comparisonTable.WAIC, y, 'ko');
waic_eb = errorbar(comparisonTable.WAIC,y,comparisonTable.SE,...
'horizontal',...
'o',...
'LineStyle', 'none',...
'Color', 'k',...
'MarkerFaceColor','w',...
'MarkerSize', ms);

% plot dSE models
waic_diff = errorbar(comparisonTable.dWAIC([2:end])+min(comparisonTable.WAIC),...
y([2:end])-0.2, comparisonTable.dSE([2:end]),...
'horizontal',...
'^',...
'LineStyle', 'none',...
'Color', [0.5 0.5 0.5],...
'MarkerFaceColor','w',...
'MarkerSize', ms);

% formatting
xlabel('deviance');
set(gca,...
'YTick', y,...
'YTickLabel', comparisonTable.model,...
'YDir','reverse');
ylim([min(y)-1, max(y)+1]);

vline(min(comparisonTable.WAIC), 'Color',[0.5 0.5 0.5]);

legend([isd, waic_eb, waic_diff],...
{'in-sample deviance', 'WAIC (+/- SE)', 'SE of WAIC difference (+/- SE)'},...
'location', 'eastoutside');

title('WAIC Model Comparison')

end

end

end
17 changes: 16 additions & 1 deletion ddToolbox/models/Model.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
coda % handle to coda object
data % handle to Data class
end

properties (Hidden = false, SetAccess = protected, GetAccess = public)
WAIC_stats
end

%% Private properties
properties (SetAccess = protected, GetAccess = protected)
Expand Down Expand Up @@ -255,7 +259,18 @@ function export(obj)
%% Calculate AUC
MAX_DELAY = 365;
obj.auc = obj.calcAreaUnderCurveForAll(MAX_DELAY);


%% Calculate WAIC
% first, prepare a table of log likelihood values. Rows are
% observations, columns are MCMC samples.
samples = obj.coda.getSamples({'log_lik'});
% collapse over chains
[chains, samples_per_chain, N] = size(samples.log_lik);
log_lik = reshape(samples.log_lik, chains*samples_per_chain, N)';
% second, create WAIC object
obj.WAIC_stats = WAIC(log_lik);

obj.WAIC_stats.modelName = class(obj);
end

function auc = calcAreaUnderCurveForAll(obj, MAX_DELAY)
Expand Down
Loading

0 comments on commit fade821

Please sign in to comment.