Skip to content

Commit

Permalink
Migrate to gymnasium (#143)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: William Blum <william.blum@microsoft.com>
  • Loading branch information
blumu and William Blum authored Aug 7, 2024
1 parent 399c327 commit 09622b8
Show file tree
Hide file tree
Showing 32 changed files with 1,389,677 additions and 1,389,677 deletions.
4 changes: 2 additions & 2 deletions cyberbattle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

"""Initialize CyberBattleSim module"""

from gym.envs.registration import registry, EnvSpec
from gym.error import Error
from gymnasium.envs.registration import registry, EnvSpec
from gymnasium.error import Error

from . import simulation
from . import agents
Expand Down
12 changes: 6 additions & 6 deletions cyberbattle/_env/cyberbattle_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import logging
import networkx
from networkx import convert_matrix
from typing import NamedTuple, Optional, Tuple, List, Dict, TypeVar, TypedDict
from typing import NamedTuple, Optional, Tuple, List, Dict, TypeVar, TypedDict, cast

from gym import spaces, Env
from gym.utils import seeding
from gymnasium import spaces, Env
from gymnasium.utils import seeding

import numpy

Expand Down Expand Up @@ -146,7 +146,7 @@ def inverse_dict(self: Dict[Key, Value]) -> Dict[Value, Key]:


class DummySpace(spaces.Space):
"""This class ensures that the values in the gym.spaces.Dict space are derived from gym.Space"""
"""This class ensures that the values in the gym.spaces.Dict space are derived from gymnasium.Space"""

def __init__(self, sample: object):
self._sample = sample
Expand Down Expand Up @@ -526,7 +526,7 @@ def __init__(
maximum_node_count = self.__bounds.maximum_node_count
port_count = self.__bounds.port_count

action_spaces_dict: dict[str, spaces.Space] = {
action_spaces = {
"local_vulnerability": spaces.MultiDiscrete(
# source_node_id, vulnerability_id
[maximum_node_count, local_vulnerabilities_count]
Expand All @@ -547,7 +547,7 @@ def __init__(
),
}

self.action_space = DiscriminatedUnion[Action](spaces=action_spaces_dict)
self.action_space = DiscriminatedUnion[Action](cast(dict, action_spaces)) # type: ignore

self.observation_space = ObservationSpaceType(self.__bounds)

Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/_env/cyberbattle_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from cyberbattle._env.option_wrapper import ContextWrapper, random_options
from cyberbattle._env.cyberbattle_env import AttackerGoal, CyberBattleEnv
import pytest
import gym
import gymnasium as gym
import numpy as np
from typing import cast

Expand Down
10 changes: 5 additions & 5 deletions cyberbattle/_env/discriminatedunion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
"""A discriminated union space for Gym"""

from collections import OrderedDict
from typing import Mapping, Optional, TypeVar, Union
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
from typing import Dict as TypingDict, Generic, cast
import numpy as np

from gym import spaces
from gym.utils import seeding
from gymnasium import spaces
from gymnasium.utils import seeding

T_cov = TypeVar("T_cov", covariant=True)

Expand Down Expand Up @@ -80,10 +80,10 @@ def __getitem__(self, key: str) -> spaces.Space:
def __repr__(self) -> str:
return self.__class__.__name__ + "(" + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"

def to_jsonable(self, sample_n: list) -> dict:
def to_jsonable(self, sample_n: Sequence[dict[str, Any]]) -> dict[str, list[Any]]:
return super().to_jsonable(sample_n)

def from_jsonable(self, sample_n: TypingDict[str, list]) -> list[dict]:
def from_jsonable(self, sample_n: TypingDict[str, list]) -> list[OrderedDict[str, Any]]:
ret = super().from_jsonable(sample_n)
assert len(ret) == 1
return ret
Expand Down
4 changes: 2 additions & 2 deletions cyberbattle/_env/flatten_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from collections import OrderedDict
from sqlite3 import NotSupportedError
from gym import spaces
from gymnasium import spaces
import numpy as np
from cyberbattle._env.cyberbattle_env import DummySpace, CyberBattleEnv, Action, CyberBattleSpaceKind
from gym.core import ObservationWrapper, ActionWrapper
from gymnasium.core import ObservationWrapper, ActionWrapper


class FlattenObservationWrapper(ObservationWrapper):
Expand Down
4 changes: 2 additions & 2 deletions cyberbattle/_env/graph_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional

import networkx as nx
from gym.spaces import Space, Dict
from gymnasium.spaces import Space, Dict


class BaseGraph(Space):
Expand Down Expand Up @@ -82,7 +82,7 @@ class MultiDiGraph(BaseGraph):


if __name__ == "__main__":
from gym.spaces import Box, Discrete
from gymnasium.spaces import Box, Discrete
import matplotlib.pyplot as plt # type:ignore

space = DiGraph(
Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/_env/graph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Union, Tuple

import gym
import gymnasium as gym
import numpy as onp
import networkx as nx

Expand Down
4 changes: 2 additions & 2 deletions cyberbattle/_env/option_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from typing import NamedTuple

import gym
from gym.spaces import Space, Discrete, Tuple
import gymnasium as gym
from gymnasium.spaces import Space, Discrete, Tuple
import numpy as onp
from cyberbattle._env.cyberbattle_env import Action, CyberBattleEnv

Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/agents/baseline/agent_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Optional, List, Tuple, overload
import enum
import numpy as np
from gym import spaces, Wrapper
from gymnasium import spaces, Wrapper
from numpy import ndarray
import cyberbattle._env.cyberbattle_env as cyberbattle_env
import logging
Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/agents/baseline/baseline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""Test training of baseline agents."""

import torch
import gym
import gymnasium as gym
import logging
import sys
import cyberbattle._env.cyberbattle_env as cyberbattle_env
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# %%
import sys
import logging
import gym
import gymnasium as gym
import cyberbattle.agents.baseline.learner as learner
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# %%
import sys
import logging
import gym
import gymnasium as gym
import cyberbattle.agents.baseline.learner as learner
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.agent_dql as dqla
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import sys
import logging
import gym
import gymnasium as gym
import torch

import cyberbattle.agents.baseline.learner as learner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from cyberbattle._env.cyberbattle_env import AttackerGoal
from cyberbattle.agents.baseline.agent_randomcredlookup import CredentialCacheExploiter
import cyberbattle.agents.baseline.learner as learner
import gym
import gymnasium as gym
import logging
import sys
import cyberbattle.agents.baseline.plotting as p
Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/agents/baseline/notebooks/notebook_tabularq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import sys
import logging
from typing import cast
import gym
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt # type:ignore
from cyberbattle.agents.baseline.learner import TrainedLearner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# %%
import sys
import logging
import gym
import gymnasium as gym
import importlib

import cyberbattle.agents.baseline.learner as learner
Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/agents/baseline/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

import torch
import gym
import gymnasium as gym
import logging
import sys
import asciichartpy
Expand Down
2,680 changes: 1,340 additions & 1,340 deletions notebooks/chainnetwork-optionwrapper.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 09622b8

Please sign in to comment.