Skip to content

Commit

Permalink
experiment: Re-compute derived Param arguments
Browse files Browse the repository at this point in the history
Now creates new Param on rebind and call with stored + updated
initialisation parameters. This allows the scale for int/float
params to be recomputed if required.
  • Loading branch information
JammyL authored and dnadlinger committed May 24, 2023
1 parent 14e1f4f commit b7f335b
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 47 deletions.
8 changes: 4 additions & 4 deletions ndscan/experiment/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,10 @@ def setattr_param_like(self,
"already rebound?".format(original_name))

template_param = original_owner._free_params[original_name]
new_param = deepcopy(template_param)
new_param.fqn = self.fqn + "." + name
for k, v in kwargs.items():
setattr(new_param, k, v)
init_params = deepcopy(template_param.init_params)
init_params.update(kwargs)
init_params["fqn"] = self.fqn + "." + name
new_param = template_param.__class__(**init_params)
self._free_params[name] = new_param
new_handle = new_param.HandleType(self, name)
setattr(self, name, new_handle)
Expand Down
58 changes: 36 additions & 22 deletions ndscan/experiment/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,15 @@ def resolve_numeric_scale(scale: Optional[float], unit: str) -> float:
"the scale manually".format(unit))


class FloatParam:
class ParamBase:
def __init__(self, **kwargs):
# Store kwargs for param rebinding
self.init_params = kwargs
for k, v in kwargs.items():
setattr(self, k, v)


class FloatParam(ParamBase):
HandleType = FloatParamHandle
StoreType = FloatParamStore
CompilerType = TFloat
Expand All @@ -278,17 +286,18 @@ def __init__(self,
step: Optional[float] = None,
is_scannable: bool = True):

self.fqn = fqn
self.description = description
self.default = default
self.min = min
self.max = max

self.unit = unit
ParamBase.__init__(self,
fqn=fqn,
description=description,
default=default,
min=min,
max=max,
unit=unit,
scale=scale,
step=step,
is_scannable=is_scannable)
self.scale = resolve_numeric_scale(scale, unit)

self.step = step if step is not None else self.scale / 10.0
self.is_scannable = is_scannable

def describe(self) -> Dict[str, Any]:
spec = {
Expand Down Expand Up @@ -326,7 +335,7 @@ def make_store(self, identity: Tuple[str, str], value: float) -> FloatParamStore
return FloatParamStore(identity, value)


class IntParam:
class IntParam(ParamBase):
HandleType = IntParamHandle
StoreType = IntParamStore
CompilerType = TInt32
Expand All @@ -341,13 +350,16 @@ def __init__(self,
unit: str = "",
scale: Optional[int] = None,
is_scannable: bool = True):
self.fqn = fqn
self.description = description
self.default = default
self.min = min
self.max = max

self.unit = unit
ParamBase.__init__(self,
fqn=fqn,
description=description,
default=default,
min=min,
max=max,
unit=unit,
scale=scale,
is_scannable=is_scannable)
self.scale = resolve_numeric_scale(scale, unit)
if self.scale != 1:
raise NotImplementedError(
Expand Down Expand Up @@ -386,7 +398,7 @@ def make_store(self, identity: Tuple[str, str], value: int) -> IntParamStore:
return IntParamStore(identity, value)


class StringParam:
class StringParam(ParamBase):
HandleType = StringParamHandle
StoreType = StringParamStore
CompilerType = TStr
Expand All @@ -396,10 +408,12 @@ def __init__(self,
description: str,
default: str,
is_scannable: bool = True):
self.fqn = fqn
self.description = description
self.default = default
self.is_scannable = is_scannable

ParamBase.__init__(self,
fqn=fqn,
description=description,
default=default,
is_scannable=is_scannable)

def describe(self) -> Dict[str, Any]:
return {
Expand Down
2 changes: 1 addition & 1 deletion test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_default_analyses(self):
class ReboundAddOneFragment(ExpFragment):
def build_fragment(self):
self.setattr_fragment("add_one", AddOneFragment)
self.setattr_param_rebind("value", self.add_one)
self.setattr_param_rebind("value", self.add_one, unit="ms")

def run_once(self):
self.add_one.run_once()
Expand Down
62 changes: 42 additions & 20 deletions test/test_experiment_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,25 @@ def test_time_series_transitory_limit(self):
exp.run()

def test_run_1d_scan(self):
exp = self._test_run_1d(ScanAddOneExp, "fixtures.AddOneFragment")
fragment_fqn = "fixtures.AddOneFragment"
expected_axes = [{
"increment": 1.0,
"max": 2,
"min": 0,
"param": {
"default": "0.0",
"description": "Value to return",
"fqn": fragment_fqn + ".value",
"spec": {
"is_scannable": True,
"scale": 1.0,
"step": 0.1,
},
"type": "float"
},
"path": "*"
}]
exp = self._test_run_1d(ScanAddOneExp, fragment_fqn, expected_axes)
self.assertEqual(exp.fragment.num_host_setup_calls, 1)
self.assertEqual(exp.fragment.num_device_setup_calls, 3)
self.assertEqual(exp.fragment.num_host_cleanup_calls, 1)
Expand Down Expand Up @@ -214,13 +232,33 @@ def test_run_1d_scan(self):
})

def test_run_rebound_1d_scan(self):
exp = self._test_run_1d(ScanReboundAddOneExp, "fixtures.ReboundAddOneFragment")
fragment_fqn = "fixtures.ReboundAddOneFragment"
expected_axes = [{
"increment": 1.0,
"max": 2,
"min": 0,
"param": {
"default": "0.0",
"description": "Value to return",
"fqn": fragment_fqn + ".value",
"spec": {
"is_scannable": True,
"scale": 0.001,
"step": 0.0001,
"unit": "ms",
},
"type": "float"
},
"path": "*"
}]
exp = self._test_run_1d(ScanReboundAddOneExp, "fixtures.ReboundAddOneFragment",
expected_axes)
self.assertEqual(exp.fragment.add_one.num_host_setup_calls, 1)
self.assertEqual(exp.fragment.add_one.num_device_setup_calls, 3)
self.assertEqual(exp.fragment.add_one.num_host_cleanup_calls, 1)
self.assertEqual(exp.fragment.add_one.num_device_cleanup_calls, 1)

def _test_run_1d(self, klass, fragment_fqn):
def _test_run_1d(self, klass, fragment_fqn, expected_axes):
exp = self.create(klass)
fqn = fragment_fqn + ".value"
exp.args._params["scan"]["axes"].append({
Expand All @@ -240,23 +278,7 @@ def _test_run_1d(self, klass, fragment_fqn):
def d(key):
return self.dataset_db.get("ndscan.rid_0." + key)

self.assertEqual(json.loads(d("axes")), [{
"increment": 1.0,
"max": 2,
"min": 0,
"param": {
"default": "0.0",
"description": "Value to return",
"fqn": fqn,
"spec": {
"is_scannable": True,
"scale": 1.0,
"step": 0.1
},
"type": "float"
},
"path": "*"
}])
self.assertEqual(json.loads(d("axes")), expected_axes)
self.assertEqual(d(SCHEMA_REVISION_KEY), SCHEMA_REVISION)
self.assertEqual(d("completed"), True)
self.assertEqual(d("points.axis_0"), [0, 1, 2])
Expand Down

0 comments on commit b7f335b

Please sign in to comment.