-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathminimum_steiner_tree.py
121 lines (99 loc) · 3.73 KB
/
minimum_steiner_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import itertools
import numpy as np
from graph_tool import Graph, GraphView
# from graph_tool.topology import shortest_distance
from graph_tool.topology import min_spanning_tree, shortest_distance
from graph_helpers import extract_edges_from_pred, edge2tuple
def build_closure(g, terminals,
p=None,
debug=False,
verbose=False):
"""build the transitive closure on terminals"""
def get_edges(dist, root, terminals):
"""get adjacent edges to root with weight"""
return {(root, t, dist[t])
for t in terminals
if dist[t] != -1 and t != root}
terminals = list(terminals)
gc = Graph(directed=False)
gc.add_vertex(g.num_vertices())
edges_with_weight = set()
r2pred = {} # root to predecessor map (from bfs)
# shortest path to all other nodes
for r in terminals:
if debug:
print('root {}'.format(r))
targets = list(set(terminals) - {r})
dist_map, pred_map = shortest_distance(
g,
source=r,
target=targets,
weights=p,
pred_map=True)
dist_map = dict(zip(targets, dist_map))
# print(dist_map)
# print(pred_map)
new_edges = get_edges(dist_map, r, targets)
# if p is None:
# vis = init_visitor(g, r)
# bfs_search(g, source=r, visitor=vis)
# new_edges = set(get_edges(vis.dist, r, terminals))
# else:
# print('weighted graph')
if debug:
print('new edges {}'.format(new_edges))
edges_with_weight |= new_edges
# r2pred[r] = vis.pred
r2pred[r] = pred_map
for u, v, c in edges_with_weight:
gc.add_edge(u, v)
# edge weights
eweight = gc.new_edge_property('int')
weights = np.array([c for _, _, c in edges_with_weight])
eweight.set_2d_array(weights)
vfilt = gc.new_vertex_property('bool')
vfilt.a = False
for v in terminals:
vfilt[v] = True
gc.set_vertex_filter(vfilt)
return gc, eweight, r2pred
def min_steiner_tree(g, obs_nodes, p=None, return_type='tree', debug=False, verbose=False):
assert len(obs_nodes) > 0, 'no terminals'
if g.num_vertices() == len(obs_nodes):
print('it\'s a minimum spanning tree problem')
gc, eweight, r2pred = build_closure(g, obs_nodes, p=p,
debug=debug, verbose=verbose)
# print('gc', gc)
tree_map = min_spanning_tree(gc, eweight, root=None)
tree = GraphView(gc, directed=False, efilt=tree_map)
tree_edges = set()
for e in tree.edges():
u, v = map(int, e)
recovered_edges = extract_edges_from_pred(u, v, r2pred[u])
assert recovered_edges, 'empty!'
for i, j in recovered_edges:
tree_edges.add((i, j))
tree_nodes = list(set(itertools.chain(*tree_edges)))
if return_type == 'nodes':
return tree_nodes
elif return_type == 'edges':
return list(map(edge2tuple, tree_edges))
elif return_type == 'tree':
vfilt = g.new_vertex_property('bool')
vfilt.set_value(False)
for n in tree_nodes:
vfilt[n] = True
efilt = g.new_edge_property('bool')
for i, j in tree_edges:
efilt[g.edge(i, j)] = 1
subg = GraphView(g, efilt=efilt, vfilt=vfilt, directed=False)
if p is not None:
weights = subg.new_edge_property('float')
for e in subg.edges():
weights[e] = p[e]
else:
weights = None
# remove cycles
tree_map = min_spanning_tree(subg, weights, root=None)
t = GraphView(g, directed=False, vfilt=vfilt, efilt=tree_map)
return t