-
Notifications
You must be signed in to change notification settings - Fork 0
/
global_settings.py
117 lines (99 loc) · 3.34 KB
/
global_settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
global settings
"""
import sys
import os
from collections import OrderedDict, defaultdict
import union_find
def get_chattering_species(data_dir, atom_followed="C"):
"""
get_chattering_species
"""
try:
sys.path.append(os.path.join(data_dir, "input"))
import local_settings
return local_settings.get_chattering_species(atom_followed)
except IOError:
return OrderedDict()
def get_union_find_group(data_dir, atom_followed="C"):
"""
return union_find_groups
"""
chattering_species = get_chattering_species(data_dir, atom_followed)
counter = 0
spe_idx_label = dict()
label_spe_idx = dict()
u_set = set()
for _, pair_label in enumerate(chattering_species):
idx1 = chattering_species[pair_label][0]
idx2 = chattering_species[pair_label][1]
if int(idx1) not in u_set:
u_set.add(int(idx1))
spe_idx_label[int(idx1)] = counter
label_spe_idx[counter] = int(idx1)
counter += 1
if int(idx2) not in u_set:
u_set.add(int(idx2))
spe_idx_label[int(idx2)] = counter
label_spe_idx[counter] = int(idx2)
counter += 1
print(spe_idx_label, label_spe_idx)
wqnpc = union_find.WeightedQuickUnionWithPathCompression(len(u_set))
for _, pair_label in enumerate(chattering_species):
idx1 = chattering_species[pair_label][0]
idx2 = chattering_species[pair_label][1]
idx1 = int(spe_idx_label[int(idx1)])
idx2 = int(spe_idx_label[int(idx2)])
wqnpc.unite(idx1, idx2)
# unique labels
unique_labels = set()
for idx, _ in enumerate(spe_idx_label):
l_tmp = wqnpc.root(idx)
unique_labels.add(l_tmp)
# unique labels and their group
unique_labels_group = defaultdict(set)
for idx, _ in enumerate(spe_idx_label):
l_tmp = wqnpc.root(idx)
if l_tmp not in unique_labels_group:
unique_labels_group[l_tmp] = set()
unique_labels_group[l_tmp].add(label_spe_idx[idx])
else:
unique_labels_group[l_tmp].add(label_spe_idx[idx])
# print(unique_labels_group)
# species index and the big group it belongs to
idx_group = defaultdict(set)
for idx, _ in enumerate(spe_idx_label):
if idx in unique_labels_group:
idx_group[str(label_spe_idx[idx])] = unique_labels_group[idx]
else:
idx_group[str(label_spe_idx[idx])
] = unique_labels_group[int(wqnpc.root(int(idx)))]
print(idx_group)
return idx_group
def get_setting(data_dir):
"""
return global settings
"""
setting = {}
try:
sys.path.append(os.path.join(data_dir, "input"))
import local_settings
return local_settings.get_local_settings()
except IOError:
return setting
def get_s_a_setting(data_dir):
"""
return sensitivity analysis setting
"""
setting = {}
try:
sys.path.append(os.path.join(data_dir, "input"))
import local_settings
return local_settings.get_s_a_setting()
except IOError:
return setting
if __name__ == '__main__':
DATA_DIR = os.path.abspath(os.path.join(os.path.realpath(
sys.argv[0]), os.pardir, os.pardir, os.pardir, os.pardir, "SOHR_DATA"))
get_union_find_group(DATA_DIR, atom_followed="C")
print("test")