-
Notifications
You must be signed in to change notification settings - Fork 26
/
graph_new.py
executable file
·71 lines (56 loc) · 2.77 KB
/
graph_new.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 23 15:49:15 2018
Graph definition for all models
@author: anazabal, olmosUC3M, ivaleraM
"""
import tensorflow as tf
import numpy as np
import VAE_functions
def HVAE_graph(model_name, types_file, batch_size, learning_rate=1e-3, z_dim=2, y_dim=1, s_dim=2, y_dim_partition=[]):
#We select the model for the VAE
print('[*] Importing model: ' + model_name)
model = __import__(model_name)
#Load placeholders
print('[*] Defining placeholders')
batch_data_list, batch_data_list_observed, miss_list, tau, tau2, types_list = VAE_functions.place_holder_types(types_file, batch_size)
#Batch normalization of the data
X_list, normalization_params = VAE_functions.batch_normalization(batch_data_list_observed, types_list, miss_list)
#Set dimensionality of Y
if y_dim_partition:
y_dim_output = np.sum(y_dim_partition)
else:
y_dim_partition = y_dim*np.ones(len(types_list),dtype=int)
y_dim_output = np.sum(y_dim_partition)
#Encoder definition
print('[*] Defining Encoder...')
samples, q_params = model.encoder(X_list, miss_list, batch_size, z_dim, s_dim, tau)
print('[*] Defining Decoder...')
theta, samples, p_params, log_p_x, log_p_x_missing = model.decoder(batch_data_list, miss_list, types_list, samples, q_params, normalization_params, batch_size, z_dim, y_dim_output, y_dim_partition, tau2)
print('[*] Defining Cost function...')
ELBO, loss_reconstruction, KL_z, KL_s = model.cost_function(log_p_x, p_params, q_params, types_list, z_dim, y_dim_output, s_dim)
optim = tf.train.AdamOptimizer(learning_rate).minimize(-ELBO)
#Generator function for testing purposes
samples_test, test_params, log_p_x_test, log_p_x_missing_test = model.samples_generator(batch_data_list, X_list, miss_list, types_list, batch_size, z_dim, y_dim_output, y_dim_partition, s_dim, tau, tau2, normalization_params)
#Packing results
tf_nodes = {'ground_batch' : batch_data_list,
'ground_batch_observed' : batch_data_list_observed,
'miss_list': miss_list,
'tau_GS': tau,
'tau_var': tau2,
'samples': samples,
'log_p_x': log_p_x,
'log_p_x_missing': log_p_x_missing,
'loss_re' : loss_reconstruction,
'loss': -ELBO,
'optim': optim,
'KL_s': KL_s,
'KL_z': KL_z,
'p_params': p_params,
'q_params': q_params,
'samples_test': samples_test,
'test_params': test_params,
'log_p_x_test': log_p_x,
'log_p_x_missing_test': log_p_x_missing_test}
return tf_nodes