diff --git a/pyriemann_qiskit/utils/distance.py b/pyriemann_qiskit/utils/distance.py index 60910207..3b65a17a 100644 --- a/pyriemann_qiskit/utils/distance.py +++ b/pyriemann_qiskit/utils/distance.py @@ -110,17 +110,17 @@ def weights_logeuclid_to_convex_hull(A, B, optimizer=ClassicalOptimizer()): n_matrices, _, _ = A.shape matrices = range(n_matrices) - def log_prod(m1, m2): - return np.nansum(logm(m1).flatten() * logm(m2).flatten()) + def trace_prod_log(m1, m2): + return np.trace(logm(m1) @ logm(m2)) prob = Model() optimizer = get_global_optimizer(optimizer) w = optimizer.get_weights(prob, matrices) wtLogAtLogAw = prob.sum( - w[i] * w[j] * log_prod(A[i], A[j]) for i in matrices for j in matrices + w[i] * w[j] * trace_prod_log(A[i], A[j]) for i in matrices for j in matrices ) - wLogBLogA = prob.sum(w[i] * log_prod(B, A[i]) for i in matrices) + wLogBLogA = prob.sum(w[i] * trace_prod_log(B, A[i]) for i in matrices) objective = wtLogAtLogAw - 2 * wLogBLogA prob.set_objective("min", objective)