diff --git a/bsk_rl/envs/general_satellite_tasking/gym_env.py b/bsk_rl/envs/general_satellite_tasking/gym_env.py index b0e2e91d..bc944981 100644 --- a/bsk_rl/envs/general_satellite_tasking/gym_env.py +++ b/bsk_rl/envs/general_satellite_tasking/gym_env.py @@ -5,7 +5,6 @@ from gymnasium import Env, spaces from bsk_rl.envs.general_satellite_tasking.scenario.communication import NoCommunication -from bsk_rl.envs.general_satellite_tasking.scenario.satellites import REQUIRES_RETASKING from bsk_rl.envs.general_satellite_tasking.simulation.simulator import Simulator from bsk_rl.envs.general_satellite_tasking.types import ( CommunicationMethod, @@ -187,6 +186,11 @@ def _get_info(self) -> dict[str, Any]: satellite.id: deepcopy(satellite.info) for satellite in self.satellites } info["d_ts"] = self.latest_step_duration + info["requires_retasking"] = [ + satellite.id + for satellite in self.satellites + if satellite.requires_retasking + ] return info @property @@ -221,7 +225,7 @@ def step( """Propagate the simulation, update information, and get rewards Args: - Joint action for satellites. Can be none to maintain current task. + Joint action for satellites Returns: observation, reward, terminated, truncated, info @@ -229,17 +233,16 @@ def step( if len(actions) != len(self.satellites): raise ValueError("There must be the same number of actions and satellites") for satellite, action in zip(self.satellites, actions): - old_info = satellite.info satellite.info = [] # reset satellite info log if action is not None: + satellite.requires_retasking = False satellite.set_action(action) else: - if REQUIRES_RETASKING in old_info: + if satellite.requires_retasking: print( f"Satellite {satellite.id} requires retasking " "but received no task." ) - satellite.info.append(REQUIRES_RETASKING) previous_time = self.simulator.sim_time # should these be recorded in simulator self.simulator.run() diff --git a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py index 31e2cf64..1660b1bb 100644 --- a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py +++ b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py @@ -35,8 +35,6 @@ SatObs = Any SatAct = Any -REQUIRES_RETASKING = "REQUIRES_RETASKING" - class Satellite(ABC): dyn_type: type["DynamicsModel"] # Type of dynamics model used by this satellite @@ -89,6 +87,7 @@ def __init__( self.fsw: "FSWModel" self.dynamics: "DynamicsModel" self.data_store: DataStore + self.requires_retasking: bool self.variable_interval = variable_interval self._timed_terminal_event_name = None @@ -106,6 +105,7 @@ def _generate_sat_args(self) -> None: def reset_pre_sim(self) -> None: """Called in environment reset, before simulator initialization""" self.info = [] + self.requires_retasking = True self._generate_sat_args() assert self.data_store.is_fresh self.data_store.is_fresh = False @@ -155,7 +155,7 @@ def set_fsw(self, fsw_rate: float) -> "FSWModel": def reset_post_sim(self) -> None: """Called in environment reset, after simulator initialization""" - self.info.append(REQUIRES_RETASKING) + pass @property def observation_space(self) -> spaces.Box: @@ -237,7 +237,7 @@ def _update_timed_terminal_event( [f"self.TotalSim.CurrentNanos * {macros.NANO2SEC} >= {t_close}"], [ self._info_command(f"timed termination at {t_close:.1f} " + info), - self._satellite_command + f".info.append('{REQUIRES_RETASKING}')", + self._satellite_command + ".requires_retasking = True", ] + extra_actions, terminal=self.variable_interval, @@ -750,7 +750,7 @@ def _update_image_event(self, target: Target) -> None: [ self._info_command(f"imaged {target}"), self._satellite_command + ".imaged += 1", - self._satellite_command + f".info.append('{REQUIRES_RETASKING}')", + self._satellite_command + ".requires_retasking = True", ], terminal=self.variable_interval, ) diff --git a/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py b/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py index e3834991..d7b0745c 100644 --- a/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py +++ b/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py @@ -100,14 +100,12 @@ def test_init(self, sat_init): def test_explicit_normalization(self, sat_init): sat = so.TimeState(normalization_time=10.0) - sat.info = MagicMock() sat.simulator = MagicMock(sim_time=1.0) sat.reset_post_sim() assert sat.normalized_time() == 0.1 def test_implicit_normalization(self, sat_init): sat = so.TimeState(normalization_time=None) - sat.info = MagicMock() sat.simulator = MagicMock(sim_time=1.0, time_limit=10.0) sat.reset_post_sim() assert sat.normalized_time() == 0.1 diff --git a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py index 8dcd6054..6fd9f015 100644 --- a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py +++ b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py @@ -7,10 +7,7 @@ GeneralSatelliteTasking, SingleSatelliteTasking, ) -from bsk_rl.envs.general_satellite_tasking.scenario.satellites import ( - REQUIRES_RETASKING, - Satellite, -) +from bsk_rl.envs.general_satellite_tasking.scenario.satellites import Satellite class TestGeneralSatelliteTasking: @@ -67,6 +64,7 @@ def test_get_info(self): env.latest_step_duration = 10.0 expected = {sat.id: {"sat_index": i} for i, sat in enumerate(mock_sats)} expected["d_ts"] = 10.0 + expected["requires_retasking"] = [sat.id for sat in mock_sats] assert env._get_info() == expected def test_action_space(self): @@ -147,23 +145,6 @@ def test_step_bad_action(self): with pytest.raises(ValueError): env.step((0,)) - @patch.multiple(Satellite, __abstractmethods__=set()) - def test_step_retask_needed(self, capfd): - mock_sat = MagicMock() - env = SingleSatelliteTasking( - satellites=[mock_sat], - env_type=MagicMock(), - env_features=MagicMock(), - data_manager=MagicMock(reward=MagicMock(return_value=25.0)), - ) - env.simulator = MagicMock(sim_time=101.0) - env.step(None) - assert REQUIRES_RETASKING not in mock_sat.info - mock_sat.info = [REQUIRES_RETASKING] - env.step(None) - assert REQUIRES_RETASKING in mock_sat.info - assert "requires retasking but received no task" in capfd.readouterr().out - @pytest.mark.parametrize("sat_death", [True, False]) @pytest.mark.parametrize("timeout", [True, False]) @pytest.mark.parametrize("terminate_on_time_limit", [True, False]) @@ -190,6 +171,23 @@ def test_step_stopped(self, sat_death, timeout, terminate_on_time_limit): assert terminated == (sat_death or (timeout and terminate_on_time_limit)) assert truncated == timeout + @patch.multiple(Satellite, __abstractmethods__=set()) + def test_step_retask_needed(self, capfd): + mock_sat = MagicMock() + env = SingleSatelliteTasking( + satellites=[mock_sat], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(reward=MagicMock(return_value=25.0)), + ) + env.simulator = MagicMock(sim_time=101.0) + env.step(None) + assert mock_sat.requires_retasking + mock_sat.requires_retasking = True + env.step(None) + assert mock_sat.requires_retasking + assert "requires retasking but received no task" in capfd.readouterr().out + def test_render(self): pass