From e484e3f7773534000f16dd9a0dced1d34a4c0ad9 Mon Sep 17 00:00:00 2001 From: Mark Stephenson Date: Sun, 4 Feb 2024 07:29:35 -0700 Subject: [PATCH] Issue #113: Reset previous action on reset --- .../scenario/sat_actions.py | 4 ++++ .../scenario/test_sat_actions.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_actions.py b/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_actions.py index a50ac081..fada4063 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_actions.py +++ b/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_actions.py @@ -32,7 +32,11 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.action_list = [] self.action_map = {} + + def reset_pre_sim(self) -> None: + """Reset the previous action key.""" self.prev_action_key = None # Used to avoid retasking of BSK tasks + return super().reset_pre_sim() def add_action( self, act_fn, act_name: Optional[str] = None, n_actions: Optional[int] = None diff --git a/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_actions.py b/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_actions.py index 088f3fc0..e1b5632a 100644 --- a/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_actions.py +++ b/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_actions.py @@ -9,6 +9,10 @@ @patch.multiple(sa.DiscreteSatAction, __abstractmethods__=set()) @patch("bsk_rl.envs.general_satellite_tasking.scenario.satellites.Satellite.__init__") +@patch( + "bsk_rl.envs.general_satellite_tasking.scenario.satellites.Satellite.reset_pre_sim", + MagicMock, +) class TestDiscreteSatAction: def test_init(self, sat_init): sa.DiscreteSatAction() @@ -51,6 +55,7 @@ def test_add_multiple_actions(self, sat_init, n_actions): ) def test_set_action(self, sat_init, disable_timed): sat = sa.DiscreteSatAction() + sat.reset_pre_sim() sat.action_list = [MagicMock(return_value="act_key")] sat.set_action(0) disable_timed.assert_called_once() @@ -62,9 +67,19 @@ def test_action_space(self, sat_init): sat.action_list = [0, 1, 2] assert sat.action_space == spaces.Discrete(3) + def test_reset_pre_sim(self, sat_init): + sat = sa.DiscreteSatAction() + sat.prev_action_key = "some_action" + sat.reset_pre_sim() + assert sat.prev_action_key is None + @patch.multiple(sa.DiscreteSatAction, __abstractmethods__=set()) @patch("bsk_rl.envs.general_satellite_tasking.scenario.satellites.Satellite.__init__") +@patch( + "bsk_rl.envs.general_satellite_tasking.scenario.satellites.Satellite.reset_pre_sim", + MagicMock, +) class TestFSWAction: def test_init(self, sat_init): FSWAct = sa.fsw_action_gen("cool_action") @@ -75,6 +90,7 @@ def test_init(self, sat_init): def make_action_sat(self): FSWAct = sa.fsw_action_gen("cool_action", 60.0) sat = FSWAct() + sat.reset_pre_sim() sat.fsw = MagicMock(cool_action=MagicMock()) sat.log_info = MagicMock() sat._disable_timed_terminal_event = MagicMock() @@ -93,6 +109,7 @@ def test_act(self, sat_init): def make_action_sat_configured(self): FSWAct = sa.fsw_action_gen("cool_action", 59.0).configure(action_duration=60.0) sat = FSWAct() + sat.reset_pre_sim() sat.fsw = MagicMock(cool_action=MagicMock()) sat.log_info = MagicMock() sat._disable_timed_terminal_event = MagicMock() @@ -158,6 +175,7 @@ def test_image_retask(self, sat_init, target): @pytest.mark.parametrize("target", [1, "target_1", MockTarget()]) def test_set_action(self, sat_init, discrete_set, target): sat = sa.ImagingActions() + sat.prev_action_key = None sat._disable_image_event = MagicMock() sat.image = MagicMock() sat.set_action(target)