Skip to content

Commit

Permalink
Merge pull request #206 from drbenvincent/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
drbenvincent authored Mar 7, 2018
2 parents 3b11b7b + 8c15421 commit 79b320f
Show file tree
Hide file tree
Showing 13 changed files with 65 additions and 129 deletions.
67 changes: 24 additions & 43 deletions ddToolbox/PosteriorPrediction.m
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
obj.controlModelProbChooseDelayed(p) = obj.postPred(p).proportion_chose_delayed;

% Calculate metrics
obj.postPred(p).score = obj.calcPostPredOverallScore(responses_predicted, responses_actual, obj.controlModelProbChooseDelayed(p));
obj.postPred(p).GOF_distribtion = obj.calcGoodnessOfFitDistribution(responses_inferredPB, responses_actual, obj.controlModelProbChooseDelayed(p));
obj.postPred(p).log_loss_distribution = obj.calcLogLossDistribution(responses_inferredPB, responses_actual);
obj.postPred(p).percentPredictedDistribution = obj.calcPercentResponsesCorrectlyPredicted(responses_inferredPB, responses_actual);
% Store
obj.postPred(p).responses_actual = responses_actual;
Expand All @@ -56,14 +55,14 @@ function plot(obj, plotOptions, modelFilename, model)

% PUBLIC GETTERS --------------------------------------------------

function score = getScores(obj)
score = [obj.postPred(:).score]';
end

function pp = getPercentPredictedDistribution(obj)
pp = {obj.postPred(:).percentPredictedDistribution};
end

function pp = getLogLossDistribution(obj)
pp = {obj.postPred(:).log_loss_distribution};
end

function postPredTable = getPostPredTable(obj)
postPredTable = obj.postPredTable;
end
Expand All @@ -80,17 +79,17 @@ function posterior_prediction_figure(obj, n, plotOptions, plotDiscountFunction)
% Arrange subplots
h = layout([1 4; 2 3]);
subplot(h(1)), obj.pp_plotTrials(n)
subplot(h(2)), obj.pp_plotGOFdistribution(n, plotOptions)
subplot(h(2)), obj.pp_plotLogLossDistribution(n, plotOptions)
subplot(h(3)), obj.pp_plotPercentPredictedDistribution(n, plotOptions)
plotDiscountFunction(n, h(4))

drawnow
end

function pp_plotGOFdistribution(obj, n, plotOptions)
function pp_plotLogLossDistribution(obj, n, plotOptions)
uni = mcmc.UnivariateDistribution(...
obj.postPred(n).GOF_distribtion(:),...
'xLabel', 'goodness of fit score',...
obj.postPred(n).log_loss_distribution(:),...
'xLabel', 'Log Loss',...
'plotStyle','hist',...
'pointEstimateType', plotOptions.pointEstimateType);
end
Expand Down Expand Up @@ -139,11 +138,12 @@ function pp_plotTrials(obj, n)
end

function postPredTable = makePostPredTable(obj, data, pointEstimateType)
postPredTable = table(obj.getScores(),...
postPredTable = table(...
obj.calc_percent_predicted_point_estimate(pointEstimateType),...
obj.calc_log_loss_point_estimate(pointEstimateType),...
obj.any_percent_predicted_warnings(),...
'RowNames', data.getIDnames('experiments'),...
'VariableNames',{'ppScore' 'percentPredicted' 'warning_percent_predicted'});
'VariableNames',{'percentPredicted' 'LogLoss' 'warning_percent_predicted'});

if data.isUnobservedPartipantPresent()
% add extra row of NaN's on the bottom for the unobserved participant
Expand All @@ -163,6 +163,14 @@ function pp_plotTrials(obj, n)
obj.getPercentPredictedDistribution())';
end

function percentPredicted = calc_log_loss_point_estimate(obj, pointEstimateType)
% Calculate point estimates of perceptPredicted. use the point
% estimate type that the user specified
pointEstFunc = str2func(pointEstimateType);
percentPredicted = cellfun(pointEstFunc,...
obj.getLogLossDistribution())';
end

function pp_warning = any_percent_predicted_warnings(obj)
% warnings when we have less than 95% confidence that we can
% predict more responses than the control model
Expand All @@ -179,15 +187,10 @@ function pp_plotTrials(obj, n)

methods (Access = private, Static)

function [score] = calcGoodnessOfFitDistribution(responses_predictedMCMC, responses_actual, proportion_chose_delayed)
% Expand the participant responses so we can do vectorised calculations below
totalSamples = size(responses_predictedMCMC,2);
responses_actual = repmat(responses_actual, [1,totalSamples]);
responses_control_model = ones(size(responses_actual)) .* proportion_chose_delayed;

score = calcLogOdds(...
calcDataLikelihood(responses_actual, responses_predictedMCMC),...
calcDataLikelihood(responses_actual, responses_control_model));
function logloss = calcLogLossDistribution(predicted, actual)
% log loss for binary variables
logloss = - (sum(actual .* log(predicted) + (1 - actual) ...
.* log(1 - predicted))) ./ length(actual);
end

function percentResponsesPredicted = calcPercentResponsesCorrectlyPredicted(responses_predictedMCMC, responses_actual)
Expand All @@ -202,17 +205,6 @@ function pp_plotTrials(obj, n)
percentResponsesPredicted = sum(isCorrectPrediction,1)./nQuestions;
end

function score = calcPostPredOverallScore(responses_predicted, responses_actual, proportion_chose_delayed)
% Calculate log posterior odds of data under the model and a
% control model where prob of responding is proportion_chose_delayed.

% NOTE: This is model comparison, not posterior prediction
responses_control_model = ones(size(responses_predicted)).*proportion_chose_delayed;
score = calcLogOdds(...
calcDataLikelihood(responses_actual, responses_predicted'),...
calcDataLikelihood(responses_actual, responses_control_model'));
end

function exportFigure(plotOptions, prefix_string, modelFilename)
% TODO: Exporting is not the responsibility of PosteriorPrediction class, so we need to extract this up to Model subclasses. They call it as: obj.postPred.plot(obj.plotOptions, obj.modelFilename)
if plotOptions.shouldExportPlots
Expand All @@ -227,14 +219,3 @@ function exportFigure(plotOptions, prefix_string, modelFilename)
end

end

function logOdds = calcLogOdds(a,b)
logOdds = log(a./b);
end

function dataLikelihood = calcDataLikelihood(responses, predicted)
% Responses are Bernoulli distributed: a special case of the Binomial with 1 event.
dataLikelihood = prod(binopdf(responses, ...
ones(size(responses)),...
predicted));
end
11 changes: 4 additions & 7 deletions demo/demo_group_comparison_repeated_measures.m
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,23 @@
'nchains', 4,...
'nburnin', 2000);
pointEstimateType = 'median';
%% Analyse group 1

%% Analyse group 1
datapath1 = fullfile(path_of_this_mfile,'datasets','group_comparison','group1');
group1 = ModelHierarchicalLogK(...
Data(datapath1, 'files', allFilesInFolder(datapath1, 'txt')),...
'savePath', fullfile(pwd,'output','group1'),...
'pointEstimateType', pointEstimateType,...
'sampler', 'jags',...
'shouldPlot', 'no',...
'shouldExportPlots', false,...
'mcmcParams', mcmcparams);
%% Analyse group 2

%% Analyse group 2
datapath2 = fullfile(path_of_this_mfile,'datasets','group_comparison','group2');
group2 = ModelHierarchicalLogK(...
Data(datapath2, 'files', allFilesInFolder(datapath2, 'txt')),...
'savePath', fullfile(pwd,'output','group2'),...
'pointEstimateType', pointEstimateType,...
'sampler', 'jags',...
'shouldPlot', 'no',...
'shouldExportPlots', false,...
'mcmcParams', mcmcparams);
Expand Down Expand Up @@ -84,4 +82,3 @@


% METHOD 2)

3 changes: 0 additions & 3 deletions demo/run_me.m
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
% Data(datapath, 'files', allFilesInFolder(datapath, 'txt')),...
% 'savePath', 'analysis_with_hierarchical_magnitude_effect',...
% 'pointEstimateType','median',...
% 'sampler', 'jags',...
% 'shouldPlot', 'no',...
% 'mcmcParams', struct('nsamples', 10^4,...
% 'nchains', 4,...
Expand Down Expand Up @@ -109,7 +108,6 @@
Data(datapath, 'files', allFilesInFolder(datapath, 'txt')),...
'savePath', fullfile(pwd,'output','my_analysis'),...
'pointEstimateType', 'median',...
'sampler', 'jags',...
'shouldPlot', 'yes',...
'shouldExportPlots', true,...
'exportFormats', {'png'},...
Expand All @@ -122,7 +120,6 @@
% running proper analyses. I have provided small numbers here just to
% confirm the code is working without having to wait a long time.
% - you can change the point estimate type to mean, median, or mode
% - the sampler can be 'jags'


% If we didn't ask for plots when we ran the model, then we do that
Expand Down
16 changes: 0 additions & 16 deletions docs/discussion/stan.md

This file was deleted.

24 changes: 11 additions & 13 deletions docs/howto/store_raw_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ We assume that we have one data file for every experiment run. So even if we hav
Each data file should be a tab-delimited `.txt` file containing the trial data for an individual participant. The format of these participant data files needs to contain 5 columns, with headers `A`, `DA`, `B`, `DB`, `R`. Each row corresponds to one experimental trial.

A DA B DB R
80 0 85 157 0
34 0 50 30 1
25 0 60 14 1
11 0 30 7 1
49 0 60 89 0
80 0 85 157 A
34 0 50 30 B
25 0 60 14 B
11 0 30 7 B
49 0 60 89 A
etc

Optionally, we can add columns for the probability of obtaining the reward, such as

A DA PA B DB PB R
80 0 1 85 157 1 0
34 0 1 50 30 1 1
25 0 1 60 14 1 1
11 0 1 30 7 1 1
49 0 1 60 89 1 0
80 0 1 85 157 1 A
34 0 1 50 30 1 B
25 0 1 60 14 1 B
11 0 1 30 7 1 B
49 0 1 60 89 1 A
etc

Column names mean:
Expand All @@ -37,8 +37,6 @@ Column names mean:
- `PB` probability of achieving the reward
- `R` is the participant response.

Note that the the coding of responses are:
* participant chooses delayed reward, ie option B, then `R = 1`
* participant chooses sooner reward, ie option A, then `R = 0`
The preferred way of coding participant responses `R` is by the unambiguous label `A` or `B`. An earlier (and less good, ambiguous) coding scheme was `R = 1 for chose delayed` and `R = 0 for chose immediate`.

The ordering of the columns should not matter, as the data is extracted using the column header name (i.e. first row of the file).
1 change: 0 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,3 @@ Vincent, B. T. (2016) **[Hierarchical Bayesian estimation and hypothesis testing
- [Hyperpriors / parameter pooling](discussion/hyperpriors.md)
- [Level of parameter pooling](discussion/level_of_pooling.md)
- [Hypothesis testing](discussion/hypothesis_testing.md)
- [STAN](discussion/stan.md)
Binary file added docs/ref/bayes_graphical_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions docs/ref/discount_functions.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Discount functions available

A number of discounting models are available, all with the following basic form. The models differ in either a) presence/absence of hyperpriors, b) the discount functions in the data generating process which specifies the response probability.

![](bayes_graphical_model.png)

The 'data generating process' describes the response probability. This consists of the psychometric link function which captures response errors, and the discount function which computes the present subjective value of prospects. See the paper for details.

| Discount function | Equation | Model suffix | Main parameters |
| :--- | :---: | :---: | :---: |
| Exponential | `exp(-k*D)` | `*Exp1` | `k` |
Expand Down
2 changes: 0 additions & 2 deletions docs/ref/param_estimation_options.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ The rest of the arguments are optional key/value pairs
| --- | --- |
| `'timeUnits'` | `'minutes'` or `'hours'` or `'days'` [default] |
| `'pointEstimateType'` | `'mean'` or `'median'` or `'mode'` [default] |
| `'sampler'` | `'jags'` [default] or `'stan'` NOTE: stan models are in beta|
| `'shouldPlot'` | `'yes'` or `'no'` [default]|
| `'shouldExportPlots'` | `'true'` [default] or `'false'` |
| `'exportFormats'` | a cell array of output formats, e.g. `{'png', 'pdf'}` (default is `{'png'}` only) |
Expand All @@ -54,7 +53,6 @@ model = ModelHierarchicalME(Data(datapath, 'files', allFilesInFolder(datapath, '
'timeUnits', 'days',...
'savePath', fullfile(pwd,'output','my_analysis'),...
'pointEstimateType', 'mode',...
'sampler', 'jags',...
'shouldPlot', 'no',...
'shouldExportPlots', false,...
'exportFormats', {'png'},...
Expand Down
17 changes: 1 addition & 16 deletions tests/test_AllNonParametricModels.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
properties (TestParameter)
model = {'ModelSeparateNonParametric'}
pointEstimateType = {'mean','median','mode'}
sampler = {'jags'}
chains = {2,3,4}
end

Expand Down Expand Up @@ -82,19 +81,6 @@ function nChains(testCase, model, chains)
end


function specifiedSampler(testCase, model, sampler)
% make model
makeModelFunction = str2func(model);
model = makeModelFunction(testCase.data,...
'savePath', testCase.savePath,...
'sampler', sampler,...
'shouldPlot','no',...
'mcmcParams', struct('nsamples', get_numer_of_samples_for_tests(),...
'nchains', 2,...
'nburnin', get_burnin_for_tests()));
% TODO: DO AN ACTUAL TEST HERE !!!!!!!!!!!!!!!!!!!!!!
end

function plotting(testCase, model)
% make model
makeModelFunction = str2func(model);
Expand All @@ -109,12 +95,11 @@ function plotting(testCase, model)
% TODO: DO AN ACTUAL TEST HERE !!!!!!!!!!!!!!!!!!!!!!
end

function model_disp_function(testCase, model, sampler)
function model_disp_function(testCase, model)
% make model
makeModelFunction = str2func(model);
modelFitted = makeModelFunction(testCase.data,...
'savePath', testCase.savePath,...
'sampler', sampler,...
'mcmcParams', struct('nsamples', get_numer_of_samples_for_tests(),...
'nchains', 2,...
'nburnin', get_burnin_for_tests()),...
Expand Down
10 changes: 3 additions & 7 deletions tests/test_AllParametricModels.m
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
end

properties (TestParameter)
sampler = {'jags'};
model = getAllParametricModelNames();
pointEstimateType = {'mean','median','mode'}
chains = {2,3,4}
Expand Down Expand Up @@ -84,12 +83,11 @@ function nChains(testCase, model, chains)
end


function specifiedSampler(testCase, model, sampler)
function specifiedSampler(testCase, model)
% make model
makeModelFunction = str2func(model);
model = makeModelFunction(testCase.data,...
'savePath', testCase.savePath,...
'sampler', sampler,...
'mcmcParams', struct('nsamples', get_numer_of_samples_for_tests(),...
'nchains', 2,...
'nburnin', get_burnin_for_tests()),...
Expand All @@ -99,12 +97,11 @@ function specifiedSampler(testCase, model, sampler)
end


function getting_predicted_values(testCase, model, sampler)
function getting_predicted_values(testCase, model)
% make model
makeModelFunction = str2func(model);
modelFitted = makeModelFunction(testCase.data,...
'savePath', testCase.savePath,...
'sampler', sampler,...
'mcmcParams', struct('nsamples', get_numer_of_samples_for_tests(),...
'nchains', 2,...
'nburnin', get_burnin_for_tests()),...
Expand All @@ -123,12 +120,11 @@ function getting_predicted_values(testCase, model, sampler)
end


function model_disp_function(testCase, model, sampler)
function model_disp_function(testCase, model)
% make model
makeModelFunction = str2func(model);
modelFitted = makeModelFunction(testCase.data,...
'savePath', testCase.savePath,...
'sampler', sampler,...
'mcmcParams', struct('nsamples', get_numer_of_samples_for_tests(),...
'nchains', 2,...
'nburnin', get_burnin_for_tests()),...
Expand Down
Loading

0 comments on commit 79b320f

Please sign in to comment.