-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
155 lines (113 loc) · 4.26 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from STM import SpeedTransitionMatrix
from misc import database, config
from misc.misc import plot_heatmap
import numpy as np
from scipy.spatial import distance
import math
import pandas as pd
import tensorly as ty
from tensorly.decomposition import non_negative_parafac
def get_mass_center(m):
max_val = 0.2 * np.max(m) # Filter: remove 10% of maximal value.
m = np.where(m < max_val, 0, m)
m = m / np.sum(m)
# marginal distributions
dx = np.sum(m, 1)
dy = np.sum(m, 0)
# expected values
X, Y = m.shape
cx = np.sum(dx * np.arange(X))
cy = np.sum(dy * np.arange(Y))
return int(cx), int(cy)
def diag_dist(point):
# Max distance to the diagonal (square matrix m x m) is: diagonal_length / 2.
max_d = (config.MAX_INDEX * math.sqrt(2)) / 2
distan = []
for d in config.DIAG_LOCS:
distan.append(distance.euclidean(d, point))
return round(min(distan) / max_d * 100, 2) # Relative distance.
def get_link_ids_square(points, link_info):
xb = link_info.x_b
yb = link_info.y_b
p1x = points[0][0]
p1y = points[0][1]
p2x = points[1][0]
p2y = points[0][1]
link_ids = link_info[(link_info.x_b > points[0][0])
& (link_info.y_b < points[0][1])
& (link_info.x_b < points[1][0])
& (link_info.y_b > points[1][1])]
return link_ids.link_id.values
def main():
config.initialize_paths()
config.initialize_db_setup()
config.initialize_stm_setup()
db, client = database.init('SpeedTransitionDB')
col_name = "spatialMatrixRWLNEWrel"
'''
p1 = (15.942489, 45.781501) # upped left
p2 = (15.961205, 45.774961) # lower right
info = pd.read_csv(r'links_info.csv', sep=';')
links_inside = get_link_ids_square(points=(p1, p2), link_info=info)
c = 0
frontal_slices = []
temp = []
try:
n_intervals = 8
for interval in range(0, n_intervals):
for link in links_inside:
transitions = database.selectSome(db, col_name, {'origin_id': int(link)})
for tran in transitions:
matrix = np.array(tran['intervals'][interval]['winter']['working'])
if int(np.sum(matrix)) > 20:
temp.append(list(matrix.flatten()))
c += 1
# temp = np.array(temp).reshape((400, len(temp)))
frontal_slices.append(temp)
temp = []
except:
print('Warning: There are no transitions with oringin_id: %s' % link)
slices_length = [len(slice) for slice in frontal_slices]
n_trans = min(slices_length)
tensor = np.zeros((400, n_trans, 8))
for f_slice_id in range(0, len(frontal_slices)):
for matrix_id in range(0, len(frontal_slices[f_slice_id])):
if matrix_id >= n_trans:
continue
tensor[:, matrix_id, f_slice_id] = frontal_slices[f_slice_id][matrix_id]
factors = non_negative_parafac(tensor=ty.tensor(tensor), rank=10, verbose=0)
# xxx = factors.factors[0][:, 0].reshape(20, 20)
# xxx = xxx / np.sum(xxx)
# xxx = np.round(xxx, decimals=2)
# plot_heatmap(xxx, 'ddd')
#
#
# yyy = xxx.tolist()
i = 0
for column in range(0, factors.factors[0].shape[1]):
xxx = factors.factors[0][:, column].reshape(20, 20)
xxx = xxx / np.sum(xxx)
xxx = np.round(xxx, decimals=2)
plot_heatmap(xxx, 'ddd')
#plot_heatmap(factors.factors[0][:, column].reshape(20, 20), 'Factor: ' + str(i))
i += 1
# char_matrices = list([])
# chm = {'orig': xxx,
# 'rounded': xxx,
# 'xy_position': [i, j],
# 'com_position': [cx, cy],
# 'com_diag_dist': 0,
# 'class': 0
# }
'''
links = [214697, 214696, 214695, 214694]
stm = SpeedTransitionMatrix(db=db, client=client, collection_name=col_name)
stm.get_consecutive_data(links=links)
# m = np.array(stm.data[2]['intervals'][4]['winter']['working'])
# x, y = get_mass_center(m)
# dd = diag_dist(point=(x, y))
# print()
stm.plot_consecutive_data(dataset_type='winter', days_type='working', intervals='all', output='show')
database.closeConnection(client)
if __name__ == "__main__":
main()