diff --git a/python/mochi/bedrock/config_space.py b/python/mochi/bedrock/config_space.py index a006cd7..5a651da 100644 --- a/python/mochi/bedrock/config_space.py +++ b/python/mochi/bedrock/config_space.py @@ -163,7 +163,7 @@ def add(self, arg): if isinstance(arg, Hyperparameter): arg.name = self.prefix + arg.name if isinstance(arg, list): - for a in args: + for a in arg: self.add(a) else: self.cs.add(arg) @@ -171,6 +171,15 @@ def add(self, arg): def __getattr__(self, key): return getattr(self.cs, key) + def __iter__(self): + return self.cs.__iter__() + + def __getitem__(self, name): + return self.cs[name] + + def __contains__(self, name): + return name in self.cs + def CategoricalOrConst(name: str, items: Sequence[Any]|Any, *, default: Any|None = None, weights: Sequence[float]|None = None, diff --git a/python/mochi/bedrock/spec.py b/python/mochi/bedrock/spec.py index dc1abc1..b531e41 100644 --- a/python/mochi/bedrock/spec.py +++ b/python/mochi/bedrock/spec.py @@ -1383,7 +1383,8 @@ def set_provider_hyperparameters(self, configuration_space: CS) -> None: @abstractmethod def resolve_to_provider_spec( - self, name: str, provider_id: int, config: Config, prefix: str) -> 'ProviderSpec': + self, name: str, provider_id: int, + config: Config, prefix: str) -> 'ProviderSpec': """ This method should convert a Configuration object into a ProviderSpec, by extracting the configuration and dependencies from the sampled parameters. @@ -1467,68 +1468,6 @@ def from_json(json_string: str) -> 'ProviderSpec': """ return ProviderSpec.from_dict(json.loads(json_string)) - @staticmethod - def space(*, type: str, tags: list[str] = [], - provider_config_space: Optional[CS] = None, - provider_config_resolver: Callable[[Config, str], dict]|None = None, - dependency_config_space: Optional[CS] = None, - dependency_resolver: Callable[[Config, str], dict]|None = None) -> CS: - """ - Create a ConfigurationSpace for a ProviderSpec. - - - type: type of provider. - - tags: list of tags the provider will use. - - provider_config_space: a ConfigurationSpace for the "config" field of the provider. - - provider_config_resolver: a function (or callable) taking a Configuration and a prefix - and returning the provider's "config" field (dict) from the Configuration. - - dependency_config_space: a ConfigurationSpace for the "dependencies" field of the provider. - - dependency_resolver: a function (or callable) taking a Configuration and a prefix - and returning the provider's "dependencies" field (dict) from the Configuration. - """ - from .config_space import ConfigurationSpace, FloatOrConst, Categorical, Constant - cs = ConfigurationSpace() - cs.add(Constant('type', type)) - cs.add(Constant('tags', tags)) - - if provider_config_space is not None: - cs.add_configuration_space( - prefix='config', delimiter='.', - configuration_space=provider_config_space) - cs.add(Constant('config_resolver', provider_config_resolver)) - if dependency_config_space is not None: - cs.add_configuration_space( - prefix='dependencies', delimiter='.', - configuration_space=dependency_config_space) - cs.add(Constant('dependency_resolver', dependency_resolver)) - return cs - - @staticmethod - def from_config(*, name: str, provider_id: int, - config: Config, prefix: str = '') -> 'ProviderSpec': - """ - Create a ProviderSpec from a given Configuration object. - - This function must be also given the name and provider Id to give the provider, - as well as the list of pools of the underlying ProcSpec. - """ - from .config_space import Configuration - type = config[f'{prefix}type'] - tags = config[f'{prefix}tags'] - provider_config_resolver = config[f'{prefix}config_resolver'] - dependency_resolver = config[f'{prefix}dependency_resolver'] - if provider_config_resolver is None: - provider_config = {} - else: - provider_config = provider_config_resolver(config, f'{prefix}config.') - if dependency_resolver is None: - provider_dependencies = {} - else: - provider_dependencies = dependency_resolver(config, f'{prefix}dependencies.') - return ProviderSpec( - name=name, type=type, provider_id=provider_id, - config=provider_config, tags=tags, - dependencies=provider_dependencies) - @attr.s(auto_attribs=True, on_setattr=_check_validators, kw_only=True) class BedrockSpec: diff --git a/python/mochi/bedrock/test_config_space.py b/python/mochi/bedrock/test_config_space.py index 17e3c0f..9165f0d 100644 --- a/python/mochi/bedrock/test_config_space.py +++ b/python/mochi/bedrock/test_config_space.py @@ -1,5 +1,27 @@ import unittest from mochi.bedrock.spec import * +from .config_space import ConfigurationSpace, Integer, CategoricalChoice + +class MyDatabaseSpaceBuilder(ProviderConfigSpaceBuilder): + + def set_provider_hyperparameters(self, configuration_space: CS) -> None: + configuration_space.add(Integer("x", (0,9))) + configuration_space.add(Integer("y", (1,42))) + # add a pool dependency + num_pools = configuration_space["margo.argobots.num_pools"] + configuration_space.add(CategoricalChoice("pool", num_options=num_pools)) + + def resolve_to_provider_spec( + self, name: str, provider_id: int, config: Config, prefix: str) -> ProviderSpec: + cfg = { + "x" : int(config[prefix + "x"]), + "y" : int(config[prefix + "y"]) + } + dep = { + "pool" : int(config[prefix + "pool"]) + } + return ProviderSpec(name=name, type="yokan", provider_id=provider_id, + tags=["tag1", "tag2"], config=cfg, dependencies=dep) class TestConfigSpace(unittest.TestCase): @@ -69,110 +91,62 @@ def test_proc_config_space(self): spec = ProcSpec.from_config(config=config, address='na+sm') #print(spec.to_json(indent=4)) - def test_provider_config_space(self): - from .config_space import ConfigurationSpace, Integer - # pools to select from - # pools = [PoolSpec(name=f'pool_{i}') for i in range(5)] - - # config space for the provider - config_cs = ConfigurationSpace() - config_cs.add(Integer("x", (0,9))) - config_cs.add(Integer("y", (1,42))) - - # function to resolve the configuration - def resolve_provider_config(config: 'Configuration', prefix: str) -> dict: - x = config[f'{prefix}x'] - y = config[f'{prefix}y'] - return {'x':x, 'y':y} - - # function to resolve the dependencies - def resolve_provider_dependencies(config: 'Configuration', prefix: str) -> dict: - return {'abc': 'def'} - - space = ProviderSpec.space( - type='yokan', - tags=['tag1', 'tag2'], - provider_config_space=config_cs, - provider_config_resolver=resolve_provider_config, - dependency_resolver=resolve_provider_dependencies).freeze() - #print(space) - - config = space.sample_configuration() - #print(config) - - spec = ProviderSpec.from_config( - name='my_yokan_provider', provider_id=1, - config=config) - #print(spec.to_json(indent=4)) - def test_proc_with_providers(self): - from .config_space import ConfigurationSpace, Integer - - max_num_pools = 3 - - """ - provider_config_cs = ConfigurationSpace() - provider_config_cs.add(Integer("x", (0,9))) - provider_config_cs.add(Integer("y", (1,42))) - - def resolve_provider_config(config: 'Configuration', prefix: str) -> dict: - x = config[f'{prefix}x'] - y = config[f'{prefix}y'] - return {'x':x, 'y':y} - - def resolve_provider_dependencies(config: 'Configuration', prefix: str) -> dict: - return {'abc': 'def'} - provider_space_factories = [ { "family": "databases", - "space": ProviderSpec.space( - type='yokan', - #max_num_pools=max_num_pools, - tags=['tag1', 'tag2'], - provider_config_space=provider_config_cs, - provider_config_resolver=resolve_provider_config, - dependency_resolver=resolve_provider_dependencies), + "builder": MyDatabaseSpaceBuilder(), "count": (1,3) } ] - """ + space = ProcSpec.space(num_pools=(1, 3), num_xstreams=(2, 5), + provider_space_factories=provider_space_factories).freeze() + #print(space) + config = space.sample_configuration() + #print(config) + spec = ProcSpec.from_config(address='na+sm', config=config) + #print(spec.to_json(indent=4)) - class MyProviderSpaceBuilder(ProviderConfigSpaceBuilder): + def test_service_config_space(self): - def set_provider_hyperparameters(self, configuration_space: CS) -> None: - configuration_space.add(Integer("x", (0,9))) - configuration_space.add(Integer("y", (1,42))) + proc_type_a = ProcSpec.space(num_pools=2, num_xstreams=3) + proc_type_b = ProcSpec.space(num_pools=1, num_xstreams=2) - def resolve_to_provider_spec( - self, name: str, provider_id: int, config: Config, prefix: str) -> ProviderSpec: - cfg = { - "x" : int(config[prefix + "x"]), - "y" : int(config[prefix + "y"]) + space = ServiceSpec.space( + process_space_factories=[ + { + 'family': 'proc_type_a', + 'space': proc_type_a, + 'count': 2 + }, + { + 'family': 'proc_type_b', + 'space': proc_type_b, + 'count': 2 } - return ProviderSpec( - name=name, type="yokan", provider_id=provider_id, - tags=["tag1", "tag2"], config=cfg, dependencies={}) + ]).freeze() + #print(space) + + config = space.sample_configuration() + #print(config) + + spec = ServiceSpec.from_config(address='na+sm', config=config) + #print(spec.to_json(indent=4)) + + def test_service_config_space_with_providers(self): provider_space_factories = [ { "family": "databases", - "builder": MyProviderSpaceBuilder(), + "builder": MyDatabaseSpaceBuilder(), "count": (1,3) } ] - space = ProcSpec.space(num_pools=(1, max_num_pools), num_xstreams=(2, 5), - provider_space_factories=provider_space_factories).freeze() - print(space) - config = space.sample_configuration() - print(config) - spec = ProcSpec.from_config(address='na+sm', config=config) - print(spec.to_json(indent=4)) - def test_service_config_space(self): - - proc_type_a = ProcSpec.space(num_pools=2, num_xstreams=3) - proc_type_b = ProcSpec.space(num_pools=1, num_xstreams=2) + proc_type_a = ProcSpec.space(num_pools=3, num_xstreams=3, + provider_space_factories=provider_space_factories) + proc_type_b = ProcSpec.space(num_pools=2, num_xstreams=2, + provider_space_factories=provider_space_factories) space = ServiceSpec.space( process_space_factories=[ @@ -187,13 +161,13 @@ def test_service_config_space(self): 'count': 2 } ]).freeze() - #print(space) + print(space) config = space.sample_configuration() - #print(config) + print(config) spec = ServiceSpec.from_config(address='na+sm', config=config) - #print(spec.to_json(indent=4)) + print(spec.to_json(indent=4)) if __name__ == '__main__': unittest.main()