-
Notifications
You must be signed in to change notification settings - Fork 3
/
configureRandomForests.m
59 lines (53 loc) · 2.2 KB
/
configureRandomForests.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
function [learner] = configureRandomForests(getPRNG,n,v,t,d,k,seed)
%Configures the Random Forests function for a given PRNG, k,d, and seed, which
%uniquely define the training and validation databestErr = inf;
%Random Forests has thre configurable parameters, the format of the features we are
%inputing to it, the depth of the trees, and the number of trees. We are
%configuring using 2 values for feature type, 5 values for depth, and 5
%INPUT:
%values for the number of trees, results in 50 configurations being tried.
%getPRNG - The random number generator function, see below for usage
%n - number of training points
%v - number of validation points
%t - number of testing points
%d - the random number seed to be used by the PRNG
%k - The number of labels outputed by the PRNG
%Random Forests always uses a label size of 1.
labelSize = 1;
bestErr = inf;
for featureType = ['s','c']
%Save and restore the random number sequence
s = rng;
[X,y,Xval,yval,~,~] = getPRNG(n,v,t,d,k,featureType,labelSize,seed);
rng(s);
for depth = [3,4,5,6,7]
for nTrees = [50,75,100,125,150]
%Train on X test, validate on Xval
model = randomForest(X,y,depth,nTrees);
yhat = model.predict(model,Xval);
%Compute the validation error, this depends on how the labels
%are formatted.
if(labelSize == 1)
err = sum(yhat ~= yval)/t;
else
err = 1-sum(all(yhat' == yval'))/t;
end
%If we've found a configuration with better validation error we
%save it to be returned later
if(err < bestErr)
bestErr = err;
learner.train = @(X,y) randomForest(X,y,depth,nTrees);
learner.featureType = featureType;
learner.depth = depth;
learner.nTrees = nTrees;
learner.labelSize = labelSize;
learner.name = 'Random Forests';
save configureRandomForests;
end
end
end
end
if(bestErr == inf)
error('Something went wrong, we were unable to find any configuration with better validation error than infinity.');
end
end