-
Notifications
You must be signed in to change notification settings - Fork 0
/
ens_model_predict.m
87 lines (74 loc) · 3.09 KB
/
ens_model_predict.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
function [y_mean, y_sigma, y_int] = ens_model_predict(models, x, vars, ...
params)
% [y_mean, y_sigma, x_int] = ens_model_predict(model, x, vars, params)
% Make predictions with an ensemble of models.
%
% Returns
% y_mean (n, ny) double
% Expected values of y at each x(i,:), i = 1, 2, ... n.
% y_sigma (n, ny) double
% Standard deviations of the uncertainty of the
% predictions y_mean(i,:) at each x(i,:).
% y_int (n, 2*ny) double
% Lower and upper confidence intervals for each
% prediction y_mean(i,:). The first 1:n columns are
% the lower bounds, columns n+1:2*n are the upper
% bounds.
%
n = size(x, 1);
model_names = fieldnames(models);
n_models = numel(model_names);
y_means = nan(n, n_models);
y_sigmas = nan(n, n_models);
y_ints = nan(n, 2, n_models);
switch params.method
case "bagging"
% Base model name - only one type allowed currently
base_model_names = string(fieldnames(params.base_models));
assert(numel(base_model_names) == 1)
base_model_name = base_model_names(1);
for i = 1:n_models
model_name = model_names{i};
% Make predictions with each sub-model
% Note: builtin() is needed here because other code in the
% MATLAB workspace overrides the built-in feval function.
[y_means(:, i), y_sigmas(:, i), y_ints(:, :, i)] = builtin("feval", ...
params.base_models.(base_model_name).predictFcn, ...
models.(model_name), ...
x, ...
vars.(model_name), ...
params.base_models.(base_model_name).params ...
);
end
case "boosting"
error("NotImplementedError")
case "stacking" % for heterogenous models
for i = 1:n_models
model_name = model_names{i};
% Make predictions with each sub-model
% Note: builtin() is needed here because other code in the
% MATLAB workspace overrides the built-in feval function.
[y_means(:, i), y_sigmas(:, i), y_ints(:, :, i)] = builtin("feval", ...
params.models.(model_name).predictFcn, ...
models.(model_name), ...
x, ...
vars.(model_name), ...
params.models.(model_name).params ...
);
end
end
% Make combined predictions, std. dev., and conf. interval
% y_sigma = mean(y_sigmas, 2); %TODO: is this the right way?
% y_int = [min(y_ints(:, 1, :), [], 3) ...
% max(y_ints(:, 2, :), [], 3)];
y_mean = nan(size(x));
y_sigma = nan(size(x));
y_int = nan(size(x, 1), 2);
alpha = vars.significance;
for j = 1:length(x)
mus = y_means(j, :)';
sigmas = y_sigmas(j, :)';
[y_mean(j), y_int(j, :), y_sigma(j)] = ...
mix_gaussians(mus, sigmas, alpha);
end
end