Skip to content

Commit

Permalink
Feature type defined using dynamic type assertion to work around limi…
Browse files Browse the repository at this point in the history
…tations of python's type system (#139)

(lack of type-based method overloading)

Co-authored-by: William Blum <william.blum@microsoft.com>
  • Loading branch information
blumu and William Blum authored Aug 6, 2024
1 parent 55f8d00 commit 5993016
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 66 deletions.
5 changes: 1 addition & 4 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
"jupyter.jupyterServerType": "local",
"jupyter.notebookFileRoot": "${workspaceFolder}",
"files.trimFinalNewlines": true,
"editor.defaultFormatter": "ms-python.flake8",
"flake8.args": [
],
"editor.defaultFormatter": "ms-python.black-formatter",
"files.trimTrailingWhitespace": true,

}
2 changes: 1 addition & 1 deletion cyberbattle/_env/cyberbattle_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,7 +1317,7 @@ def get_explored_network_node_properties_bitmap_as_numpy(
]
)

def step(self, action: Action) -> Tuple[Observation, float, bool, bool, StepInfo]:
def step(self, action: Action) -> Tuple[Observation, float, bool, bool, StepInfo]: # type: ignore
if self.__done:
raise RuntimeError("new episode must be started with env.reset()")

Expand Down
7 changes: 5 additions & 2 deletions cyberbattle/_env/discriminatedunion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
T_cov = TypeVar("T_cov", covariant=True)


class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): # type: ignore
class DiscriminatedUnion(spaces.Dict, Generic[T_cov]):
"""
A discriminated union of simpler spaces.
Expand All @@ -23,6 +23,9 @@ class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): # type: ignore
self.observation_space = discriminatedunion.DiscriminatedUnion(
{"foo": spaces.Discrete(2), "Bar": spaces.Discrete(3)})
Generic type T_cov is the type of the contained discriminated values.
It should be defined as a typed dictionary, e.g.: TypedDict('Choices', {'foo': int, 'Bar': int})
"""

def __init__(
Expand All @@ -47,7 +50,7 @@ def __init__(
def seed(self, seed: Union[dict, None, int] = None):
return super().seed(seed)

def sample(self, mask=None) -> T_cov: # dict[str, object]:
def sample(self, mask=None) -> T_cov: # type: ignore
space_count = len(self.spaces.items())
index_k = self.union_np_random.integers(0, space_count)
kth_key, kth_space = list(self.spaces.items())[index_k]
Expand Down
Loading

0 comments on commit 5993016

Please sign in to comment.