diff --git a/compute_distance.py b/compute_distance.py index 3f2605b..b621b7d 100644 --- a/compute_distance.py +++ b/compute_distance.py @@ -5,12 +5,12 @@ from time import time -def distance(D1, D2, gamma, niter, data_path, verbose = 0): +def distance(D1, D2, lambd, niter, data_path, verbose = 0): """ calculates all pairwise distances d(d1,d2) with: - d1 in D1 = [d11, d12, ...] - d2 in D2 = [d21, d22, ...] - - gamma: entropic regularization parameter + - lambd: entropic regularization parameter - niter: number of iterations of Sinkhorn's algorithm - data_path: path to data files - verbose = {0: display nothing, 1: display intermedidate computation times} @@ -84,7 +84,7 @@ def distance(D1, D2, gamma, niter, data_path, verbose = 0): # define graphs init = tf.initialize_all_variables() - op_xi = tf.exp(tf.scalar_mul(-1./gamma,tf.pow(xi, 2))) + op_xi = tf.exp(tf.scalar_mul(-1./lambd,tf.pow(xi, 2))) update_xi = tf.assign(xi, op_xi) op_A = tf.div(P, tf.matmul(B, xi))