From 6bcca31329715ccd329f58a40e55faa283ab9797 Mon Sep 17 00:00:00 2001 From: Alexander Nikitin <1243786+AlexanderVNikitin@users.noreply.github.com> Date: Fri, 17 Nov 2023 21:22:35 +0200 Subject: [PATCH] add visualization tests --- tests/test_visualizations.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_visualizations.py b/tests/test_visualizations.py index b686a13..edb84c4 100644 --- a/tests/test_visualizations.py +++ b/tests/test_visualizations.py @@ -59,6 +59,9 @@ def test_visualize_ts_lineplot(unite_features): ys = np.array([1, 2]) tsgm.utils.visualize_ts_lineplot(Xs, ys, num=1, unite_features=unite_features) + ys = np.array([[1, 2], [1, 2]]) + tsgm.utils.visualize_ts_lineplot(Xs, ys, num=1, unite_features=unite_features) + def test_visualize_training_loss(): loss = np.array([[10.0], [9.0], [8.0], [7.0]]) @@ -76,3 +79,7 @@ def test_visualize_original_and_reconst_ts(): reconstructed = original tsgm.utils.visualize_original_and_reconst_ts(original, reconstructed) + +def test_visualize_training_loss(): + loss_vector = np.ones((100, 100)) + tsgm.utils.visualize_training_loss(loss_vector, labels=("a", "b"))