From c4ab0ca63193c86412b67702243808aac69de133 Mon Sep 17 00:00:00 2001 From: Mark Stephenson Date: Thu, 10 Oct 2024 10:09:36 -0600 Subject: [PATCH] Issue #0: Clean up extraneous logging --- src/bsk_rl/data/base.py | 3 ++- src/bsk_rl/gym.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/bsk_rl/data/base.py b/src/bsk_rl/data/base.py index 1910ff7..d9e9d60 100644 --- a/src/bsk_rl/data/base.py +++ b/src/bsk_rl/data/base.py @@ -192,7 +192,8 @@ def reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]: for new_data in new_data_dict.values(): self.data += new_data - logger.info(f"Data reward: {reward}") + nonzero_reward = {k: v for k, v in reward.items() if v != 0} + logger.info(f"Data reward: {nonzero_reward}") return reward diff --git a/src/bsk_rl/gym.py b/src/bsk_rl/gym.py index 06831ff..a2e8881 100644 --- a/src/bsk_rl/gym.py +++ b/src/bsk_rl/gym.py @@ -32,7 +32,6 @@ class GeneralSatelliteTasking(Env, Generic[SatObs, SatAct]): - def __init__( self, satellites: Union[Satellite, list[Satellite]], @@ -314,7 +313,6 @@ def _get_obs(self) -> MultiSatObs: tuple: Joint observation """ if self.generate_obs_retasking_only: - return tuple( ( satellite.get_obs() @@ -698,9 +696,14 @@ def step( terminated = self._get_terminated() truncated = self._get_truncated() info = self._get_info() - logger.info(f"Step reward: {reward}") - logger.info(f"Episode terminated: {terminated}") - logger.info(f"Episode truncated: {truncated}") + nonzero_reward = {k: v for k, v in reward.items() if v != 0} + logger.info(f"Step reward: {nonzero_reward}") + if any(terminated.values()): + terminated_true = [k for k, v in terminated.items() if v] + logger.info(f"Episode terminated: {terminated_true}") + if any(truncated.values()): + truncated_true = [k for k, v in truncated.items() if v] + logger.info(f"Episode truncated: {truncated_true}") logger.debug(f"Step info: {info}") logger.debug(f"Step observation: {observation}") return observation, reward, terminated, truncated, info