diff --git a/bsk_rl/envs/general_satellite_tasking/gym_env.py b/bsk_rl/envs/general_satellite_tasking/gym_env.py index 62e47937..bc944981 100644 --- a/bsk_rl/envs/general_satellite_tasking/gym_env.py +++ b/bsk_rl/envs/general_satellite_tasking/gym_env.py @@ -186,6 +186,11 @@ def _get_info(self) -> dict[str, Any]: satellite.id: deepcopy(satellite.info) for satellite in self.satellites } info["d_ts"] = self.latest_step_duration + info["requires_retasking"] = [ + satellite.id + for satellite in self.satellites + if satellite.requires_retasking + ] return info @property @@ -229,7 +234,15 @@ def step( raise ValueError("There must be the same number of actions and satellites") for satellite, action in zip(self.satellites, actions): satellite.info = [] # reset satellite info log - satellite.set_action(action) + if action is not None: + satellite.requires_retasking = False + satellite.set_action(action) + else: + if satellite.requires_retasking: + print( + f"Satellite {satellite.id} requires retasking " + "but received no task." + ) previous_time = self.simulator.sim_time # should these be recorded in simulator self.simulator.run() diff --git a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py index 20c2df0c..1660b1bb 100644 --- a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py +++ b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py @@ -87,6 +87,7 @@ def __init__( self.fsw: "FSWModel" self.dynamics: "DynamicsModel" self.data_store: DataStore + self.requires_retasking: bool self.variable_interval = variable_interval self._timed_terminal_event_name = None @@ -104,6 +105,7 @@ def _generate_sat_args(self) -> None: def reset_pre_sim(self) -> None: """Called in environment reset, before simulator initialization""" self.info = [] + self.requires_retasking = True self._generate_sat_args() assert self.data_store.is_fresh self.data_store.is_fresh = False @@ -235,6 +237,7 @@ def _update_timed_terminal_event( [f"self.TotalSim.CurrentNanos * {macros.NANO2SEC} >= {t_close}"], [ self._info_command(f"timed termination at {t_close:.1f} " + info), + self._satellite_command + ".requires_retasking = True", ] + extra_actions, terminal=self.variable_interval, @@ -747,6 +750,7 @@ def _update_image_event(self, target: Target) -> None: [ self._info_command(f"imaged {target}"), self._satellite_command + ".imaged += 1", + self._satellite_command + ".requires_retasking = True", ], terminal=self.variable_interval, ) diff --git a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py index 7d267c56..6fd9f015 100644 --- a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py +++ b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py @@ -64,6 +64,7 @@ def test_get_info(self): env.latest_step_duration = 10.0 expected = {sat.id: {"sat_index": i} for i, sat in enumerate(mock_sats)} expected["d_ts"] = 10.0 + expected["requires_retasking"] = [sat.id for sat in mock_sats] assert env._get_info() == expected def test_action_space(self): @@ -170,6 +171,23 @@ def test_step_stopped(self, sat_death, timeout, terminate_on_time_limit): assert terminated == (sat_death or (timeout and terminate_on_time_limit)) assert truncated == timeout + @patch.multiple(Satellite, __abstractmethods__=set()) + def test_step_retask_needed(self, capfd): + mock_sat = MagicMock() + env = SingleSatelliteTasking( + satellites=[mock_sat], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(reward=MagicMock(return_value=25.0)), + ) + env.simulator = MagicMock(sim_time=101.0) + env.step(None) + assert mock_sat.requires_retasking + mock_sat.requires_retasking = True + env.step(None) + assert mock_sat.requires_retasking + assert "requires retasking but received no task" in capfd.readouterr().out + def test_render(self): pass