Skip to content

Commit

Permalink
chore: Update dependencies and pyright fixes
Browse files Browse the repository at this point in the history
- Update networkx
- Update pytest
- Update setuptools
- Update pyright
- Update plotly

chore: Remove networkx from createstubs.sh

chore: Update progressbar2 dependency to version 4.4.2

chore: Update wheel dependency to version 0.44.0

chore: Update pip dependency to version 24.2

chore: Update prompt-toolkit dependency to version 3.0.47

Refactor CyberBattleEnv class to use np_random consistently

- Remove trailing whitespace in notebook scripts
- Update dependencies: prompt-toolkit, pip, wheel, progressbar2
- Refactor CyberBattleEnv class to use np_random consistently
- Remove networkx from createstubs.sh
- Update dependencies: networkx, pytest, setuptools, pyright, plotly

.

.

.

.

.

.

.
  • Loading branch information
William Blum committed Aug 6, 2024
1 parent 55f8d00 commit 783221e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion cyberbattle/_env/cyberbattle_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,7 +1317,7 @@ def get_explored_network_node_properties_bitmap_as_numpy(
]
)

def step(self, action: Action) -> Tuple[Observation, float, bool, bool, StepInfo]:
def step(self, action: Action) -> Tuple[Observation, float, bool, bool, StepInfo]: # type: ignore
if self.__done:
raise RuntimeError("new episode must be started with env.reset()")

Expand Down
7 changes: 5 additions & 2 deletions cyberbattle/_env/discriminatedunion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
T_cov = TypeVar("T_cov", covariant=True)


class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): # type: ignore
class DiscriminatedUnion(spaces.Dict, Generic[T_cov]):
"""
A discriminated union of simpler spaces.
Expand All @@ -23,6 +23,9 @@ class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): # type: ignore
self.observation_space = discriminatedunion.DiscriminatedUnion(
{"foo": spaces.Discrete(2), "Bar": spaces.Discrete(3)})
Generic type T_cov is the type of the contained discriminated values.
It should be defined as a typed dictionary, e.g.: TypedDict('Choices', {'foo': int, 'Bar': int})
"""

def __init__(
Expand All @@ -47,7 +50,7 @@ def __init__(
def seed(self, seed: Union[dict, None, int] = None):
return super().seed(seed)

def sample(self, mask=None) -> T_cov: # dict[str, object]:
def sample(self, mask=None) -> T_cov: # type: ignore
space_count = len(self.spaces.items())
index_k = self.union_np_random.integers(0, space_count)
kth_key, kth_space = list(self.spaces.items())[index_k]
Expand Down
7 changes: 4 additions & 3 deletions cyberbattle/agents/baseline/agent_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def on_step(
self,
action: cyberbattle_env.Action,
reward: float,
truncated,
truncated: bool,
done: bool,
observation: cyberbattle_env.Observation,
):
Expand Down Expand Up @@ -332,14 +332,15 @@ def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature]):
assert np.shape(self.ravelled_size) == (), f"! {np.shape(self.ravelled_size)}"
super().__init__(p, [self.ravelled_size])

def vector_to_index(self, feature_vector):
def vector_to_index(self, feature_vector) -> int:
assert len(self.dim_sizes) == len(feature_vector), (
f"feature vector of size {len(feature_vector)}, "
f"expecting {len(self.dim_sizes)}: {feature_vector} -- {self.dim_sizes}"
)
index: np.int32 = np.ravel_multi_index(
index_intp = np.ravel_multi_index(
list(feature_vector), list(self.dim_sizes)
)
index = index_intp.item()
assert index < self.ravelled_size, (
f"feature vector out of bound ({feature_vector}, dim={self.dim_sizes}) "
f"-> index={index}, max_index={self.ravelled_size-1})"
Expand Down

0 comments on commit 783221e

Please sign in to comment.