Skip to content

Commit

Permalink
added client list in specs
Browse files Browse the repository at this point in the history
  • Loading branch information
mdorier committed Jan 31, 2024
1 parent 90ef7fc commit 75f35f0
Showing 1 changed file with 93 additions and 1 deletion.
94 changes: 93 additions & 1 deletion python/mochi/bedrock/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,7 @@ class ProviderSpec:
factory=list)

def to_dict(self) -> dict:
"""Convert the SSGSpec into a dictionary.
"""Convert the ProviderSpec into a dictionary.
"""
return {'name': self.name,
'type': self.type,
Expand Down Expand Up @@ -1343,6 +1343,75 @@ def from_json(json_string: str, abt_spec: ArgobotsSpec) -> 'ProviderSpec':
return ProviderSpec.from_dict(json.loads(json_string), abt_spec)


@attr.s(auto_attribs=True, on_setattr=_check_validators, kw_only=True)
class ClientSpec:
"""Client specification.
:param name: Name of the client
:type name: str
:param type: Type of client
:type type: str
:param config: Configuration
:type config: dict
:param dependencies: Dependencies
:type dependencies: dict
:param tags: Tags
:type tags: List[str]
"""

name: str = attr.ib(
validator=[instance_of(str), _validate_object_name],
on_setattr=attr.setters.frozen)
type: str = attr.ib(
validator=instance_of(str),
on_setattr=attr.setters.frozen)
config: dict = attr.ib(
validator=instance_of(dict),
factory=dict)
dependencies: dict = attr.ib(
validator=instance_of(dict),
factory=dict)
tags: List[str] = attr.ib(
validator=instance_of(List),
factory=list)

def to_dict(self) -> dict:
"""Convert the ClientSpec into a dictionary.
"""
return {'name': self.name,
'type': self.type,
'dependencies': self.dependencies,
'config': self.config,
'tags': self.tags}

@staticmethod
def from_dict(data: dict) -> 'ClientSpec':
"""Construct a ClientSpec from a dictionary.
:param data: Dictionary
:type data: dict
"""
return ClientSpec(**args)

def to_json(self, *args, **kwargs) -> str:
"""Convert the ClientSpec into a JSON string.
"""
return json.dumps(self.to_dict(), *args, **kwargs)

@staticmethod
def from_json(json_string: str) -> 'ClientSpec':
"""Construct a ClientSpec from a JSON string.
:param json_string: JSON string
:type json_string: str
"""
return ClientSpec.from_dict(json.loads(json_string))


@attr.s(auto_attribs=True, on_setattr=_check_validators, kw_only=True)
class BedrockSpec:
"""Bedrock specification.
Expand Down Expand Up @@ -1436,6 +1505,9 @@ class ProcSpec:
:param providers: List of ProviderSpec
:type providers: list
:param clients: List of ClientSpec
:type clients: list
"""

margo: MargoSpec = attr.ib(
Expand All @@ -1453,6 +1525,9 @@ class ProcSpec:
_providers: list = attr.ib(
factory=list,
validator=instance_of(list))
_clients: list = attr.ib(
factory=list,
validator=instance_of(list))
bedrock: BedrockSpec = attr.ib(
default=Factory(lambda self: BedrockSpec(pool=self.margo.rpc_pool),
takes_self=True),
Expand All @@ -1479,6 +1554,13 @@ def providers(self) -> SpecListDecorator:
"""
return SpecListDecorator(list=self._providers, type=ProviderSpec)

@property
def clients(self) -> SpecListDecorator:
"""Return a decorator to access the internal list of ClientSpec
and validate changes to this list.
"""
return SpecListDecorator(list=self._clients, type=ClientSpec)

def to_dict(self) -> dict:
"""Convert the ProcSpec into a dictionary.
"""
Expand All @@ -1487,6 +1569,7 @@ def to_dict(self) -> dict:
'ssg': [g.to_dict() for g in self._ssg],
'libraries': self.libraries,
'providers': [p.to_dict() for p in self._providers],
'clients': [c.to_dict() for c in self._clients],
'bedrock': self.bedrock.to_dict()}
return data

Expand All @@ -1511,6 +1594,9 @@ def from_dict(data: dict) -> 'ProcSpec':
if 'providers' in data:
for p in data['providers']:
providers.append(ProviderSpec.from_dict(p, margo.argobots))
if 'clients' in data:
for c in data['clients']:
clients.append(ClientSpec.from_dict(c))
if 'bedrock' in data:
bedrock = BedrockSpec.from_dict(data['bedrock'], margo.argobots)
return ProcSpec(margo=margo,
Expand Down Expand Up @@ -1556,6 +1642,10 @@ def validate(self) -> NoReturn:
if p.type not in self._libraries:
raise ValueError('Could not find module library for' +
f'module type {p.name}')
for c in self._clients:
if c.type not in self._libraries:
raise ValueError('Could not find module library for' +
f'module type {p.name}')


attr.resolve_types(MercurySpec, globals(), locals())
Expand All @@ -1564,6 +1654,8 @@ def validate(self) -> NoReturn:
attr.resolve_types(XstreamSpec, globals(), locals())
attr.resolve_types(ArgobotsSpec, globals(), locals())
attr.resolve_types(MargoSpec, globals(), locals())
attr.resolve_types(ProviderSpec, globals(), locals())
attr.resolve_types(ClientSpec, globals(), locals())
attr.resolve_types(AbtIOSpec, globals(), locals())
attr.resolve_types(SwimSpec, globals(), locals())
attr.resolve_types(SSGSpec, globals(), locals())
Expand Down

0 comments on commit 75f35f0

Please sign in to comment.