Skip to content

Commit

Permalink
Watch out for truncated returns.
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewGerber committed Aug 11, 2024
1 parent 69cd486 commit b67d482
Showing 1 changed file with 13 additions and 21 deletions.
34 changes: 13 additions & 21 deletions src/rlai/policy_gradient/monte_carlo/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,42 +297,34 @@ def improve(
agent.v_S.plot(pdf)

plt.figure(figsize=(10, 10))
time_steps = [step.t for step in steps]

steps_t = [step.t for step in steps]

non_truncated_steps = [step for step in steps if step.returns is not None]
non_truncated_steps_t = [step.t for step in non_truncated_steps]

# plot rewards and returns
plt.plot(
time_steps,
steps_t,
[step.reward.r for step in steps],
color='red',
label='Reward: r(t)'
)
plt.plot(
time_steps,
[
step.returns.return_value
for step in steps
if step.returns is not None
],
non_truncated_steps_t,
[step.returns.return_value for step in non_truncated_steps], # type: ignore[union-attr]
color='green',
label='Return: g(t)'
)
plt.plot(
time_steps,
[
step.returns.baseline_return_value
for step in steps
if step.returns is not None
],
non_truncated_steps_t,
[step.returns.baseline_return_value for step in non_truncated_steps], # type: ignore[union-attr]
color='violet',
label='Value: v(t)',
)
plt.plot(
time_steps,
[
step.returns.target
for step in steps
if step.returns is not None
],
non_truncated_steps_t,
[step.returns.target for step in non_truncated_steps], # type: ignore[union-attr]
color='orange',
label='Target: g(t) - v(t)'
)
Expand All @@ -348,7 +340,7 @@ def improve(
# plot gamma (discount) in a twin-x axes
gamma_axe: plt.Axes = plt.twinx() # type: ignore[assignment]
gamma_axe.plot(
time_steps,
steps_t,
[step.gamma for step in steps],
color='blue',
label='gamma(t)'
Expand Down

0 comments on commit b67d482

Please sign in to comment.