-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunion_find.py
96 lines (74 loc) · 2.34 KB
/
union_find.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
# https://www.cs.princeton.edu/~rs/AlgsDS07/01UnionFind.pdf
Union Find algorithms
"""
class UnionFindBase(object):
def __init__(self, n):
self.id = [None] * n
for i in range(n):
self.id[i] = i
# Abstract method, defined by convention only
def find(self, p, q):
raise NotImplementedError("Subclass must implement abstract method")
# Abstract method, defined by convention only
def unite(self, p, q):
raise NotImplementedError("Subclass must implement abstract method")
class QuickFind(UnionFindBase):
def __init__(self, n):
super(QuickFind, self).__init__(n)
# 1 operation
def find(self, p, q):
return self.id[p] == self.id[q]
# N operations
def unite(self, p, q):
pid = self.id[p]
# change all id with p's id to q's id
for i in range(len(self.id)):
if self.id[i] == pid:
self.id[i] = self.id[q]
class QuickUnion(UnionFindBase):
def __init__(self, n):
super(QuickUnion, self).__init__(n)
def root(self, i):
while i != self.id[i]:
i = self.id[i]
return i
def find(self, p, q):
return self.root(p) == self.root(q)
def unite(self, p, q):
i = self.root(p)
j = self.root(q)
self.id[i] = j
class WeightedQuickUnion(QuickUnion):
def __init__(self, n):
super(WeightedQuickUnion, self).__init__(n)
self.sz = [None] * n
for i in range(n):
self.sz[i] = 1
def unite(self, p, q):
i = self.root(p)
j = self.root(q)
if self.sz[i] < self.sz[j]:
self.id[i] = j
self.sz[j] += self.sz[i]
else:
self.id[j] = i
self.sz[i] += self.sz[j]
class WeightedQuickUnionWithPathCompression(WeightedQuickUnion):
def __init__(self, n):
super(WeightedQuickUnionWithPathCompression, self).__init__(n)
def root(self, i):
while i != self.id[i]:
self.id[i] = self.id[self.id[i]]
i = self.id[i]
return i
if __name__ == '__main__':
qf = QuickFind(10)
qf.unite(1, 3)
qn = QuickUnion(10)
qn.unite(1, 3)
wqn = WeightedQuickUnion(10)
wqn.unite(1, 3)
wqnpc = WeightedQuickUnionWithPathCompression(10)
wqnpc.unite(1, 3)
print(wqnpc.find(1, 3), wqnpc.find(1, 4))