diff --git a/cyberbattle/_env/cyberbattle_env.py b/cyberbattle/_env/cyberbattle_env.py index ee7478c..6896bff 100644 --- a/cyberbattle/_env/cyberbattle_env.py +++ b/cyberbattle/_env/cyberbattle_env.py @@ -1317,7 +1317,7 @@ def get_explored_network_node_properties_bitmap_as_numpy( ] ) - def step(self, action: Action) -> Tuple[Observation, float, bool, bool, StepInfo]: + def step(self, action: Action) -> Tuple[Observation, float, bool, bool, StepInfo]: # type: ignore if self.__done: raise RuntimeError("new episode must be started with env.reset()") diff --git a/cyberbattle/_env/discriminatedunion.py b/cyberbattle/_env/discriminatedunion.py index bb72450..149233a 100644 --- a/cyberbattle/_env/discriminatedunion.py +++ b/cyberbattle/_env/discriminatedunion.py @@ -14,7 +14,7 @@ T_cov = TypeVar("T_cov", covariant=True) -class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): # type: ignore +class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): """ A discriminated union of simpler spaces. @@ -23,6 +23,9 @@ class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): # type: ignore self.observation_space = discriminatedunion.DiscriminatedUnion( {"foo": spaces.Discrete(2), "Bar": spaces.Discrete(3)}) + Generic type T_cov is the type of the contained discriminated values. + It should be defined as a typed dictionary, e.g.: TypedDict('Choices', {'foo': int, 'Bar': int}) + """ def __init__( @@ -47,7 +50,7 @@ def __init__( def seed(self, seed: Union[dict, None, int] = None): return super().seed(seed) - def sample(self, mask=None) -> T_cov: # dict[str, object]: + def sample(self, mask=None) -> T_cov: # type: ignore space_count = len(self.spaces.items()) index_k = self.union_np_random.integers(0, space_count) kth_key, kth_space = list(self.spaces.items())[index_k] diff --git a/cyberbattle/agents/baseline/agent_wrapper.py b/cyberbattle/agents/baseline/agent_wrapper.py index 63471af..c545cfc 100644 --- a/cyberbattle/agents/baseline/agent_wrapper.py +++ b/cyberbattle/agents/baseline/agent_wrapper.py @@ -25,7 +25,7 @@ def on_step( self, action: cyberbattle_env.Action, reward: float, - truncated, + truncated: bool, done: bool, observation: cyberbattle_env.Observation, ): @@ -332,14 +332,15 @@ def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature]): assert np.shape(self.ravelled_size) == (), f"! {np.shape(self.ravelled_size)}" super().__init__(p, [self.ravelled_size]) - def vector_to_index(self, feature_vector): + def vector_to_index(self, feature_vector) -> int: assert len(self.dim_sizes) == len(feature_vector), ( f"feature vector of size {len(feature_vector)}, " f"expecting {len(self.dim_sizes)}: {feature_vector} -- {self.dim_sizes}" ) - index: np.int32 = np.ravel_multi_index( + index_intp = np.ravel_multi_index( list(feature_vector), list(self.dim_sizes) ) + index = index_intp.item() assert index < self.ravelled_size, ( f"feature vector out of bound ({feature_vector}, dim={self.dim_sizes}) " f"-> index={index}, max_index={self.ravelled_size-1})"