From b7f335bb68e754b25a270c9ef148d0245b02456d Mon Sep 17 00:00:00 2001 From: Jamie Leppard Date: Thu, 24 Nov 2022 15:56:57 +0000 Subject: [PATCH] experiment: Re-compute derived Param arguments Now creates new Param on rebind and call with stored + updated initialisation parameters. This allows the scale for int/float params to be recomputed if required. --- ndscan/experiment/fragment.py | 8 ++-- ndscan/experiment/parameters.py | 58 +++++++++++++++++----------- test/fixtures.py | 2 +- test/test_experiment_entrypoint.py | 62 ++++++++++++++++++++---------- 4 files changed, 83 insertions(+), 47 deletions(-) diff --git a/ndscan/experiment/fragment.py b/ndscan/experiment/fragment.py index bc048966..645b20ca 100644 --- a/ndscan/experiment/fragment.py +++ b/ndscan/experiment/fragment.py @@ -363,10 +363,10 @@ def setattr_param_like(self, "already rebound?".format(original_name)) template_param = original_owner._free_params[original_name] - new_param = deepcopy(template_param) - new_param.fqn = self.fqn + "." + name - for k, v in kwargs.items(): - setattr(new_param, k, v) + init_params = deepcopy(template_param.init_params) + init_params.update(kwargs) + init_params["fqn"] = self.fqn + "." + name + new_param = template_param.__class__(**init_params) self._free_params[name] = new_param new_handle = new_param.HandleType(self, name) setattr(self, name, new_handle) diff --git a/ndscan/experiment/parameters.py b/ndscan/experiment/parameters.py index 0d29bd82..ce887427 100644 --- a/ndscan/experiment/parameters.py +++ b/ndscan/experiment/parameters.py @@ -261,7 +261,15 @@ def resolve_numeric_scale(scale: Optional[float], unit: str) -> float: "the scale manually".format(unit)) -class FloatParam: +class ParamBase: + def __init__(self, **kwargs): + # Store kwargs for param rebinding + self.init_params = kwargs + for k, v in kwargs.items(): + setattr(self, k, v) + + +class FloatParam(ParamBase): HandleType = FloatParamHandle StoreType = FloatParamStore CompilerType = TFloat @@ -278,17 +286,18 @@ def __init__(self, step: Optional[float] = None, is_scannable: bool = True): - self.fqn = fqn - self.description = description - self.default = default - self.min = min - self.max = max - - self.unit = unit + ParamBase.__init__(self, + fqn=fqn, + description=description, + default=default, + min=min, + max=max, + unit=unit, + scale=scale, + step=step, + is_scannable=is_scannable) self.scale = resolve_numeric_scale(scale, unit) - self.step = step if step is not None else self.scale / 10.0 - self.is_scannable = is_scannable def describe(self) -> Dict[str, Any]: spec = { @@ -326,7 +335,7 @@ def make_store(self, identity: Tuple[str, str], value: float) -> FloatParamStore return FloatParamStore(identity, value) -class IntParam: +class IntParam(ParamBase): HandleType = IntParamHandle StoreType = IntParamStore CompilerType = TInt32 @@ -341,13 +350,16 @@ def __init__(self, unit: str = "", scale: Optional[int] = None, is_scannable: bool = True): - self.fqn = fqn - self.description = description - self.default = default - self.min = min - self.max = max - self.unit = unit + ParamBase.__init__(self, + fqn=fqn, + description=description, + default=default, + min=min, + max=max, + unit=unit, + scale=scale, + is_scannable=is_scannable) self.scale = resolve_numeric_scale(scale, unit) if self.scale != 1: raise NotImplementedError( @@ -386,7 +398,7 @@ def make_store(self, identity: Tuple[str, str], value: int) -> IntParamStore: return IntParamStore(identity, value) -class StringParam: +class StringParam(ParamBase): HandleType = StringParamHandle StoreType = StringParamStore CompilerType = TStr @@ -396,10 +408,12 @@ def __init__(self, description: str, default: str, is_scannable: bool = True): - self.fqn = fqn - self.description = description - self.default = default - self.is_scannable = is_scannable + + ParamBase.__init__(self, + fqn=fqn, + description=description, + default=default, + is_scannable=is_scannable) def describe(self) -> Dict[str, Any]: return { diff --git a/test/fixtures.py b/test/fixtures.py index 75cce6f6..12768090 100644 --- a/test/fixtures.py +++ b/test/fixtures.py @@ -53,7 +53,7 @@ def get_default_analyses(self): class ReboundAddOneFragment(ExpFragment): def build_fragment(self): self.setattr_fragment("add_one", AddOneFragment) - self.setattr_param_rebind("value", self.add_one) + self.setattr_param_rebind("value", self.add_one, unit="ms") def run_once(self): self.add_one.run_once() diff --git a/test/test_experiment_entrypoint.py b/test/test_experiment_entrypoint.py index 00137497..ef751b00 100644 --- a/test/test_experiment_entrypoint.py +++ b/test/test_experiment_entrypoint.py @@ -136,7 +136,25 @@ def test_time_series_transitory_limit(self): exp.run() def test_run_1d_scan(self): - exp = self._test_run_1d(ScanAddOneExp, "fixtures.AddOneFragment") + fragment_fqn = "fixtures.AddOneFragment" + expected_axes = [{ + "increment": 1.0, + "max": 2, + "min": 0, + "param": { + "default": "0.0", + "description": "Value to return", + "fqn": fragment_fqn + ".value", + "spec": { + "is_scannable": True, + "scale": 1.0, + "step": 0.1, + }, + "type": "float" + }, + "path": "*" + }] + exp = self._test_run_1d(ScanAddOneExp, fragment_fqn, expected_axes) self.assertEqual(exp.fragment.num_host_setup_calls, 1) self.assertEqual(exp.fragment.num_device_setup_calls, 3) self.assertEqual(exp.fragment.num_host_cleanup_calls, 1) @@ -214,13 +232,33 @@ def test_run_1d_scan(self): }) def test_run_rebound_1d_scan(self): - exp = self._test_run_1d(ScanReboundAddOneExp, "fixtures.ReboundAddOneFragment") + fragment_fqn = "fixtures.ReboundAddOneFragment" + expected_axes = [{ + "increment": 1.0, + "max": 2, + "min": 0, + "param": { + "default": "0.0", + "description": "Value to return", + "fqn": fragment_fqn + ".value", + "spec": { + "is_scannable": True, + "scale": 0.001, + "step": 0.0001, + "unit": "ms", + }, + "type": "float" + }, + "path": "*" + }] + exp = self._test_run_1d(ScanReboundAddOneExp, "fixtures.ReboundAddOneFragment", + expected_axes) self.assertEqual(exp.fragment.add_one.num_host_setup_calls, 1) self.assertEqual(exp.fragment.add_one.num_device_setup_calls, 3) self.assertEqual(exp.fragment.add_one.num_host_cleanup_calls, 1) self.assertEqual(exp.fragment.add_one.num_device_cleanup_calls, 1) - def _test_run_1d(self, klass, fragment_fqn): + def _test_run_1d(self, klass, fragment_fqn, expected_axes): exp = self.create(klass) fqn = fragment_fqn + ".value" exp.args._params["scan"]["axes"].append({ @@ -240,23 +278,7 @@ def _test_run_1d(self, klass, fragment_fqn): def d(key): return self.dataset_db.get("ndscan.rid_0." + key) - self.assertEqual(json.loads(d("axes")), [{ - "increment": 1.0, - "max": 2, - "min": 0, - "param": { - "default": "0.0", - "description": "Value to return", - "fqn": fqn, - "spec": { - "is_scannable": True, - "scale": 1.0, - "step": 0.1 - }, - "type": "float" - }, - "path": "*" - }]) + self.assertEqual(json.loads(d("axes")), expected_axes) self.assertEqual(d(SCHEMA_REVISION_KEY), SCHEMA_REVISION) self.assertEqual(d("completed"), True) self.assertEqual(d("points.axis_0"), [0, 1, 2])