Skip to content

Commit

Permalink
Merge pull request #179 from drbenvincent/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
drbenvincent authored Feb 18, 2017
2 parents b7513bf + 9ebbecd commit bda0673
Show file tree
Hide file tree
Showing 25 changed files with 438 additions and 387 deletions.
16 changes: 1 addition & 15 deletions ddToolbox/CODA/CODA.m
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,6 @@ function densityplot(obj, targetAxisHandle, samples)
if ~isempty(targetAxisHandle)
subplot(targetAxisHandle)
end


% % using my plot tools package
% mcmc.UnivariateDistribution(samples',...
% 'plotStyle','hist',...
% 'plotHDI',false);

univariateObject = Stochastic('name_here');
univariateObject.addSamples(samples);
Expand Down Expand Up @@ -208,7 +202,7 @@ function plot_bivariate_distribution(obj, targetAxisHandle, x_var_name, y_var_na
'axisSquare', true);
end

function plotUnivariateSummaries(obj, variables, plotOptions, modelFilename, idNames)
function plotUnivariateSummaries(obj, variables, plotOptions, idNames)
% create a figure with multiple subplots (boxplots) for each input variable privided

% create subplots
Expand Down Expand Up @@ -238,14 +232,6 @@ function plotUnivariateSummaries(obj, variables, plotOptions, modelFilename, idN
% screen_size = get(0,'ScreenSize');
% fig_width = min(screen_size(3), 100+numel(participantNames)*20);
% set(gcf,'Position',[100 200 fig_width 1000])

% Export
if plotOptions.shouldExportPlots
myExport(plotOptions.savePath,...
'UnivariateSummary',...
'suffix', modelFilename,...
'formats', plotOptions.exportFormats)
end
end


Expand Down
184 changes: 100 additions & 84 deletions ddToolbox/CODA/computeStats.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,105 +7,121 @@
disp('CODA: Calculating statistics')
assert(isstruct(all_samples))

variable_names = fieldnames(all_samples);
variable_list = asRowVector(fieldnames(all_samples));

stats = struct('Rhat',[], 'mean', [], 'median', [], 'std', [],...
'ci_low' , [] , 'ci_high' , [],...
'hdi_low', [] , 'hdi_high' , []);

for v=1:length(variable_names)
var_name = variable_names{v};
var_samples = all_samples.(var_name);

sz = size(var_samples);
Nchains = sz(1);
Nsamples = sz(2);
dims = ndims(var_samples);

% Calculate stats
switch dims
case{2} % scalar
stats.mode.(var_name) = calcMode( var_samples(:) );
stats.median.(var_name) = median( var_samples(:) );
stats.mean.(var_name) = mean( var_samples(:) );
stats.std.(var_name) = std( var_samples(:) );
[stats.ci_low.(var_name), stats.ci_high.(var_name)] = calcCI(var_samples(:));
[stats.hdi_low.(var_name), stats.hdi_high.(var_name)] = calcHDI(var_samples(:));

case{3} % vector
for n=1:sz(3)
stats.mode.(var_name)(n) = calcMode( vec(var_samples(:,:,n)) );
stats.median.(var_name)(n) = median( vec(var_samples(:,:,n)) );
stats.mean.(var_name)(n) = mean( vec(var_samples(:,:,n)) );
stats.std.(var_name)(n) = std( vec(var_samples(:,:,n)) );
[stats.ci_low.(var_name)(n),...
stats.ci_high.(var_name)(n)] = calcCI( vec(var_samples(:,:,n)) );
[stats.hdi_low.(var_name)(n),...
stats.hdi_high.(var_name)(n)] = calcHDI( vec(var_samples(:,:,n)) );
end
case{4} % 2D matrix
for a=1:sz(3)
for b=1:sz(4)
stats.mode.(var_name)(a,b) = calcMode( vec(var_samples(:,:,a,b)) );
stats.median.(var_name)(a,b) = median( vec(var_samples(:,:,a,b)) );
stats.mean.(var_name)(a,b) = mean( vec(var_samples(:,:,a,b)) );
stats.std.(var_name)(a,b) = std( vec(var_samples(:,:,a,b)) );
[stats.ci_low.(var_name)(a,b),...
stats.ci_high.(var_name)(a,b)] = calcCI( vec(var_samples(:,:,a,b)) );
[stats.hdi_low.(var_name)(a,b),...
stats.hdi_high.(var_name(a,b))] = calcHDI( vec(var_samples(:,:,a,b)) );
end
end
for variable_name = variable_list

variable_samples = all_samples.(variable_name{:});

switch ndims(variable_samples)
case{2}
stats = calcStatsScalar(stats, variable_samples, variable_name{:});
case{3}
stats = calcStatsVector(stats, variable_samples, variable_name{:});
case{4}
stats = calcStatsMatrix(stats, variable_samples, variable_name{:});
otherwise
warning('calculation of stats not supported for >2D variables. You could implement it and send a pull request.')
stats.mode.(var_name) = [];
stats.median.(var_name) = [];
stats.mean.(var_name) = [];
stats.std.(var_name) = [];
stats.ci_low.(var_name) = [];
stats.ci_high.(var_name) = [];
stats.hdi_low.(var_name) = [];
stats.hdi_high.(var_name) = [];
stats = calcStatsTensor3(stats, variable_samples, variable_name{:});
end

%% "estimated potential scale reduction" statistics due to Gelman and Rubin.
Rhat = calcRhat();

Rhat = calcRhat(variable_samples);
if ~isnan(Rhat)
stats.Rhat.(var_name) = squeeze(Rhat);
stats.Rhat.(variable_name{:}) = squeeze(Rhat);
end

end

end


function Rhat = calcRhat()
st_mean_per_chain = mean(var_samples, 2);
st_mean_overall = mean(st_mean_per_chain, 1);
function stats = calcStatsScalar(stats, var_samples, var_name)
stats.mode.(var_name) = calcMode( var_samples(:) );
stats.median.(var_name) = median( var_samples(:) );
stats.mean.(var_name) = mean( var_samples(:) );
stats.std.(var_name) = std( var_samples(:) );
[stats.ci_low.(var_name), stats.ci_high.(var_name)] = calcCI(var_samples(:));
[stats.hdi_low.(var_name), stats.hdi_high.(var_name)] = calcHDI(var_samples(:));
end

if Nchains > 1
B = (Nsamples/Nchains-1) * ...
sum((st_mean_per_chain - repmat(st_mean_overall, [Nchains,1])).^2);
varPerChain = var(var_samples, 0, 2);
W = (1/Nchains) * sum(varPerChain);
vhat = ((Nsamples-1)/Nsamples) * W + (1/Nsamples) * B;
Rhat = sqrt(vhat./(W+eps));
else
Rhat = nan;
end
end
function stats = calcStatsVector(stats, var_samples, var_name)
sz = size(var_samples);
for n=1:sz(3)
stats.mode.(var_name)(n) = calcMode( vec(var_samples(:,:,n)) );
stats.median.(var_name)(n) = median( vec(var_samples(:,:,n)) );
stats.mean.(var_name)(n) = mean( vec(var_samples(:,:,n)) );
stats.std.(var_name)(n) = std( vec(var_samples(:,:,n)) );
[stats.ci_low.(var_name)(n),...
stats.ci_high.(var_name)(n)] = calcCI( vec(var_samples(:,:,n)) );
[stats.hdi_low.(var_name)(n),...
stats.hdi_high.(var_name)(n)] = calcHDI( vec(var_samples(:,:,n)) );
end
end

function [low, high] = calcCI(reshaped_samples)
% get the 95% interval of the posterior
ci_samples_overall = prctile( reshaped_samples , [ 2.5 97.5 ] , 1 );
ci_samples_overall_low = ci_samples_overall( 1,: );
ci_samples_overall_high = ci_samples_overall( 2,: );
low = squeeze(ci_samples_overall_low);
high = squeeze(ci_samples_overall_high);
function stats = calcStatsMatrix(stats, var_samples, var_name)
sz = size(var_samples);
for a=1:sz(3)
for b=1:sz(4)
stats.mode.(var_name)(a,b) = calcMode( vec(var_samples(:,:,a,b)) );
stats.median.(var_name)(a,b) = median( vec(var_samples(:,:,a,b)) );
stats.mean.(var_name)(a,b) = mean( vec(var_samples(:,:,a,b)) );
stats.std.(var_name)(a,b) = std( vec(var_samples(:,:,a,b)) );
[stats.ci_low.(var_name)(a,b),...
stats.ci_high.(var_name)(a,b)] = calcCI( vec(var_samples(:,:,a,b)) );
[stats.hdi_low.(var_name)(a,b),...
stats.hdi_high.(var_name(a,b))] = calcHDI( vec(var_samples(:,:,a,b)) );
end
end
end

function [low, high] = calcHDI(reshaped_samples)
% get the 95% highest density intervals of the posterior
[hdi] = HDIofSamples(reshaped_samples, 0.95);
low = squeeze(hdi(1));
high = squeeze(hdi(2));
end
function stats = calcStatsTensor3(stats, var_samples, var_name)
warning('calculation of stats not supported for >2D matricies. You could implement it and send a pull request.')
stats.mode.(var_name) = [];
stats.median.(var_name) = [];
stats.mean.(var_name) = [];
stats.std.(var_name) = [];
stats.ci_low.(var_name) = [];
stats.ci_high.(var_name) = [];
stats.hdi_low.(var_name) = [];
stats.hdi_high.(var_name) = [];
end

function Rhat = calcRhat(var_samples)
% "estimated potential scale reduction" statistics due to Gelman and Rubin
sz = size(var_samples);
Nchains = sz(1);
Nsamples = sz(2);

st_mean_per_chain = mean(var_samples, 2);
st_mean_overall = mean(st_mean_per_chain, 1);

if Nchains > 1
B = (Nsamples/Nchains-1) * ...
sum((st_mean_per_chain - repmat(st_mean_overall, [Nchains,1])).^2);
varPerChain = var(var_samples, 0, 2);
W = (1/Nchains) * sum(varPerChain);
vhat = ((Nsamples-1)/Nsamples) * W + (1/Nsamples) * B;
Rhat = sqrt(vhat./(W+eps));
else
Rhat = nan;
end
end

function [low, high] = calcCI(reshaped_samples)
% get the 95% interval of the posterior
ci_samples_overall = prctile( reshaped_samples , [ 2.5 97.5 ] , 1 );
ci_samples_overall_low = ci_samples_overall( 1,: );
ci_samples_overall_high = ci_samples_overall( 2,: );
low = squeeze(ci_samples_overall_low);
high = squeeze(ci_samples_overall_high);
end

function [low, high] = calcHDI(reshaped_samples)
% get the 95% highest density intervals of the posterior
[hdi] = HDIofSamples(reshaped_samples, 0.95);
low = squeeze(hdi(1));
high = squeeze(hdi(2));
end
12 changes: 8 additions & 4 deletions ddToolbox/Data/Data.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

properties (Dependent)
totalTrials
groupTable % table of A§, DA, B, DB, R, ID, PA, PB
groupTable % table of A, DA, B, DB, R, ID, PA, PB
end

% NOTE TO SELF: These public methods need to be seen as interfaces to
Expand Down Expand Up @@ -234,9 +234,13 @@ function exportGroupDataFileToDisk(obj)
end
end

function output = getEverythingAboutAnExperiment(obj, ind)
% return a structure of everything about the data file 'ind'

function [samples] = getGroupLevelSamples(obj, fieldsToGet)
if ~obj.isUnobservedPartipantPresent()
error('Looks like we don''t have group level estimates.')
else
index = obj.data.getIndexOfUnobservedParticipant();
samples = obj.coda.getSamplesAtIndex_asStruct(index, fieldsToGet);
end
end

function participantIndexList = getParticipantIndexList(obj)
Expand Down
7 changes: 3 additions & 4 deletions ddToolbox/Data/DataImporter.m
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
obj.path = path;
obj.fnames = fnames;
obj.nFiles = numel(fnames);
obj = obj.import();

% do importing
obj.dataArray = obj.import();
disp('The following data files were imported:')
disp(fnames')
end

function obj = import(obj)
function dataArray = import(obj)
for n=1:obj.nFiles
%% do the actual file import
experimentTable = readtable(...
Expand All @@ -47,7 +47,6 @@

dataArray(n) = DataFile(experimentTable);
end
obj.dataArray = dataArray;
end

function dataArray = getData(obj)
Expand Down
65 changes: 31 additions & 34 deletions ddToolbox/DeterministicFunction/DiscountFunction.m
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function plot(obj, pointEstimateType, dataPlotType, timeUnits)
timeUnitFunction = str2func(timeUnits);
N_SAMPLES_FROM_POSTERIOR = 100;

delays = obj.determineDelayValues();
delays = obj.getDelayValues();
if verLessThan('matlab','9.1') % backward compatability
delaysDuration = delays;
else
Expand Down Expand Up @@ -59,39 +59,36 @@ function plot(obj, pointEstimateType, dataPlotType, timeUnits)

drawnow
end

function delayValues = determineDelayValues(obj)
% TODO: remove this stupid special-case handling of group-level
% participant with no data
try
maxDelayRange = max( obj.data.getDelayRange() )*1.2;
catch
% default (happens when there is no data, ie group level
% observer).
maxDelayRange = 365;
end
delayValues = linspace(0, maxDelayRange, 1000);
end

function nansPresent = anyNaNsPresent(obj)
nansPresent = false;
for field = fields(obj.theta)'
if any(isnan(obj.theta.(field{:}).samples))
nansPresent = true;
warning('NaN''s detected in theta')
break
end
end
end


end

methods (Abstract)



end




methods (Access = private)

function delayValues = getDelayValues(obj)
% TODO: remove this stupid special-case handling of group-level
% participant with no data
try
maxDelayRange = max( obj.data.getDelayRange() )*1.2;
catch
% default (happens when there is no data, ie group level
% observer).
maxDelayRange = 365;
end
delayValues = linspace(0, maxDelayRange, 1000);
end

function nansPresent = anyNaNsPresent(obj)
nansPresent = false;
for field = fields(obj.theta)'
if any(isnan(obj.theta.(field{:}).samples))
nansPresent = true;
warning('NaN''s detected in theta')
break
end
end
end

end

end
1 change: 1 addition & 0 deletions ddToolbox/DeterministicFunction/PsychometricFunction.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
function obj = PsychometricFunction(varargin)
obj = obj@DeterministicFunction(varargin{:});

% TODO: this violates dependency injection, so we may want to pass these Stochastic objects in
obj.theta.alpha = Stochastic('alpha');
obj.theta.epsilon = Stochastic('epsilon');

Expand Down
Loading

0 comments on commit bda0673

Please sign in to comment.