-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.go
135 lines (116 loc) · 2.88 KB
/
training.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package bbn
import (
"fmt"
"math"
"github.com/mlange-42/bbn/ve"
)
// Trainer is a utility type to train a [Network].
type Trainer struct {
network *Network
data [][][]float64
counter [][]int
indices [][]int
sample []int
utility []float64
}
// NewTrainer creates a new [Trainer] for the given [Network].
func NewTrainer(net *Network) Trainer {
nodes := net.Variables()
data := make([][][]float64, len(nodes))
counter := make([][]int, len(nodes))
indices := make([][]int, len(nodes))
nodeIndices := make(map[string]int, len(nodes))
maxColumns := 0
for i, node := range nodes {
nodeIndices[node.Name] = i
if len(node.Factor.Given) > maxColumns {
maxColumns = len(node.Factor.Given)
}
}
for i, node := range nodes {
columns := len(node.Outcomes)
rows := len(node.Factor.Table) / columns
d := make([][]float64, rows)
for j := 0; j < rows; j++ {
d[j] = make([]float64, columns)
}
data[i] = d
counter[i] = make([]int, rows)
idx := make([]int, len(node.Factor.Given))
for i, n := range node.Factor.Given {
var ok bool
idx[i], ok = nodeIndices[n]
if !ok {
panic(fmt.Sprintf("parent node %s for %s not found", n, node.Name))
}
}
indices[i] = idx
}
return Trainer{
network: net,
data: data,
counter: counter,
indices: indices,
sample: make([]int, 0, maxColumns),
utility: make([]float64, 0, maxColumns),
}
}
// AddSample adds a training sample.
// Order of values in the sample is the same as the order in which nodes were passed into the [Network] constructor.
func (t *Trainer) AddSample(sample []int, utility []float64) {
nodes := t.network.Variables()
for i, node := range nodes {
if node.NodeType == ve.DecisionNode {
continue
}
indices := t.indices[i]
t.sample = t.sample[:0]
for _, idx := range indices {
t.sample = append(t.sample, sample[idx])
}
if utility != nil {
t.utility = t.utility[:0]
for _, idx := range indices {
t.utility = append(t.utility, utility[idx])
}
}
idx, ok := node.Factor.rowIndex(t.sample)
if !ok {
continue
}
if node.NodeType == ve.UtilityNode {
u := utility[i]
if math.IsNaN(u) {
continue
}
t.data[i][idx][0] += u
t.counter[i][idx]++
} else {
s := sample[i]
if s < 0 {
continue
}
t.data[i][idx][s]++
}
t.counter[i][idx]++
}
}
// UpdateNetwork applies the training to the network, and returns a pointer to the original network.
func (t *Trainer) UpdateNetwork() (*Network, error) {
nodes := t.network.Variables()
for i, node := range nodes {
data := t.data[i]
cols := node.Factor.columns
rows := len(node.Factor.Table) / cols
for j := 0; j < rows; j++ {
cnt := t.counter[i][j]
if cnt == 0 {
return nil, fmt.Errorf("no samples for node '%s', table row %d", node.Name, j)
}
for k := 0; k < cols; k++ {
node.Factor.Table[j*cols+k] = float64(data[j][k]) / float64(cnt)
}
}
}
return t.network, nil
}