Skip to content

Commit

Permalink
fixed the way conditions are handled in config_space
Browse files Browse the repository at this point in the history
  • Loading branch information
mdorier committed Aug 2, 2024
1 parent d95691d commit 9ba6acf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 17 deletions.
40 changes: 24 additions & 16 deletions python/mochi/bedrock/config_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,18 @@ def __init__(self, name: str|None = None,
meta: dict|None = None):
self._frozen = False
self._inner = ConfigSpace.ConfigurationSpace(name=name, seed=seed, meta=meta)
self._conditions = {} # associates the name of the hp the condition applies to the condition
self._conditions = {} # associates the name of the condition's child to the tuple
# (condition_type, parent_name, value)

def add(self, arg: (Hyperparameter|ConditionLike|ForbiddenLike)):
if self._frozen:
raise PermissionError("ConfigurationSpace is already frozen")
if isinstance(arg, Condition):
if arg.child.name not in self._conditions:
self._conditions[arg.child.name] = []
self._conditions[arg.child.name].append(arg)
self._conditions[arg.child.name].append(
(type(arg), arg.parent.name, arg.value))
elif isinstance(arg, Conjunction):
if arg.child.name not in self._conditions:
self._conditions[arg.child.name] = []
for component in arg.components:
self.add(component)
else:
Expand All @@ -88,21 +88,18 @@ def add_configuration_space(self, prefix: str,
parent_hyperparameter: dict|None = None):
if self._frozen:
raise PermissionError("ConfigurationSpace is already frozen")
conditions_from_sub = []
for param, conditions in configuration_space._conditions.items():
for cond in conditions:
conditions_from_sub.append(
(type(cond), cond.child.name, cond.parent.name, cond.value))
self._inner.add_configuration_space(
prefix=prefix, configuration_space=configuration_space._inner,
delimiter=delimiter, parent_hyperparameter=parent_hyperparameter)
for cond_from_sub in conditions_from_sub:
cond_type, child_name, parent_name, value = cond_from_sub
for child_name, cond_list in configuration_space._conditions.items():
child_name = prefix + delimiter + child_name
parent_name = prefix + delimiter + parent_name
child = self._inner[child_name]
parent = self._inner[parent_name]
self.add(cond_type(child, parent, value))
for cond_tuple in cond_list:
cond_type, parent_name, value = cond_tuple
parent_name = prefix + delimiter + parent_name
if child_name not in self._conditions:
self._conditions[child_name] = []
self._conditions[child_name].append(
(cond_type, parent_name, value))

def __getitem__(self, name):
return self._inner[name]
Expand All @@ -117,7 +114,18 @@ def items(self):
return self._inner.items()

def freeze(self):
for name, conditions in self._conditions.items():

def convert_condition_tuples(child_name, conditions: list[tuple[type,str,Any]]):
result = []
child = self._inner[child_name]
for cond_tuple in conditions:
cond_type, parent_name, value = cond_tuple
parent = self._inner[parent_name]
result.append(cond_type(child, parent, value))
return result

for child_name, conditions in self._conditions.items():
conditions = convert_condition_tuples(child_name, conditions)
if len(conditions) == 1:
self._inner.add(conditions[0])
else:
Expand Down
2 changes: 1 addition & 1 deletion python/mochi/bedrock/test_config_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def resolve_provider_dependencies(config: 'Configuration', prefix: str) -> dict:
}
]

space = ProcSpec.space(num_pools=(2, max_num_pools), num_xstreams=(2, 5),
space = ProcSpec.space(num_pools=(1, max_num_pools), num_xstreams=(2, 5),
provider_space_factories=provider_space_factories).freeze()
#print(space)
config = space.sample_configuration()
Expand Down

0 comments on commit 9ba6acf

Please sign in to comment.