-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsumtree.py
51 lines (44 loc) · 1.63 KB
/
sumtree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Node:
def __init__(self, left, right, is_leaf: bool = False, idx: int = None):
self.left = left
self.right = right
self.is_leaf = is_leaf
if not self.is_leaf:
self.value = self.left.value + self.right.value
self.parent = None
self.idx = idx
if left is not None:
left.parent = self
if right is not None:
right.parent = self
@classmethod
def create_leaf(cls, value: float, idx: int):
leaf = cls(None, None, is_leaf=True, idx=idx)
leaf.value = value
return leaf
class SumTree:
def __init__(self, inputs: list):
self.root_node, self.leaf_nodes = self.create_tree(inputs)
@staticmethod
def create_tree(input: list):
nodes = [Node.create_leaf(v, i) for i, v in enumerate(input)]
leaves = nodes
while len(nodes) > 1:
inodes = iter(nodes)
nodes = [Node(*pair) for pair in zip(inodes, inodes)]
return nodes[0], leaves
def get_node(self, value: float, node: Node):
if node.is_leaf:
return node
if node.left.value >= value:
return self.get_node(value, node.left)
else:
return self.get_node(value - node.left.value, node.right)
def update_node(self, node: Node, new_value: float):
change = new_value - node.value
node.value = new_value
self.propagate_changes(change, node.parent)
def propagate_changes(self, change: float, node: Node):
node.value += change
if node.parent is not None:
self.propagate_changes(change, node.parent)