diff --git a/pyriemann_qiskit/utils/distance.py b/pyriemann_qiskit/utils/distance.py index 60910207..ac6c49bb 100644 --- a/pyriemann_qiskit/utils/distance.py +++ b/pyriemann_qiskit/utils/distance.py @@ -111,7 +111,7 @@ def weights_logeuclid_to_convex_hull(A, B, optimizer=ClassicalOptimizer()): matrices = range(n_matrices) def log_prod(m1, m2): - return np.nansum(logm(m1).flatten() * logm(m2).flatten()) + return np.trace(logm(m1) @ logm(m2)) prob = Model() optimizer = get_global_optimizer(optimizer)