From 71d25f60fb8ff16d60db3b4d6c3d9c49507150e2 Mon Sep 17 00:00:00 2001 From: Alexander Nikitin <1243786+AlexanderVNikitin@users.noreply.github.com> Date: Fri, 17 Nov 2023 21:33:18 +0200 Subject: [PATCH] add simulator test --- tests/test_simulator.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/test_simulator.py b/tests/test_simulator.py index 0fef977..76e8bf3 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -1,7 +1,5 @@ -import pytest - -import networkx as nx - +import numpy as np +from unittest.mock import MagicMock import tsgm @@ -23,3 +21,14 @@ def test_sine_cosine_simulator(): params1 = sin_cosine_sim.params() params2 = new_sim.params() assert params1.keys() == params2.keys() and params1["max_scale"] == params2["max_scale"] and params1["max_const"] == params2["max_const"] + + +def test_simulator_base(): + MockDriver = MagicMock() + data = tsgm.dataset.Dataset(x=np.ones((3, 2, 1)), y=np.ones(3)) + s = tsgm.simulator.Simulator(data, MockDriver) + cloned = s.clone() + assert np.array_equal(s._data.X, cloned._data.X) and np.array_equal(s._data.y, cloned._data.y) + + s.fit() + MockDriver.fit.assert_called_once_with(s._data.X, s._data.y)