From 9232b73291f9d8d396498d7abfdb14278c69fdcc Mon Sep 17 00:00:00 2001 From: gcattan Date: Tue, 17 Dec 2024 21:49:20 +0100 Subject: [PATCH 1/3] Update distance.py --- pyriemann_qiskit/utils/distance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyriemann_qiskit/utils/distance.py b/pyriemann_qiskit/utils/distance.py index 60910207..ac6c49bb 100644 --- a/pyriemann_qiskit/utils/distance.py +++ b/pyriemann_qiskit/utils/distance.py @@ -111,7 +111,7 @@ def weights_logeuclid_to_convex_hull(A, B, optimizer=ClassicalOptimizer()): matrices = range(n_matrices) def log_prod(m1, m2): - return np.nansum(logm(m1).flatten() * logm(m2).flatten()) + return np.trace(logm(m1) @ logm(m2)) prob = Model() optimizer = get_global_optimizer(optimizer) From 425c99dd6dcdadad05d8c64de234b4e987b635ca Mon Sep 17 00:00:00 2001 From: gcattan Date: Fri, 20 Dec 2024 22:26:40 +0100 Subject: [PATCH 2/3] Update pyriemann_qiskit/utils/distance.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Barthélemy --- pyriemann_qiskit/utils/distance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyriemann_qiskit/utils/distance.py b/pyriemann_qiskit/utils/distance.py index ac6c49bb..a082b29d 100644 --- a/pyriemann_qiskit/utils/distance.py +++ b/pyriemann_qiskit/utils/distance.py @@ -110,7 +110,7 @@ def weights_logeuclid_to_convex_hull(A, B, optimizer=ClassicalOptimizer()): n_matrices, _, _ = A.shape matrices = range(n_matrices) - def log_prod(m1, m2): + def trace_prod_log(m1, m2): return np.trace(logm(m1) @ logm(m2)) prob = Model() From bd54e8834c24b73f807d1ffff6b6cd6dff0db087 Mon Sep 17 00:00:00 2001 From: gcattan Date: Fri, 20 Dec 2024 22:27:21 +0100 Subject: [PATCH 3/3] Update distance.py --- pyriemann_qiskit/utils/distance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyriemann_qiskit/utils/distance.py b/pyriemann_qiskit/utils/distance.py index a082b29d..3b65a17a 100644 --- a/pyriemann_qiskit/utils/distance.py +++ b/pyriemann_qiskit/utils/distance.py @@ -118,9 +118,9 @@ def trace_prod_log(m1, m2): w = optimizer.get_weights(prob, matrices) wtLogAtLogAw = prob.sum( - w[i] * w[j] * log_prod(A[i], A[j]) for i in matrices for j in matrices + w[i] * w[j] * trace_prod_log(A[i], A[j]) for i in matrices for j in matrices ) - wLogBLogA = prob.sum(w[i] * log_prod(B, A[i]) for i in matrices) + wLogBLogA = prob.sum(w[i] * trace_prod_log(B, A[i]) for i in matrices) objective = wtLogAtLogAw - 2 * wLogBLogA prob.set_objective("min", objective)