-
Notifications
You must be signed in to change notification settings - Fork 0
/
decision-trees.js
139 lines (102 loc) · 3.7 KB
/
decision-trees.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class ID3 {
static createTree(dataset, targetAttribute, predicate) {
predicate = predicate || (example => example[targetAttribute]);
let tree = new Tree(dataset);
tree.root = this.grow(tree.root, Object.keys(dataset[0]).filter(a => a !== targetAttribute), predicate);
return tree;
}
static grow(node, predictingAttributes, predicate) {
let dataset = node.dataset;
if (this.isPure(dataset, predicate)) {
node.label = predicate(dataset[0]);
return node;
}
if (predictingAttributes.length === 0) {
node.label = this.getCommonValue(dataset, predicate);
return node;
}
predictingAttributes = predictingAttributes.filter(attribute => attribute !== node.splitAttribute);
node = this.getBestSplit(dataset, predictingAttributes, predicate);
for (let nodeKey in node.nodes) {
node.nodes[nodeKey] = this.grow(node.nodes[nodeKey], predictingAttributes, predicate);
}
return node;
}
static getBestSplit(dataset, attributes, predicate) {
let splits = attributes.map(attribute => {
return {
splitAttribute: attribute,
nodes: this.split(dataset, attribute)
}
});
let setEntropy = this.entropy(dataset, predicate);
let setSize = dataset.length;
let bestSplit = null;
splits.forEach(split => {
split.informationGain = this.informationGain(split.nodes, setEntropy, setSize, predicate);
if (split.informationGain > (bestSplit ? bestSplit.informationGain : 0)) {
bestSplit = split;
}
});
return bestSplit;
}
static informationGain(nodes, setEntropy, setSize, predicate) {
let subsetsEntropySum = 0;
for (let key in nodes) {
let subset = nodes[key].dataset;
let subsetEntropy = this.entropy(subset, predicate);
subsetsEntropySum += ((subset.length / setSize) * subsetEntropy);
}
return setEntropy - subsetsEntropySum;
}
static entropy(subset, predicate) {
let p = subset.filter(predicate).length / subset.length;
if (p === 1 || p === 0) {
return 0;
}
return -p * Math.log2(p) - (1 - p) * Math.log2(1 - p);
}
static split(dataset, splitAttribute) {
let map = {};
dataset.forEach(example => {
let exampleAttributeValue = example[splitAttribute];
if (!map[exampleAttributeValue]) {
map[exampleAttributeValue] = { dataset: [] };
}
map[exampleAttributeValue].dataset.push(example);
});
return map;
}
static isPure(subset, predicate) {
let prevValue = null;
for (let example of subset) {
let exampleValue = predicate(example);
if (prevValue !== null && prevValue !== exampleValue) {
return false;
}
prevValue = exampleValue;
}
return true;
}
static getCommonValue(dataset, predicate) {
let positiveCount = dataset.filter(predicate).length;
let negativeCount = dataset.length - positiveCount;
return positiveCount >= negativeCount;
}
}
class Tree {
constructor(dataset) {
this.root = {
dataset: dataset
};
}
classify(kase, node) {
node = node || this.root;
if (!node.nodes) {
return node.label;
}
let splitAttributeValue = kase[node.splitAttribute];
return this.classify(kase, node.nodes[splitAttributeValue]);
}
}
module.exports = ID3;