Skip to content

Commit

Permalink
parameter based simulator - use base params
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Sep 25, 2023
1 parent d4fd54b commit 189f496
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
10 changes: 7 additions & 3 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ def test_sine_cosine_simulator():
sin_cosine_sim = tsgm.simulator.SineConstSimulator(data)
syn_dataset = sin_cosine_sim.generate(10)
assert type(syn_dataset) is tsgm.dataset.Dataset
assert sin_cosine_sim.params() == {"max_scale": 10.0, "max_const": 5.0}
params = sin_cosine_sim.params()
assert params["max_scale"] == 10.0 and params["max_const"] == 5.0
assert syn_dataset.X.shape == (10, 12, 23)
assert syn_dataset.y.shape == (10,)

sin_cosine_sim.set_params(max_scale=10, max_const=123)
assert sin_cosine_sim.params() == {"max_scale": 10.0, "max_const": 123}
params = sin_cosine_sim.params()
assert params["max_scale"] == 10.0 and params["max_const"] == 123

new_sim = sin_cosine_sim.clone()
assert sin_cosine_sim.params() == new_sim.params()
params1 = sin_cosine_sim.params()
params2 = new_sim.params()
assert params1.keys() == params2.keys() and params1["max_scale"] == params2["max_scale"] and params1["max_const"] == params2["max_const"]
18 changes: 6 additions & 12 deletions tsgm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(self, data: tsgm.dataset.DatasetProperties):
super().__init__(data)

def params(self):
params = self.__dict__
del params["data"]
params = copy.deepcopy(self.__dict__)
del params["_data"], params["_driver"]
return params

def set_params(self, params: dict) -> None:
Expand All @@ -67,19 +67,12 @@ def __init__(self, data: tsgm.dataset.DatasetProperties, max_scale: float = 10.0

self.set_params(max_scale, max_const)

def set_params(self, max_scale, max_const):
def set_params(self, max_scale, max_const, *args, **kwargs):
self._scale = tfp.distributions.Uniform(0, max_scale)
self._const = tfp.distributions.Uniform(0, max_const)
self._shift = tfp.distributions.Uniform(0, 2)

self._max_scale = max_scale
self._max_const = max_const

def params(self):
return {
"max_scale": self._max_scale,
"max_const": self._max_const,
}
super().set_params({"max_scale": max_scale, "max_const": max_const})

def generate(self, num_samples: int, *args) -> tsgm.dataset.Dataset:
result_X, result_y = [], []
Expand All @@ -98,5 +91,6 @@ def generate(self, num_samples: int, *args) -> tsgm.dataset.Dataset:

def clone(self) -> "SineConstSimulator":
copy_simulator = SineConstSimulator(self._data)
copy_simulator.set_params(**self.params())
params = self.params()
copy_simulator.set_params(max_scale=params["max_scale"], max_const=params["max_const"])
return copy_simulator

0 comments on commit 189f496

Please sign in to comment.