forked from jsyoon0823/TimeGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
timegan.py
310 lines (242 loc) · 11.9 KB
/
timegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""Time-series Generative Adversarial Networks (TimeGAN) Codebase.
Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar,
"Time-series Generative Adversarial Networks,"
Neural Information Processing Systems (NeurIPS), 2019.
Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks
Last updated Date: April 24th 2020
Code author: Jinsung Yoon (jsyoon0823@gmail.com)
-----------------------------
timegan.py
Note: Use original data as training set to generater synthetic data (time-series)
"""
# Necessary Packages
import tensorflow as tf
import numpy as np
from tf_slim.layers import layers as slim_layers
from utils import extract_time, rnn_cell, random_generator, batch_generator
def timegan (ori_data, parameters):
"""TimeGAN function.
Use original data as training set to generater synthetic data (time-series)
Args:
- ori_data: original time-series data
- parameters: TimeGAN network parameters
Returns:
- generated_data: generated time-series data
"""
# Initialization on the Graph
tf.compat.v1.reset_default_graph()
# Basic Parameters
no, seq_len, dim = np.asarray(ori_data).shape
# Maximum sequence length and each sequence length
ori_time, max_seq_len = extract_time(ori_data)
def MinMaxScaler(data):
"""Min-Max Normalizer.
Args:
- data: raw data
Returns:
- norm_data: normalized data
- min_val: minimum values (for renormalization)
- max_val: maximum values (for renormalization)
"""
min_val = np.min(np.min(data, axis = 0), axis = 0)
data = data - min_val
max_val = np.max(np.max(data, axis = 0), axis = 0)
norm_data = data / (max_val + 1e-7)
return norm_data, min_val, max_val
# Normalization
ori_data, min_val, max_val = MinMaxScaler(ori_data)
## Build a RNN networks
# Network Parameters
hidden_dim = parameters['hidden_dim']
num_layers = parameters['num_layer']
iterations = parameters['iterations']
batch_size = parameters['batch_size']
module_name = parameters['module']
z_dim = dim
gamma = 1
# Input place holders
tf.compat.v1.disable_eager_execution()
X = tf.compat.v1.placeholder(tf.float32, [None, max_seq_len, dim], name = "myinput_x")
Z = tf.compat.v1.placeholder(tf.float32, [None, max_seq_len, z_dim], name = "myinput_z")
T = tf.compat.v1.placeholder(tf.int32, [None], name = "myinput_t")
def embedder (X, T):
"""Embedding network between original feature space to latent space.
Args:
- X: input time-series features
- T: input time information
Returns:
- H: embeddings
"""
with tf.compat.v1.variable_scope("embedder", reuse = tf.compat.v1.AUTO_REUSE):
e_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([rnn_cell(module_name, hidden_dim) for _ in range(num_layers)])
e_outputs, e_last_states = tf.compat.v1.nn.dynamic_rnn(e_cell, X, dtype=tf.float32, sequence_length=T)
H = slim_layers.fully_connected(e_outputs, hidden_dim, activation_fn=tf.nn.sigmoid)
return H
def recovery (H, T):
"""Recovery network from latent space to original space.
Args:
- H: latent representation
- T: input time information
Returns:
- X_tilde: recovered data
"""
with tf.compat.v1.variable_scope("recovery", reuse = tf.compat.v1.AUTO_REUSE):
r_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([rnn_cell(module_name, hidden_dim) for _ in range(num_layers)])
r_outputs, r_last_states = tf.compat.v1.nn.dynamic_rnn(r_cell, H, dtype=tf.float32, sequence_length = T)
X_tilde = slim_layers.fully_connected(r_outputs, dim, activation_fn=tf.nn.sigmoid)
return X_tilde
def generator (Z, T):
"""Generator function: Generate time-series data in latent space.
Args:
- Z: random variables
- T: input time information
Returns:
- E: generated embedding
"""
with tf.compat.v1.variable_scope("generator", reuse = tf.compat.v1.AUTO_REUSE):
e_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([rnn_cell(module_name, hidden_dim) for _ in range(num_layers)])
e_outputs, e_last_states = tf.compat.v1.nn.dynamic_rnn(e_cell, Z, dtype=tf.float32, sequence_length = T)
E = slim_layers.fully_connected(e_outputs, hidden_dim, activation_fn=tf.nn.sigmoid)
return E
def supervisor (H, T):
"""Generate next sequence using the previous sequence.
Args:
- H: latent representation
- T: input time information
Returns:
- S: generated sequence based on the latent representations generated by the generator
"""
with tf.compat.v1.variable_scope("supervisor", reuse = tf.compat.v1.AUTO_REUSE):
e_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([rnn_cell(module_name, hidden_dim) for _ in range(num_layers-1)])
e_outputs, e_last_states = tf.compat.v1.nn.dynamic_rnn(e_cell, H, dtype=tf.float32, sequence_length = T)
S = slim_layers.fully_connected(e_outputs, hidden_dim, activation_fn=tf.nn.sigmoid)
return S
def discriminator (H, T):
"""Discriminate the original and synthetic time-series data.
Args:
- H: latent representation
- T: input time information
Returns:
- Y_hat: classification results between original and synthetic time-series
"""
with tf.compat.v1.variable_scope("discriminator", reuse = tf.compat.v1.AUTO_REUSE):
d_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([rnn_cell(module_name, hidden_dim) for _ in range(num_layers)])
d_outputs, d_last_states = tf.compat.v1.nn.dynamic_rnn(d_cell, H, dtype=tf.float32, sequence_length = T)
Y_hat = slim_layers.fully_connected(d_outputs, 1, activation_fn=None)
return Y_hat
# Embedder & Recovery
H = embedder(X, T)
X_tilde = recovery(H, T)
# Generator
E_hat = generator(Z, T)
H_hat = supervisor(E_hat, T)
H_hat_supervise = supervisor(H, T)
# Synthetic data
X_hat = recovery(H_hat, T)
# Discriminator
Y_fake = discriminator(H_hat, T)
Y_real = discriminator(H, T)
Y_fake_e = discriminator(E_hat, T)
# Variables
e_vars = [v for v in tf.compat.v1.trainable_variables() if v.name.startswith('embedder')]
r_vars = [v for v in tf.compat.v1.trainable_variables() if v.name.startswith('recovery')]
g_vars = [v for v in tf.compat.v1.trainable_variables() if v.name.startswith('generator')]
s_vars = [v for v in tf.compat.v1.trainable_variables() if v.name.startswith('supervisor')]
d_vars = [v for v in tf.compat.v1.trainable_variables() if v.name.startswith('discriminator')]
# Discriminator loss
D_loss_real = tf.compat.v1.losses.sigmoid_cross_entropy(tf.ones_like(Y_real), Y_real)
D_loss_fake = tf.compat.v1.losses.sigmoid_cross_entropy(tf.zeros_like(Y_fake), Y_fake)
D_loss_fake_e = tf.compat.v1.losses.sigmoid_cross_entropy(tf.zeros_like(Y_fake_e), Y_fake_e)
D_loss = D_loss_real + D_loss_fake + gamma * D_loss_fake_e
# Generator loss
# 1. Adversarial loss
G_loss_U = tf.compat.v1.losses.sigmoid_cross_entropy(tf.ones_like(Y_fake), Y_fake)
G_loss_U_e = tf.compat.v1.losses.sigmoid_cross_entropy(tf.ones_like(Y_fake_e), Y_fake_e)
# 2. Supervised loss
G_loss_S = tf.compat.v1.losses.mean_squared_error(H[:,1:,:], H_hat_supervise[:,1:,:])
# 3. Two Momments
G_loss_V1 = tf.reduce_mean(tf.abs(tf.sqrt(tf.nn.moments(X_hat,[0])[1] + 1e-6) - tf.sqrt(tf.nn.moments(X,[0])[1] + 1e-6)))
G_loss_V2 = tf.reduce_mean(tf.abs((tf.nn.moments(X_hat,[0])[0]) - (tf.nn.moments(X,[0])[0])))
G_loss_V = G_loss_V1 + G_loss_V2
# 4. Summation
G_loss = G_loss_U + gamma * G_loss_U_e + 100 * tf.sqrt(G_loss_S) + 100*G_loss_V
# Embedder network loss
E_loss_T0 = tf.losses.mean_squared_error(X, X_tilde)
E_loss0 = 10*tf.sqrt(E_loss_T0)
E_loss = E_loss0 + 0.1*G_loss_S
# optimizer
E0_solver = tf.compat.v1.train.AdamOptimizer().minimize(E_loss0, var_list = e_vars + r_vars)
E_solver = tf.compat.v1.train.AdamOptimizer().minimize(E_loss, var_list = e_vars + r_vars)
D_solver = tf.compat.v1.train.AdamOptimizer().minimize(D_loss, var_list = d_vars)
G_solver = tf.compat.v1.train.AdamOptimizer().minimize(G_loss, var_list = g_vars + s_vars)
GS_solver = tf.compat.v1.train.AdamOptimizer().minimize(G_loss_S, var_list = g_vars + s_vars)
## TimeGAN training
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
# 1. Embedding network training
print('Start Embedding Network Training')
for itt in range(iterations):
# Set mini-batch
X_mb, T_mb = batch_generator(ori_data, ori_time, batch_size)
# Train embedder
_, step_e_loss = sess.run([E0_solver, E_loss_T0], feed_dict={X: X_mb, T: T_mb})
# Checkpoint
if itt % 1000 == 0:
print('step: '+ str(itt) + '/' + str(iterations) + ', e_loss: ' + str(np.round(np.sqrt(step_e_loss.mean()),4)) )
print('Finish Embedding Network Training')
# 2. Training only with supervised loss
print('Start Training with Supervised Loss Only')
for itt in range(iterations):
# Set mini-batch
X_mb, T_mb = batch_generator(ori_data, ori_time, batch_size)
# Random vector generation
Z_mb = random_generator(batch_size, z_dim, T_mb, max_seq_len)
# Train generator
_, step_g_loss_s = sess.run([GS_solver, G_loss_S], feed_dict={Z: Z_mb, X: X_mb, T: T_mb})
# Checkpoint
if itt % 1000 == 0:
print('step: '+ str(itt) + '/' + str(iterations) +', s_loss: ' + str(np.round(np.sqrt(step_g_loss_s.mean()),4)) )
print('Finish Training with Supervised Loss Only')
# 3. Joint Training
print('Start Joint Training')
for itt in range(iterations):
# Generator training (twice more than discriminator training)
for kk in range(2):
# Set mini-batch
X_mb, T_mb = batch_generator(ori_data, ori_time, batch_size)
# Random vector generation
Z_mb = random_generator(batch_size, z_dim, T_mb, max_seq_len)
# Train generator
_, step_g_loss_u, step_g_loss_s, step_g_loss_v = sess.run([G_solver, G_loss_U, G_loss_S, G_loss_V], feed_dict={Z: Z_mb, X: X_mb, T: T_mb})
# Train embedder
_, step_e_loss_t0 = sess.run([E_solver, E_loss_T0], feed_dict={Z: Z_mb, X: X_mb, T: T_mb})
# Discriminator training
# Set mini-batch
X_mb, T_mb = batch_generator(ori_data, ori_time, batch_size)
# Random vector generation
Z_mb = random_generator(batch_size, z_dim, T_mb, max_seq_len)
# Check discriminator loss before updating
check_d_loss = sess.run(D_loss, feed_dict={X: X_mb, T: T_mb, Z: Z_mb})
# Train discriminator (only when the discriminator does not work well)
if (check_d_loss > 0.15):
_, step_d_loss = sess.run([D_solver, D_loss], feed_dict={X: X_mb, T: T_mb, Z: Z_mb})
# Print multiple checkpoints
if itt % 1000 == 0:
print('step: '+ str(itt) + '/' + str(iterations) +
', d_loss: ' + str(np.round(step_d_loss.mean(),4)) +
', g_loss_u: ' + str(np.round(step_g_loss_u.mean(),4)) +
', g_loss_s: ' + str(np.round(np.sqrt(step_g_loss_s.mean()),4)) +
', g_loss_v: ' + str(np.round(step_g_loss_v.mean(),4)) +
', e_loss_t0: ' + str(np.round(np.sqrt(step_e_loss_t0.mean()),4)) )
print('Finish Joint Training')
## Synthetic data generation
Z_mb = random_generator(no, z_dim, ori_time, max_seq_len)
generated_data_curr = sess.run(X_hat, feed_dict={Z: Z_mb, X: ori_data, T: ori_time})
generated_data = list()
for i in range(no):
temp = generated_data_curr[i,:ori_time[i],:]
generated_data.append(temp)
# Renormalization
generated_data = generated_data * max_val
generated_data = generated_data + min_val
return generated_data