From b67d48204339751a6e89bb4a088be046d5567f36 Mon Sep 17 00:00:00 2001 From: MatthewGerber Date: Sun, 11 Aug 2024 11:22:47 -0400 Subject: [PATCH] Watch out for truncated returns. --- .../policy_gradient/monte_carlo/reinforce.py | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/src/rlai/policy_gradient/monte_carlo/reinforce.py b/src/rlai/policy_gradient/monte_carlo/reinforce.py index 4bd9f7b..1f0a077 100644 --- a/src/rlai/policy_gradient/monte_carlo/reinforce.py +++ b/src/rlai/policy_gradient/monte_carlo/reinforce.py @@ -297,42 +297,34 @@ def improve( agent.v_S.plot(pdf) plt.figure(figsize=(10, 10)) - time_steps = [step.t for step in steps] + + steps_t = [step.t for step in steps] + + non_truncated_steps = [step for step in steps if step.returns is not None] + non_truncated_steps_t = [step.t for step in non_truncated_steps] # plot rewards and returns plt.plot( - time_steps, + steps_t, [step.reward.r for step in steps], color='red', label='Reward: r(t)' ) plt.plot( - time_steps, - [ - step.returns.return_value - for step in steps - if step.returns is not None - ], + non_truncated_steps_t, + [step.returns.return_value for step in non_truncated_steps], # type: ignore[union-attr] color='green', label='Return: g(t)' ) plt.plot( - time_steps, - [ - step.returns.baseline_return_value - for step in steps - if step.returns is not None - ], + non_truncated_steps_t, + [step.returns.baseline_return_value for step in non_truncated_steps], # type: ignore[union-attr] color='violet', label='Value: v(t)', ) plt.plot( - time_steps, - [ - step.returns.target - for step in steps - if step.returns is not None - ], + non_truncated_steps_t, + [step.returns.target for step in non_truncated_steps], # type: ignore[union-attr] color='orange', label='Target: g(t) - v(t)' ) @@ -348,7 +340,7 @@ def improve( # plot gamma (discount) in a twin-x axes gamma_axe: plt.Axes = plt.twinx() # type: ignore[assignment] gamma_axe.plot( - time_steps, + steps_t, [step.gamma for step in steps], color='blue', label='gamma(t)'