diff --git a/pyproject.toml b/pyproject.toml index 02ea968e..24b0d74b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ "matplotlib", "numpy", "pandas", + "pettingzoo", "pytest", "pytest-cov", "pytest-repeat", diff --git a/src/bsk_rl/envs/general_satellite_tasking/gym_env.py b/src/bsk_rl/envs/general_satellite_tasking/gym_env.py index bc944981..257baf09 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/gym_env.py +++ b/src/bsk_rl/envs/general_satellite_tasking/gym_env.py @@ -1,8 +1,10 @@ +import functools from copy import deepcopy -from typing import Any, Iterable, Optional, Union +from typing import Any, Generic, Iterable, Optional, TypeVar, Union import numpy as np from gymnasium import Env, spaces +from pettingzoo.utils.env import AgentID, ParallelEnv from bsk_rl.envs.general_satellite_tasking.scenario.communication import NoCommunication from bsk_rl.envs.general_satellite_tasking.simulation.simulator import Simulator @@ -14,13 +16,13 @@ Satellite, ) -SatObs = Any -SatAct = Any +SatObs = TypeVar("SatObs") +SatAct = TypeVar("SatAct") MultiSatObs = tuple[SatObs, ...] MultiSatAct = Iterable[SatAct] -class GeneralSatelliteTasking(Env): +class GeneralSatelliteTasking(Env, Generic[SatObs, SatAct]): def __init__( self, satellites: Union[Satellite, list[Satellite]], @@ -67,7 +69,7 @@ def __init__( communicator: Object to manage communication between satellites sim_rate: Rate for model simulation [s]. max_step_duration: Maximum time to propagate sim at a step [s]. - failure_penalty: Reward for satellite failure. + failure_penalty: Reward for satellite failure. Should be nonpositive. time_limit: Time at which to truncate the simulation [s]. terminate_on_time_limit: Send terminations signal time_limit instead of just truncation. @@ -189,10 +191,29 @@ def _get_info(self) -> dict[str, Any]: info["requires_retasking"] = [ satellite.id for satellite in self.satellites - if satellite.requires_retasking + if satellite.requires_retasking and satellite.is_alive() ] return info + def _get_reward(self): + """Return a scalar reward for the step.""" + reward = sum(self.reward_dict.values()) + for satellite in self.satellites: + if not satellite.is_alive(): + reward += self.failure_penalty + return reward + + def _get_terminated(self) -> bool: + """Return the terminated flag for the step.""" + if self.terminate_on_time_limit and self._get_truncated(): + return True + else: + return not all(satellite.is_alive() for satellite in self.satellites) + + def _get_truncated(self) -> bool: + """Return the truncated flag for the step.""" + return self.simulator.sim_time >= self.time_limit + @property def action_space(self) -> spaces.Space[MultiSatAct]: """Compose satellite action spaces @@ -219,17 +240,7 @@ def observation_space(self) -> spaces.Space[MultiSatObs]: [satellite.observation_space for satellite in self.satellites] ) - def step( - self, actions: MultiSatAct - ) -> tuple[MultiSatObs, float, bool, bool, dict[str, Any]]: - """Propagate the simulation, update information, and get rewards - - Args: - Joint action for satellites - - Returns: - observation, reward, terminated, truncated, info - """ + def _step(self, actions: MultiSatAct) -> None: if len(actions) != len(self.satellites): raise ValueError("There must be the same number of actions and satellites") for satellite, action in zip(self.satellites, actions): @@ -252,23 +263,27 @@ def step( satellite.id: satellite.data_store.internal_update() for satellite in self.satellites } - reward = self.data_manager.reward(new_data) + self.reward_dict = self.data_manager.reward(new_data) self.communicator.communicate() - terminated = False - for satellite in self.satellites: - if not satellite.is_alive(): - terminated = True - reward += self.failure_penalty + def step( + self, actions: MultiSatAct + ) -> tuple[MultiSatObs, float, bool, bool, dict[str, Any]]: + """Propagate the simulation, update information, and get rewards + + Args: + Joint action for satellites - truncated = False - if self.simulator.sim_time >= self.time_limit: - truncated = True - if self.terminate_on_time_limit: - terminated = True + Returns: + observation, reward, terminated, truncated, info + """ + self._step(actions) observation = self._get_obs() + reward = self._get_reward() + terminated = self._get_terminated() + truncated = self._get_truncated() info = self._get_info() return observation, reward, terminated, truncated, info @@ -282,7 +297,7 @@ def close(self) -> None: del self.simulator -class SingleSatelliteTasking(GeneralSatelliteTasking): +class SingleSatelliteTasking(GeneralSatelliteTasking, Generic[SatObs, SatAct]): """A special case of the GeneralSatelliteTasking for one satellite. For compatibility with standard training APIs, actions and observations are directly exposed for the single satellite and are not wrapped in a tuple. @@ -296,7 +311,7 @@ def __init__(self, *args, **kwargs) -> None: ) @property - def action_space(self) -> spaces.Discrete: + def action_space(self) -> spaces.Space[SatAct]: """Return the single satellite action space""" return self.satellite.action_space @@ -316,3 +331,155 @@ def step(self, action) -> tuple[Any, float, bool, bool, dict[str, Any]]: def _get_obs(self) -> Any: return self.satellite.get_obs() + + +class MultiagentSatelliteTasking( + GeneralSatelliteTasking, ParallelEnv, Generic[SatObs, SatAct, AgentID] +): + """Implements the environment with the PettingZoo parallel API.""" + + def reset( + self, seed: int | None = None, options=None + ) -> tuple[MultiSatObs, dict[str, Any]]: + self.newly_dead = [] + return super().reset(seed, options) + + @property + def agents(self) -> list[AgentID]: + """Agents currently in the environment""" + truncated = super()._get_truncated() + return [ + satellite.id + for satellite in self.satellites + if (satellite.is_alive() and not truncated) + ] + + @property + def num_agents(self) -> int: + """Number of agents currently in the environment""" + return len(self.agents) + + @property + def possible_agents(self) -> list[AgentID]: + """Return the list of all possible agents.""" + return [satellite.id for satellite in self.satellites] + + @property + def max_num_agents(self) -> int: + """Maximum number of agents possible in the environment""" + return len(self.possible_agents) + + @property + def previously_dead(self) -> list[AgentID]: + """Return the list of agents that died at least one step ago.""" + return list(set(self.possible_agents) - set(self.agents) - set(self.newly_dead)) + + @property + def observation_spaces(self) -> dict[AgentID, spaces.Box]: + """Return the observation space for each agent""" + return { + agent: obs_space + for agent, obs_space in zip(self.possible_agents, super().observation_space) + } + + @functools.lru_cache(maxsize=None) + def observation_space(self, agent: AgentID) -> spaces.Space[SatObs]: + """Return the observation space for a certain agent""" + return self.observation_spaces[agent] + + @property + def action_spaces(self) -> dict[AgentID, spaces.Space[SatAct]]: + """Return the action space for each agent""" + return { + agent: act_space + for agent, act_space in zip(self.possible_agents, super().action_space) + } + + @functools.lru_cache(maxsize=None) + def action_space(self, agent: AgentID) -> spaces.Space[SatAct]: + """Return the action space for a certain agent""" + return self.action_spaces[agent] + + def _get_obs(self) -> dict[AgentID, SatObs]: + """Format the observation per the PettingZoo Parallel API""" + return { + agent: satellite.get_obs() + for agent, satellite in zip(self.possible_agents, self.satellites) + if agent not in self.previously_dead + } + + def _get_reward(self) -> dict[AgentID, float]: + """Format the reward per the PettingZoo Parallel API""" + reward = deepcopy(self.reward_dict) + for agent, satellite in zip(self.possible_agents, self.satellites): + if not satellite.is_alive(): + reward[agent] += self.failure_penalty + + reward_keys = list(reward.keys()) + for agent in reward_keys: + if agent in self.previously_dead: + del reward[agent] + + return reward + + def _get_terminated(self) -> dict[AgentID, bool]: + """Format terminations per the PettingZoo Parallel API""" + if self.terminate_on_time_limit and super()._get_truncated(): + return { + agent: True + for agent in self.possible_agents + if agent not in self.previously_dead + } + else: + return { + agent: not satellite.is_alive() + for agent, satellite in zip(self.possible_agents, self.satellites) + if agent not in self.previously_dead + } + + def _get_truncated(self) -> dict[AgentID, bool]: + """Format truncations per the PettingZoo Parallel API""" + truncated = super()._get_truncated() + return { + agent: truncated + for agent in self.possible_agents + if agent not in self.previously_dead + } + + def _get_info(self) -> dict[AgentID, dict]: + """Format info per the PettingZoo Parallel API""" + info = super()._get_info() + for agent in self.possible_agents: + if agent in self.previously_dead: + del info[agent] + return info + + def step( + self, + actions: dict[AgentID, SatAct], + ) -> tuple[ + dict[AgentID, SatObs], + dict[AgentID, float], + dict[AgentID, bool], + dict[AgentID, bool], + dict[AgentID, dict], + ]: + """Step the environment and return PettingZoo Parallel API format""" + previous_alive = self.agents + + action_vector = [] + for agent in self.possible_agents: + if agent in actions.keys(): + action_vector.append(actions[agent]) + else: + action_vector.append(None) + self._step(action_vector) + + self.newly_dead = list(set(previous_alive) - set(self.agents)) + + observation = self._get_obs() + reward = self._get_reward() + terminated = self._get_terminated() + truncated = self._get_truncated() + info = self._get_info() + return observation, reward, terminated, truncated, info diff --git a/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py b/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py index daa06579..6b15e585 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py +++ b/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py @@ -120,31 +120,34 @@ def __init__(self, env_features: Optional["EnvironmentFeatures"] = None) -> None data """ self.env_features = env_features + self.DataType = self.DataStore.DataType def reset(self) -> None: - self.data = self.DataStore.DataType() - self.cum_reward = 0.0 + self.data = self.DataType() + self.cum_reward = {} def create_data_store(self, satellite: "Satellite") -> None: """Create a data store for a satellite""" satellite.data_store = self.DataStore(self, satellite) + self.cum_reward[satellite.id] = 0.0 @abstractmethod # pragma: no cover - def _calc_reward(self, new_data_dict: dict[str, DataType]) -> float: + def _calc_reward(self, new_data_dict: dict[str, DataType]) -> dict[str, float]: """Calculate step reward based on all satellite data from a step Args: new_data_dict: Satellite-DataType pairs of new data from a step Returns: - Step reward + Step reward for each satellite """ pass - def reward(self, new_data_dict: dict[str, DataType]) -> float: + def reward(self, new_data_dict: dict[str, DataType]) -> dict[str, float]: """Calls _calc_reward and logs cumulative reward""" reward = self._calc_reward(new_data_dict) - self.cum_reward += reward + for satellite_id, sat_reward in reward.items(): + self.cum_reward[satellite_id] += sat_reward return reward @@ -167,7 +170,7 @@ class NoDataManager(DataManager): DataStore = NoDataStore def _calc_reward(self, new_data_dict): - return 0 + return {sat: 0.0 for sat in new_data_dict.keys()} ####################################### @@ -261,7 +264,9 @@ def __init__( super().__init__(env_features) self.reward_fn = reward_fn - def _calc_reward(self, new_data_dict: dict[str, UniqueImageData]) -> float: + def _calc_reward( + self, new_data_dict: dict[str, UniqueImageData] + ) -> dict[str, float]: """Reward new each unique image once using self.reward_fn() Args: @@ -270,11 +275,19 @@ def _calc_reward(self, new_data_dict: dict[str, UniqueImageData]) -> float: Returns: reward: Cumulative reward across satellites for one step """ - reward = 0.0 - for new_data in new_data_dict.values(): + reward = {} + imaged_targets = sum( + [new_data.imaged for new_data in new_data_dict.values()], [] + ) + for sat_id, new_data in new_data_dict.items(): + reward[sat_id] = 0.0 for target in new_data.imaged: if target not in self.data.imaged: - reward += self.reward_fn(target.priority) + reward[sat_id] += self.reward_fn( + target.priority + ) / imaged_targets.count(target) + + for new_data in new_data_dict.values(): self.data += new_data return reward @@ -454,7 +467,9 @@ def reward_fn(p): self.reward_fn = reward_fn - def _calc_reward(self, new_data_dict: ["NadirScanningTimeData"]) -> float: + def _calc_reward( + self, new_data_dict: dict[str, "NadirScanningTimeData"] + ) -> dict[str, float]: """Calculate step reward based on all satellite data from a step Args: @@ -463,8 +478,8 @@ def _calc_reward(self, new_data_dict: ["NadirScanningTimeData"]) -> float: Returns: Step reward """ - reward = 0.0 - for scanning_time in new_data_dict.values(): - reward += self.reward_fn(scanning_time.scanning_time) + reward = {} + for sat, scanning_time in new_data_dict.items(): + reward[sat] = self.reward_fn(scanning_time.scanning_time) return reward diff --git a/tests/integration/envs/general_satellite_tasking/test_int_full_environments.py b/tests/integration/envs/general_satellite_tasking/test_int_full_environments.py index ab344ed1..c068c9b2 100644 --- a/tests/integration/envs/general_satellite_tasking/test_int_full_environments.py +++ b/tests/integration/envs/general_satellite_tasking/test_int_full_environments.py @@ -1,6 +1,8 @@ import gymnasium as gym import pytest +from pettingzoo.test.parallel_test import parallel_api_test +from bsk_rl.envs.general_satellite_tasking.gym_env import MultiagentSatelliteTasking from bsk_rl.envs.general_satellite_tasking.scenario import data from bsk_rl.envs.general_satellite_tasking.scenario import satellites as sats from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import ( @@ -35,6 +37,30 @@ disable_env_checker=True, ) +parallel_env = MultiagentSatelliteTasking( + satellites=[ + sats.FullFeaturedSatellite( + "Sentinel-2A", + sat_args=sats.FullFeaturedSatellite.default_sat_args(oe=random_orbit), + imageAttErrorRequirement=0.01, + imageRateErrorRequirement=0.01, + ), + sats.FullFeaturedSatellite( + "Sentinel-2B", + sat_args=sats.FullFeaturedSatellite.default_sat_args(oe=random_orbit), + imageAttErrorRequirement=0.01, + imageRateErrorRequirement=0.01, + ), + ], + env_type=environment.GroundStationEnvModel, + env_args=None, + env_features=StaticTargets(n_targets=1000), + data_manager=data.UniqueImagingManager(), + sim_rate=0.5, + max_step_duration=1e9, + time_limit=5700.0, +) + @pytest.mark.parametrize("env", [multi_env]) def test_reproducibility(env): @@ -61,3 +87,10 @@ def test_reproducibility(env): break assert reward_sum_2 == reward_sum_1 + + +@pytest.mark.repeat(5) +def test_parallel_api(): + with pytest.warns(UserWarning): + # expect an erroneous warning about the info dict due to our additional info + parallel_api_test(parallel_env) diff --git a/tests/unittest/envs/general_satellite_tasking/scenario/test_data.py b/tests/unittest/envs/general_satellite_tasking/scenario/test_data.py index 93f7ad14..19f6442b 100644 --- a/tests/unittest/envs/general_satellite_tasking/scenario/test_data.py +++ b/tests/unittest/envs/general_satellite_tasking/scenario/test_data.py @@ -38,21 +38,23 @@ def test_reset(self): data.DataManager.DataStore = MagicMock() dm = data.DataManager(MagicMock()) dm.reset() - assert dm.cum_reward == 0 + assert dm.cum_reward == {} def test_create_data_store(self): sat = MagicMock() data.DataManager.DataStore = MagicMock(return_value="ds") dm = data.DataManager(MagicMock()) + dm.reset() dm.create_data_store(sat) assert sat.data_store == "ds" + assert sat.id in dm.cum_reward def test_reward(self): dm = data.DataManager(MagicMock()) - dm._calc_reward = MagicMock(return_value=10.0) - dm.cum_reward = 0 - assert 10.0 == dm.reward({"new": "data"}) - assert dm.cum_reward == 10.0 + dm._calc_reward = MagicMock(return_value={"sat": 10.0}) + dm.cum_reward = {"sat": 5.0} + assert {"sat": 10.0} == dm.reward({"sat": "data"}) + assert dm.cum_reward == {"sat": 15.0} class TestNoData: @@ -73,7 +75,7 @@ class TestNoDataManager: def test_calc_reward(self): dm = data.NoDataManager(MagicMock()) reward = dm._calc_reward({"sat1": 0, "sat2": 1}) - assert reward == 0 + assert reward == {"sat1": 0.0, "sat2": 0.0} class TestUniqueImageData: @@ -159,7 +161,7 @@ def test_calc_reward(self): "sat2": data.UniqueImageData([MagicMock(priority=0.2)]), } ) - assert reward == approx(0.3) + assert reward == {"sat1": approx(0.1), "sat2": approx(0.2)} def test_calc_reward_existing(self): tgt = MagicMock(priority=0.2) @@ -171,7 +173,19 @@ def test_calc_reward_existing(self): "sat2": data.UniqueImageData([tgt]), } ) - assert reward == approx(0.1) + assert reward == {"sat1": approx(0.1), "sat2": 0.0} + + def test_calc_reward_repeated(self): + tgt = MagicMock(priority=0.2) + dm = data.UniqueImagingManager(MagicMock()) + dm.data = data.UniqueImageData([]) + reward = dm._calc_reward( + { + "sat1": data.UniqueImageData([tgt]), + "sat2": data.UniqueImageData([tgt]), + } + ) + assert reward == {"sat1": approx(0.1), "sat2": approx(0.1)} def test_calc_reward_custom_fn(self): dm = data.UniqueImagingManager(MagicMock(), reward_fn=lambda x: 1 / x) @@ -182,7 +196,7 @@ def test_calc_reward_custom_fn(self): "sat2": data.UniqueImageData([MagicMock(priority=2)]), } ) - assert reward == approx(1.5) + assert reward == {"sat1": approx(1.0), "sat2": 0.5} class TestNadirScanningTimeData: @@ -240,7 +254,7 @@ def test_calc_reward(self): "sat2": data.NadirScanningTimeData(2), } ) - assert reward == approx(3) + assert reward == {"sat1": 1.0, "sat2": 2.0} def test_calc_reward_existing(self): dm = data.NadirScanningManager(MagicMock()) @@ -252,7 +266,7 @@ def test_calc_reward_existing(self): "sat2": data.NadirScanningTimeData(3), } ) - assert reward == approx(5) + assert reward == {"sat1": 2.0, "sat2": 3.0} def test_calc_reward_custom_fn(self): dm = data.NadirScanningManager(MagicMock(), reward_fn=lambda x: 1 / x) @@ -263,4 +277,4 @@ def test_calc_reward_custom_fn(self): "sat2": data.NadirScanningTimeData(2), } ) - assert reward == approx(1.0) + assert reward == {"sat1": 0.5, "sat2": 0.5} diff --git a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py index 6fd9f015..69492961 100644 --- a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py +++ b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py @@ -5,6 +5,7 @@ from bsk_rl.envs.general_satellite_tasking.gym_env import ( GeneralSatelliteTasking, + MultiagentSatelliteTasking, SingleSatelliteTasking, ) from bsk_rl.envs.general_satellite_tasking.scenario.satellites import Satellite @@ -119,7 +120,9 @@ def test_step(self): satellites=mock_sats, env_type=MagicMock(), env_features=MagicMock(), - data_manager=MagicMock(reward=MagicMock(return_value=25.0)), + data_manager=MagicMock( + reward=MagicMock(return_value={sat.id: 12.5 for sat in mock_sats}) + ), ) env.simulator = MagicMock(sim_time=101.0) _, reward, _, _, info = env.step((0, 10)) @@ -154,7 +157,9 @@ def test_step_stopped(self, sat_death, timeout, terminate_on_time_limit): satellites=mock_sats, env_type=MagicMock(), env_features=MagicMock(), - data_manager=MagicMock(reward=MagicMock(return_value=25.0)), + data_manager=MagicMock( + reward=MagicMock(return_value={sat.id: 12.5 for sat in mock_sats}) + ), terminate_on_time_limit=terminate_on_time_limit, ) env.simulator = MagicMock(sim_time=101.0) @@ -178,7 +183,7 @@ def test_step_retask_needed(self, capfd): satellites=[mock_sat], env_type=MagicMock(), env_features=MagicMock(), - data_manager=MagicMock(reward=MagicMock(return_value=25.0)), + data_manager=MagicMock(reward=MagicMock(return_value={mock_sat.id: 25.0})), ) env.simulator = MagicMock(sim_time=101.0) env.step(None) @@ -256,3 +261,249 @@ def test_step(self, step_patch): def test_get_obs(self): env, mock_sat = self.make_env() assert env._get_obs() == mock_sat.get_obs() + + +class TestMultiagentSatelliteTasking: + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.Simulator", + ) + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.MultiagentSatelliteTasking._get_obs", + ) + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.MultiagentSatelliteTasking._get_info", + ) + def test_reset(self, mock_sim, obs_fn, info_fn): + mock_sat_1 = MagicMock() + mock_sat_2 = MagicMock() + mock_sat_1.sat_args_generator = {} + mock_sat_2.sat_args_generator = {} + mock_data = MagicMock(env_features=None) + env = MultiagentSatelliteTasking( + satellites=[mock_sat_1, mock_sat_2], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=mock_data, + ) + env.env_args_generator = {"utc_init": "a long time ago"} + env.communicator = MagicMock() + obs, info = env.reset() + obs_fn.assert_called_once() + info_fn.assert_called_once() + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + def test_agents(self): + env = MultiagentSatelliteTasking( + satellites=[MagicMock() for i in range(3)], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + assert env.agents == [sat.id for sat in env.satellites] + assert env.num_agents == 3 + assert env.possible_agents == [sat.id for sat in env.satellites] + assert env.max_num_agents == 3 + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + def test_get_obs(self): + env = MultiagentSatelliteTasking( + satellites=[MagicMock(get_obs=MagicMock(return_value=i)) for i in range(3)], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + env.newly_dead = [] + assert env._get_obs() == {sat.id: i for i, sat in enumerate(env.satellites)} + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + def test_get_info(self): + mock_sats = [MagicMock(info={"sat_index": i}) for i in range(3)] + env = MultiagentSatelliteTasking( + satellites=mock_sats, + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + env.newly_dead = [] + env.latest_step_duration = 10.0 + expected = {sat.id: {"sat_index": i} for i, sat in enumerate(mock_sats)} + expected["d_ts"] = 10.0 + expected["requires_retasking"] = [sat.id for sat in mock_sats] + assert env._get_info() == expected + + def test_action_spaces(self): + env = MultiagentSatelliteTasking( + satellites=[ + MagicMock(action_space=spaces.Discrete(i + 1)) for i in range(3) + ], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + assert env.action_spaces == { + env.satellites[0].id: spaces.Discrete(1), + env.satellites[1].id: spaces.Discrete(2), + env.satellites[2].id: spaces.Discrete(3), + } + + def test_obs_spaces(self): + env = MultiagentSatelliteTasking( + satellites=[ + MagicMock(observation_space=spaces.Discrete(i + 1)) for i in range(3) + ], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + env.simulator = MagicMock() + env.reset = MagicMock() + assert env.observation_spaces == { + env.satellites[0].id: spaces.Discrete(1), + env.satellites[1].id: spaces.Discrete(2), + env.satellites[2].id: spaces.Discrete(3), + } + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + def test_get_reward(self): + env = MultiagentSatelliteTasking( + satellites=[ + MagicMock(is_alive=MagicMock(return_value=False)) for i in range(3) + ], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + failure_penalty=-20.0, + ) + env.newly_dead = [sat.id for sat in env.satellites] + env.reward_dict = {sat.id: 10.0 for i, sat in enumerate(env.satellites)} + assert env._get_reward() == { + sat.id: -10.0 for i, sat in enumerate(env.satellites) + } + + @pytest.mark.parametrize("timeout", [False, True]) + @pytest.mark.parametrize("terminate_on_time_limit", [False, True]) + def test_get_terminated(self, timeout, terminate_on_time_limit): + env = MultiagentSatelliteTasking( + satellites=[ + MagicMock(is_alive=MagicMock(return_value=True if i != 0 else False)) + for i in range(3) + ], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + terminate_on_time_limit=terminate_on_time_limit, + time_limit=100, + ) + env.simulator = MagicMock(sim_time=101 if timeout else 99) + + if not timeout or not terminate_on_time_limit: + env.newly_dead = [sat.id for sat in env.satellites] + assert env._get_terminated() == { + env.satellites[0].id: True, + env.satellites[1].id: False, + env.satellites[2].id: False, + } + else: + env.newly_dead = [sat.id for sat in env.satellites] + assert env._get_terminated() == { + env.satellites[0].id: True, + env.satellites[1].id: True, + env.satellites[2].id: True, + } + + @pytest.mark.parametrize("time", [99, 101]) + def test_get_truncated(self, time): + env = MultiagentSatelliteTasking( + satellites=[MagicMock() for i in range(3)], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + time_limit=100, + ) + env.simulator = MagicMock(sim_time=time) + env.newly_dead = [sat.id for sat in env.satellites] if time >= 100 else [] + assert env._get_truncated() == { + env.satellites[0].id: time >= 100, + env.satellites[1].id: time >= 100, + env.satellites[2].id: time >= 100, + } + + def test_close(self): + env = MultiagentSatelliteTasking( + satellites=[MagicMock()], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + env.simulator = MagicMock() + env.close() + assert not hasattr(env, "simulator") + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + def test_dead(self): + env = MultiagentSatelliteTasking( + satellites=[MagicMock() for _ in range(3)], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + env.satellites[1].is_alive = MagicMock(return_value=False) + env.satellites[2].is_alive = MagicMock(return_value=False) + env.newly_dead = [env.satellites[2].id] + assert env.previously_dead == [env.satellites[1].id] + assert env.agents == [env.satellites[0].id] + assert env.possible_agents == [sat.id for sat in env.satellites] + + mst = "bsk_rl.envs.general_satellite_tasking.gym_env.MultiagentSatelliteTasking." + + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._get_truncated", + MagicMock(return_value=False), + ) + @patch(mst + "_get_obs", MagicMock()) + @patch(mst + "_get_reward", MagicMock()) + @patch(mst + "_get_terminated", MagicMock()) + @patch(mst + "_get_truncated", MagicMock()) + @patch(mst + "_get_info", MagicMock()) + @patch( + "bsk_rl.envs.general_satellite_tasking.gym_env.GeneralSatelliteTasking._step", + MagicMock(), + ) + def test_step(self): + env = MultiagentSatelliteTasking( + satellites=[ + MagicMock(is_alive=MagicMock(return_value=True)) for _ in range(3) + ], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(), + ) + + def kill_sat_2(): + env.satellites[2].is_alive.return_value = False + + env._step.side_effect = lambda _: kill_sat_2() + env.satellites[1].is_alive.return_value = False + env.step( + { + env.satellites[0].id: 0, + env.satellites[2].id: 2, + } + ) + env._step.assert_called_with([0, None, 2]) + assert env.newly_dead == [env.satellites[2].id]