-
Notifications
You must be signed in to change notification settings - Fork 0
/
segment_tree.go
88 lines (69 loc) · 1.7 KB
/
segment_tree.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package segmenttree
import (
"errors"
"fmt"
"math"
)
type SegmentTree interface {
GetSum(findL, findR int) int
Set(index int, value int) error
}
type segmentTree struct {
tree []int
count int
}
func NewSegmentTree(values []int) SegmentTree {
log := math.Log2(float64(len(values)))
log = math.Ceil(log)
treeLength := int(math.Pow(2, log))*2 - 1
result := &segmentTree{make([]int, treeLength), len(values)}
result.makeTree(0, 0, len(values), values)
return result
}
func (t *segmentTree) makeTree(node int, l, r int, values []int) {
if l+1 == r {
t.tree[node] = values[l]
return
}
t.makeTree(node*2+1, l, (l+r)/2, values)
t.makeTree(node*2+2, (l+r)/2, r, values)
t.tree[node] = t.tree[node*2+1] + t.tree[node*2+2]
}
func (t *segmentTree) GetSum(findL, findR int) int {
if findL < 0 {
findL = 0
}
if findR > len(t.tree)+1 {
findR = len(t.tree)
}
return t.getSum(0, 0, t.count, findL, findR)
}
func (t *segmentTree) getSum(node int, l, r int, findL, findR int) int {
if findL >= r || findR <= l {
return 0
}
if findL <= l && r <= findR {
return t.tree[node]
}
return t.getSum(node*2+1, l, (l+r)/2, findL, findR) +
t.getSum(node*2+2, (l+r)/2, r, findL, findR)
}
func (t *segmentTree) Set(idx int, value int) error {
if idx < 0 || idx >= t.count {
return errors.New(fmt.Sprintf("index out of range error: %d of %d", idx, t.count))
}
t.set(0, 0, t.count, idx, value)
return nil
}
func (t *segmentTree) set(node int, l, r int, idx int, value int) {
if idx < l || idx >= r {
return
}
if l+1 == r && idx == l {
t.tree[node] = value
return
}
t.set(node*2+1, l, (l+r)/2, idx, value)
t.set(node*2+2, (l+r)/2, r, idx, value)
t.tree[node] = t.tree[node*2+1] + t.tree[node*2+2]
}