Skip to content

Commit

Permalink
misc refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Vincent committed Feb 18, 2017
1 parent 8f6da7d commit d2fe438
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 104 deletions.
6 changes: 0 additions & 6 deletions ddToolbox/CODA/CODA.m
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,6 @@ function densityplot(obj, targetAxisHandle, samples)
if ~isempty(targetAxisHandle)
subplot(targetAxisHandle)
end


% % using my plot tools package
% mcmc.UnivariateDistribution(samples',...
% 'plotStyle','hist',...
% 'plotHDI',false);

univariateObject = Stochastic('name_here');
univariateObject.addSamples(samples);
Expand Down
184 changes: 100 additions & 84 deletions ddToolbox/CODA/computeStats.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,105 +7,121 @@
disp('CODA: Calculating statistics')
assert(isstruct(all_samples))

variable_names = fieldnames(all_samples);
variable_list = asRowVector(fieldnames(all_samples));

stats = struct('Rhat',[], 'mean', [], 'median', [], 'std', [],...
'ci_low' , [] , 'ci_high' , [],...
'hdi_low', [] , 'hdi_high' , []);

for v=1:length(variable_names)
var_name = variable_names{v};
var_samples = all_samples.(var_name);

sz = size(var_samples);
Nchains = sz(1);
Nsamples = sz(2);
dims = ndims(var_samples);

% Calculate stats
switch dims
case{2} % scalar
stats.mode.(var_name) = calcMode( var_samples(:) );
stats.median.(var_name) = median( var_samples(:) );
stats.mean.(var_name) = mean( var_samples(:) );
stats.std.(var_name) = std( var_samples(:) );
[stats.ci_low.(var_name), stats.ci_high.(var_name)] = calcCI(var_samples(:));
[stats.hdi_low.(var_name), stats.hdi_high.(var_name)] = calcHDI(var_samples(:));

case{3} % vector
for n=1:sz(3)
stats.mode.(var_name)(n) = calcMode( vec(var_samples(:,:,n)) );
stats.median.(var_name)(n) = median( vec(var_samples(:,:,n)) );
stats.mean.(var_name)(n) = mean( vec(var_samples(:,:,n)) );
stats.std.(var_name)(n) = std( vec(var_samples(:,:,n)) );
[stats.ci_low.(var_name)(n),...
stats.ci_high.(var_name)(n)] = calcCI( vec(var_samples(:,:,n)) );
[stats.hdi_low.(var_name)(n),...
stats.hdi_high.(var_name)(n)] = calcHDI( vec(var_samples(:,:,n)) );
end
case{4} % 2D matrix
for a=1:sz(3)
for b=1:sz(4)
stats.mode.(var_name)(a,b) = calcMode( vec(var_samples(:,:,a,b)) );
stats.median.(var_name)(a,b) = median( vec(var_samples(:,:,a,b)) );
stats.mean.(var_name)(a,b) = mean( vec(var_samples(:,:,a,b)) );
stats.std.(var_name)(a,b) = std( vec(var_samples(:,:,a,b)) );
[stats.ci_low.(var_name)(a,b),...
stats.ci_high.(var_name)(a,b)] = calcCI( vec(var_samples(:,:,a,b)) );
[stats.hdi_low.(var_name)(a,b),...
stats.hdi_high.(var_name(a,b))] = calcHDI( vec(var_samples(:,:,a,b)) );
end
end
for variable_name = variable_list

variable_samples = all_samples.(variable_name{:});

switch ndims(variable_samples)
case{2}
stats = calcStatsScalar(stats, variable_samples, variable_name{:});
case{3}
stats = calcStatsVector(stats, variable_samples, variable_name{:});
case{4}
stats = calcStatsMatrix(stats, variable_samples, variable_name{:});
otherwise
warning('calculation of stats not supported for >2D variables. You could implement it and send a pull request.')
stats.mode.(var_name) = [];
stats.median.(var_name) = [];
stats.mean.(var_name) = [];
stats.std.(var_name) = [];
stats.ci_low.(var_name) = [];
stats.ci_high.(var_name) = [];
stats.hdi_low.(var_name) = [];
stats.hdi_high.(var_name) = [];
stats = calcStatsTensor3(stats, variable_samples, variable_name{:});
end

%% "estimated potential scale reduction" statistics due to Gelman and Rubin.
Rhat = calcRhat();

Rhat = calcRhat(variable_samples);
if ~isnan(Rhat)
stats.Rhat.(var_name) = squeeze(Rhat);
stats.Rhat.(variable_name{:}) = squeeze(Rhat);
end

end

end


function Rhat = calcRhat()
st_mean_per_chain = mean(var_samples, 2);
st_mean_overall = mean(st_mean_per_chain, 1);
function stats = calcStatsScalar(stats, var_samples, var_name)
stats.mode.(var_name) = calcMode( var_samples(:) );
stats.median.(var_name) = median( var_samples(:) );
stats.mean.(var_name) = mean( var_samples(:) );
stats.std.(var_name) = std( var_samples(:) );
[stats.ci_low.(var_name), stats.ci_high.(var_name)] = calcCI(var_samples(:));
[stats.hdi_low.(var_name), stats.hdi_high.(var_name)] = calcHDI(var_samples(:));
end

if Nchains > 1
B = (Nsamples/Nchains-1) * ...
sum((st_mean_per_chain - repmat(st_mean_overall, [Nchains,1])).^2);
varPerChain = var(var_samples, 0, 2);
W = (1/Nchains) * sum(varPerChain);
vhat = ((Nsamples-1)/Nsamples) * W + (1/Nsamples) * B;
Rhat = sqrt(vhat./(W+eps));
else
Rhat = nan;
end
end
function stats = calcStatsVector(stats, var_samples, var_name)
sz = size(var_samples);
for n=1:sz(3)
stats.mode.(var_name)(n) = calcMode( vec(var_samples(:,:,n)) );
stats.median.(var_name)(n) = median( vec(var_samples(:,:,n)) );
stats.mean.(var_name)(n) = mean( vec(var_samples(:,:,n)) );
stats.std.(var_name)(n) = std( vec(var_samples(:,:,n)) );
[stats.ci_low.(var_name)(n),...
stats.ci_high.(var_name)(n)] = calcCI( vec(var_samples(:,:,n)) );
[stats.hdi_low.(var_name)(n),...
stats.hdi_high.(var_name)(n)] = calcHDI( vec(var_samples(:,:,n)) );
end
end

function [low, high] = calcCI(reshaped_samples)
% get the 95% interval of the posterior
ci_samples_overall = prctile( reshaped_samples , [ 2.5 97.5 ] , 1 );
ci_samples_overall_low = ci_samples_overall( 1,: );
ci_samples_overall_high = ci_samples_overall( 2,: );
low = squeeze(ci_samples_overall_low);
high = squeeze(ci_samples_overall_high);
function stats = calcStatsMatrix(stats, var_samples, var_name)
sz = size(var_samples);
for a=1:sz(3)
for b=1:sz(4)
stats.mode.(var_name)(a,b) = calcMode( vec(var_samples(:,:,a,b)) );
stats.median.(var_name)(a,b) = median( vec(var_samples(:,:,a,b)) );
stats.mean.(var_name)(a,b) = mean( vec(var_samples(:,:,a,b)) );
stats.std.(var_name)(a,b) = std( vec(var_samples(:,:,a,b)) );
[stats.ci_low.(var_name)(a,b),...
stats.ci_high.(var_name)(a,b)] = calcCI( vec(var_samples(:,:,a,b)) );
[stats.hdi_low.(var_name)(a,b),...
stats.hdi_high.(var_name(a,b))] = calcHDI( vec(var_samples(:,:,a,b)) );
end
end
end

function [low, high] = calcHDI(reshaped_samples)
% get the 95% highest density intervals of the posterior
[hdi] = HDIofSamples(reshaped_samples, 0.95);
low = squeeze(hdi(1));
high = squeeze(hdi(2));
end
function stats = calcStatsTensor3(stats, var_samples, var_name)
warning('calculation of stats not supported for >2D matricies. You could implement it and send a pull request.')
stats.mode.(var_name) = [];
stats.median.(var_name) = [];
stats.mean.(var_name) = [];
stats.std.(var_name) = [];
stats.ci_low.(var_name) = [];
stats.ci_high.(var_name) = [];
stats.hdi_low.(var_name) = [];
stats.hdi_high.(var_name) = [];
end

function Rhat = calcRhat(var_samples)
% "estimated potential scale reduction" statistics due to Gelman and Rubin
sz = size(var_samples);
Nchains = sz(1);
Nsamples = sz(2);

st_mean_per_chain = mean(var_samples, 2);
st_mean_overall = mean(st_mean_per_chain, 1);

if Nchains > 1
B = (Nsamples/Nchains-1) * ...
sum((st_mean_per_chain - repmat(st_mean_overall, [Nchains,1])).^2);
varPerChain = var(var_samples, 0, 2);
W = (1/Nchains) * sum(varPerChain);
vhat = ((Nsamples-1)/Nsamples) * W + (1/Nsamples) * B;
Rhat = sqrt(vhat./(W+eps));
else
Rhat = nan;
end
end

function [low, high] = calcCI(reshaped_samples)
% get the 95% interval of the posterior
ci_samples_overall = prctile( reshaped_samples , [ 2.5 97.5 ] , 1 );
ci_samples_overall_low = ci_samples_overall( 1,: );
ci_samples_overall_high = ci_samples_overall( 2,: );
low = squeeze(ci_samples_overall_low);
high = squeeze(ci_samples_overall_high);
end

function [low, high] = calcHDI(reshaped_samples)
% get the 95% highest density intervals of the posterior
[hdi] = HDIofSamples(reshaped_samples, 0.95);
low = squeeze(hdi(1));
high = squeeze(hdi(2));
end
12 changes: 8 additions & 4 deletions ddToolbox/Data/Data.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

properties (Dependent)
totalTrials
groupTable % table of A§, DA, B, DB, R, ID, PA, PB
groupTable % table of A, DA, B, DB, R, ID, PA, PB
end

% NOTE TO SELF: These public methods need to be seen as interfaces to
Expand Down Expand Up @@ -234,9 +234,13 @@ function exportGroupDataFileToDisk(obj)
end
end

function output = getEverythingAboutAnExperiment(obj, ind)
% return a structure of everything about the data file 'ind'

function [samples] = getGroupLevelSamples(obj, fieldsToGet)
if ~obj.isUnobservedPartipantPresent()
error('Looks like we don''t have group level estimates.')
else
index = obj.data.getIndexOfUnobservedParticipant();
samples = obj.coda.getSamplesAtIndex_asStruct(index, fieldsToGet);
end
end

function participantIndexList = getParticipantIndexList(obj)
Expand Down
15 changes: 6 additions & 9 deletions ddToolbox/models/Model.m
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ function export(obj)
function obj = plotMCMCchains(obj,vars)
obj.coda.plotMCMCchains(vars);
end

function [samples] = getGroupLevelSamples(obj, fieldsToGet)
[samples] = obj.data.getGroupLevelSamples(fieldsToGet);
end

end

Expand All @@ -152,15 +156,7 @@ function export(obj)
nChains = obj.mcmcParams.nchains;
end

function [samples] = getGroupLevelSamples(obj, fieldsToGet)
if ~obj.data.isUnobservedPartipantPresent()
% exit if we don't have any group level inference
error('Looks like we don''t have group level estimates.')
else
index = obj.data.getIndexOfUnobservedParticipant();
samples = obj.coda.getSamplesAtIndex_asStruct(index, fieldsToGet);
end
end


function [predicted_subjective_values] = get_inferred_present_subjective_values(obj)
%% calculate point estimates
Expand All @@ -170,6 +166,7 @@ function export(obj)

%% return point estimates of present subjective values...
all_data_table = obj.data.groupTable;
% add new columns for present subjective value (VA, VB)
all_data_table.VA = obj.coda.getStats(obj.plotOptions.pointEstimateType, 'VA');
all_data_table.VB = obj.coda.getStats(obj.plotOptions.pointEstimateType, 'VB');
predicted_subjective_values.point_estimates = all_data_table;
Expand Down
1 change: 0 additions & 1 deletion ddToolbox/models/nonparametric_models/NonParametric.m
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ function psychometric_plots(obj)
psycho.plot();
%% plot response data TODO: move this to Data ~~~~~~~~~
hold on
%pTable = obj.data.getRawDataTableForParticipant(ind);
AoverB = personStruct.data.A ./ personStruct.data.B;
R = personStruct.data.R;
% grab just for this delay
Expand Down
6 changes: 6 additions & 0 deletions ddToolbox/utils/asRowVector.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
function output = asRowVector(input)
% coerce into a row vector
if iscolumn(input)
output = input';
end
end

0 comments on commit d2fe438

Please sign in to comment.