diff --git a/readme.md b/README.md
similarity index 88%
rename from readme.md
rename to README.md
index 508032a..e647dde 100644
--- a/readme.md
+++ b/README.md
@@ -2,7 +2,8 @@
## Overview
This repository contains code reinforcement learning code for solving the
-Udacity Deep reinforcement Learning projects.
+Udacity Deep reinforcement Learning projects. It has been refactored in such a way
+as to be useful for applying the implemented algorithms to new environments with minimal setup.
## Prerequisites
- Anaconda
@@ -14,6 +15,9 @@ All models are developed in Pytorch
Recreate the Anaconda environment with:
`conda env create -f environment.yml`
+Activate the conda environment with:
+`conda activate drl_toolbox`
+
## Repository Structure
The code is organized as follows:
@@ -63,21 +67,56 @@ The code is organized as follows:
- [Task/environment Details](tasks/banana_collector/TASK_DETAILS.md)
- [REPORT.md](tasks/banana_collector/solutions/ray_tracing_banana/REPORT.md)
- [RESULTS.pdf](tasks/banana_collector/solutions/ray_tracing_banana/RESULTS.pdf)
- - [Train](tasks/banana_collector/solutions/ray_tracing_banana/banana_solution_train.py)
- - [Eval](tasks/banana_collector/solutions/ray_tracing_banana/banana_solution_eval.py)
-- Visual (pixel) implementation
+ - [Train DQN](tasks/banana_collector/solutions/ray_tracing_banana/banana_solution_train.py)
+ - [Eval DQN](tasks/banana_collector/solutions/ray_tracing_banana/banana_solution_eval.py)
+
+
+
+
+- Visual Banana (pixel) implementation
- [Task/environment Details](tasks/banana_collector/TASK_DETAILS.md)
- [REPORT.md](tasks/banana_collector/solutions/pixel_banana/REPORT.md)
- - [Train](tasks/banana_collector/solutions/pixel_banana/banana_visual_solution_train.py)
- - [Eval](tasks/banana_collector/solutions/pixel_banana/banana_visual_solution_train.py)
-- Reacher continuous control (20-agent) implementation
- - [Task/environment Details](tasks/reacher_continuous_control/TASK_DETAILS.md)
- - [REPORT.md](tasks/reacher_continuous_control/solutions/ddpg/REPORT.md)
- - [Train DDPG](tasks/reacher_continuous_control/solutions/ddpg/train_ddpg_baseline.py)
- - [Eval DDPG](tasks/reacher_continuous_control/solutions/ddpg/eval_ddpg_baseline.py)
- - [Train TD3](tasks/reacher_continuous_control/solutions/ddpg/train_ddpg_baseline.py)
- - [Eval TD3](tasks/reacher_continuous_control/solutions/ddpg/eval_td3_baseline.py)
+ - [Train DQN](tasks/banana_collector/solutions/pixel_banana/banana_visual_solution_train.py)
+ - [Eval DQN](tasks/banana_collector/solutions/pixel_banana/banana_visual_solution_train.py)
+
+
+
+
+- Reacher (20 homogeneous agents) implementation
+ - [Task/environment Details](tasks/reacher/TASK_DETAILS.md)
+ - [REPORT.md](tasks/reacher/solutions/ddpg/REPORT.md)
+ - [Train TD3](tasks/reacher/solutions/ddpg/train_td3_per.py)
+ - [Eval TD3](tasks/reacher/solutions/ddpg/eval_td3_per.py)
+
+
+
+- Crawler (12 homogeneous agents) implementation
+ - [Task/environment Details](tasks/crawler/TASK_DETAILS.md)
+ - [REPORT.md](tasks/crawler/solutions/ppo/REPORT.md)
+ - [Train TD3](tasks/crawler/solutions/ppo/eval_ppo.py)
+ - [Eval TD3](tasks/crawler/solutions/ppo/train_ppo.py)
+
+
+
+
+- Soccer (multi agent) implementation
+ - [Task/environment Details](tasks/soccer/TASK_DETAILS.md)
+ - [REPORT.md](tasks/soccer/REPORT.md)
+ - [Train MAPPO](tasks/soccer/solutions/mappo/train_mappo.py)
+ - [Eval MAPPO](tasks/soccer/solutions/mappo/eval_mappo.py)
+
+
+
+
+- Tennis (multi agent)) implementation
+ - [Task/environment Details](tasks/tennis/TASK_DETAILS.md)
+ - [REPORT.md](tasks/tennis/REPORT.md)
+ - [Train MAPPO](tasks/tennis/solutions/mappo/train_mappo.py)
+ - [Eval MAPPO](tasks/tennis/solutions/mappo/eval_mappo.py)
+
+
+
## Agent Implementations and explanation
Currently only the [Deep Q-Network](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) algorithm is implemented, along with
@@ -135,7 +174,7 @@ Below, we discuss the algorithm at a high level, along with the implemented exte
YtDouble-DQN ≡ Rt + γQ(st+1, argmaxa Q(st+1, a; θt) θt-)
- See the `compute_errors` method of the [Base Policy](agents/policies/base.py) class for code implementation
+ See the `compute_errors` method of the [Base Policy](agents/policies/base_policy.py) class for code implementation
###### [Prioritized Experience Replay (PER)](https://arxiv.org/abs/1511.05952)
Rather than performing learning updates on experiences as they are sampled from the environment (i.e. sequentially through time), the DQN
@@ -169,7 +208,7 @@ Below, we discuss the algorithm at a high level, along with the implemented exte
A SumTree data structure is implemented to perform weighted sampling efficiently. See the implementation
of the [PER buffer](agents/memory/prioritized_memory.py), and the [SumTree](tools/data_structures/sumtree.py).
- See the `compute_errors` method of the [Base Policy](agents/policies/base.py) class shows where importance weights
+ See the `compute_errors` method of the [Base Policy](agents/policies/base_policy.py) class shows where importance weights
are applied to scale the gradients, and `step` method of the [DQNAgent](agents/dqn_agent.py) contains the implementation
of updating the priorities.
@@ -199,7 +238,7 @@ Below, we discuss the algorithm at a high level, along with the implemented exte
###### [Distributional (Categorical) DQN network](https://arxiv.org/abs/1707.06887)
The categorical DQN algorithm attempts to model the `return distribution` for an action, rather than the
`expected return`, thus modelling the distribution of Q(s, a). The categorical DQN is implemented in
- the `get_output` method of [dqn](agents/models/dqn.py), with corresponding [categorical policy](agents/policies/categorical.py)
+ the `get_output` method of [dqn](agents/models/dqn.py), with corresponding [categorical policy](agents/policies/categorical_policy.py)
which is responsible for computing the errors between the target and online network distributions. Please refer
to the paper for theoretical details and to this [reference implementation](https://github.com/higgsfield/RL-Adventure/blob/master/7.rainbow%20dqn.ipynb),
from which the code is adapted from.
@@ -374,7 +413,7 @@ Below, we discuss the algorithm at a high level, along with the implemented exte
where c1 and c2 are constants.
###### MAPPO
- The PPO algorithm above can be extended to the multi-agent scenario in an analagous way as DDPG to MADDPG. This
+ The PPO algorithm above can be extended to the multi-agent scenario in an analogous way as DDPG to MADDPG. This
involves passing the state and actions of all other agents in the environment to the (joint_state, joint_actions)
to the Critic of each agent during training, whos value estimate will assist in guiding the learning of the policy (Actor)
network. During evaluation, only the policy network is used, and the agents are not provided any external information regarding
diff --git a/agents/models/ppo.py b/agents/models/ppo.py
index 985943d..f6b01c3 100644
--- a/agents/models/ppo.py
+++ b/agents/models/ppo.py
@@ -2,7 +2,6 @@
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
-import numpy as np
from tools.misc import set_seed
@@ -25,15 +24,6 @@ def step_episode(self):
pass
def forward(self, state, action=None, scale=1, min_std=0.05, *args, **kargs):
- """Build Policy.
-
- Returns
- ======
- action (Tensor): predicted action or inputed action
- log_prob (Tensor): log probability of current action distribution
- ent (Tensor): entropy of current action distribution
- value (Tensor): estimate value function
- """
assert min_std >= 0 and scale >= 0
if self.continuous_actions:
action_mean = self.actor(state)
@@ -73,15 +63,6 @@ def step_episode(self):
def forward(self, agent_state: torch.FloatTensor, other_agent_states: torch.FloatTensor,
other_agent_actions: Optional[torch.FloatTensor] = None, action: Optional[torch.FloatTensor] = None, min_std=0.05, scale=1,):
- """Build Policy.
-
- Returns
- ======
- action (Tensor): predicted action or inputed action
- log_prob (Tensor): log probability of current action distribution
- ent (Tensor): entropy of current action distribution
- value (Tensor): estimate value function
- """
assert min_std > 0 and scale >= 0, (min_std, scale)
if self.continuous_actions:
diff --git a/tasks/banana_collector/solutions/pixel_banana/REPORT.md b/tasks/banana_collector/solutions/pixel_banana/REPORT.md
index f293c68..640ec3a 100644
--- a/tasks/banana_collector/solutions/pixel_banana/REPORT.md
+++ b/tasks/banana_collector/solutions/pixel_banana/REPORT.md
@@ -127,4 +127,4 @@ Value channel:
![value][image5]
Basic experiments were performed with the above dimensionality techniques, which can be found in [tools](../../../../tools/image_utils.py),
-however the network has signfiicant difficulty learning from them (at least in the constraints imposed by the memory issue).
+however the network has significant difficulty learning from them (at least in the constraints imposed by the memory issue).
diff --git a/tasks/banana_collector/solutions/pixel_banana/solution_checkpoint/trainined_visual_banana_agent.gif b/tasks/banana_collector/solutions/pixel_banana/solution_checkpoint/trainined_visual_banana_agent.gif
new file mode 100644
index 0000000..e3ccac7
Binary files /dev/null and b/tasks/banana_collector/solutions/pixel_banana/solution_checkpoint/trainined_visual_banana_agent.gif differ
diff --git a/tasks/crawler/TASK_DETAILS.md b/tasks/crawler/TASK_DETAILS.md
index 78949b9..746ff31 100644
--- a/tasks/crawler/TASK_DETAILS.md
+++ b/tasks/crawler/TASK_DETAILS.md
@@ -1,3 +1,4 @@
+[image2]: https://user-images.githubusercontent.com/10624937/43851646-d899bf20-9b00-11e8-858c-29b5c2c94ccc.png "Crawler"
### (Optional) Challenge: Crawler Environment
diff --git a/tasks/crawler/solutions/ppo/REPORT.md b/tasks/crawler/solutions/ppo/REPORT.md
new file mode 100644
index 0000000..9c9d6ef
--- /dev/null
+++ b/tasks/crawler/solutions/ppo/REPORT.md
@@ -0,0 +1,82 @@
+[scores]: solution_checkpoint/ppo_training_scores.png "PPO Baseline Results"
+
+# Crawler
+Please see the [repository overview](../../../../README.md) as well as the [task description](../../TASK_DETAILS.md)
+before reading this report. The theoretical details of the utilized algorithms can be found in the [repository overview](../../../../README.md).
+
+In this task there are 12 crawler agents who's goal is to reach a static location in the environment as fast as possible
+(i.e. minimize falling and maximize for speed).
+
+
+
+# Solution Overview
+
+The solutions discussed in this report rely on the PPO algorithm. All 12 algorithms share the same PPO brain
+(actor-critic and optimizer) and the same shared trajectory buffer. During training, agents may perform
+batch learning by sampling from the shared replay buffer. After a small number of learning epochs, the experience
+samples are discarded.
+
+The actor-critic architecture has the following form:
+
+```
+PPO_Actor_Critic(
+ (actor): MLP(
+ (mlp_layers): Sequential(
+ (0): BatchNorm1d(129, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (1): Linear(in_features=129, out_features=128, bias=True)
+ (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (3): LeakyReLU(negative_slope=True)
+ (4): Linear(in_features=128, out_features=128, bias=True)
+ (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (6): LeakyReLU(negative_slope=True)
+ (7): Linear(in_features=128, out_features=20, bias=True)
+ (8): Tanh()
+ )
+ )
+ (critic): MLP(
+ (mlp_layers): Sequential(
+ (0): BatchNorm1d(129, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (1): Linear(in_features=129, out_features=128, bias=True)
+ (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (3): LeakyReLU(negative_slope=True)
+ (4): Linear(in_features=128, out_features=128, bias=True)
+ (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (6): LeakyReLU(negative_slope=True)
+ (7): Linear(in_features=128, out_features=1, bias=True)
+ (8): Tanh()
+ )
+ )
+)
+```
+The model hyper-parameters are given below:
+
+```
+NUM_EPISODES = 3000
+SEED = 8
+MAX_T = 2000
+WEIGHT_DECAY = 1e-4
+EPSILON = 1e-5 # epsilon of Adam
+LR = 1e-4 # learning rate of the actor-critic
+BATCH_SIZE = 1024
+DROPOUT = None
+BATCHNORM = True
+SOLVE_SCORE = 1600
+```
+
+## Results
+
+Below we show the plot of mean episode scores (across all agents) versus episode number.
+
+![Training scores][scores]
+
+The environment was solved (mean reward of >=1600) after about 320 episodes.
+The training time took roughly 1.3 hours.
+
+## Discussion
+The PPO algorithm demonstrated good stability and convergence, and was experimentally shown to be rather robust to changes
+in hyperparameters.
+
+## Ideas for Future Work
+The MAPPO algorithm demonstrated quick convergence on this task, however it's sample efficiency leaves much to be desired.
+In order to increase the sample efficiency, memory replay methods such as [Hindsight Experience Replay (HER)](https://papers.nips.cc/paper/7090-hindsight-experience-replay.pdf) can be implemented,
+which better help the agent learn from sparse rewards.
diff --git a/tasks/crawler/solutions/ppo/train_ppo.py b/tasks/crawler/solutions/ppo/train_ppo.py
index 571cce4..b282206 100644
--- a/tasks/crawler/solutions/ppo/train_ppo.py
+++ b/tasks/crawler/solutions/ppo/train_ppo.py
@@ -18,7 +18,7 @@
MAX_T = 2000
WEIGHT_DECAY = 1e-4
EPSILON = 1e-5 # epsilon of Adam
-LR = 1e-4 # learning rate of the actor
+LR = 1e-4 # learning rate of the actor-critic
BATCH_SIZE = 1024
DROPOUT = None
BATCHNORM = True
diff --git a/tasks/reacher/TASK_DETAILS.md b/tasks/reacher/TASK_DETAILS.md
index 6c73bb7..18b6b70 100644
--- a/tasks/reacher/TASK_DETAILS.md
+++ b/tasks/reacher/TASK_DETAILS.md
@@ -1,7 +1,4 @@
-[//]: # (Image References)
-
[image1]: https://user-images.githubusercontent.com/10624937/43851024-320ba930-9aff-11e8-8493-ee547c6af349.gif "Trained Agent"
-[image2]: https://user-images.githubusercontent.com/10624937/43851646-d899bf20-9b00-11e8-858c-29b5c2c94ccc.png "Crawler"
# Project 2: Continuous Control
diff --git a/tasks/reacher/solutions/ddpg/REPORT.md b/tasks/reacher/solutions/ddpg/REPORT.md
index cacc1c0..bbcb963 100644
--- a/tasks/reacher/solutions/ddpg/REPORT.md
+++ b/tasks/reacher/solutions/ddpg/REPORT.md
@@ -1,8 +1,7 @@
[image1]: https://user-images.githubusercontent.com/10624937/43851024-320ba930-9aff-11e8-8493-ee547c6af349.gif "Trained Agent"
-[image2]: resources/ddpg_baseline.png "DDPG Baseline Results"
[image3]: resources/per_td3_baseline.png "TD3 PER Baseline Results"
-# Reacher (Continuous Control)
+# Reacher
Please see the [repository overview](../../../../README.md) as well as the [task description](../../TASK_DETAILS.md)
before reading this report. The theoretical details of the utilized algorithms can be found in the [repository overview](../../../../README.md).
diff --git a/tasks/reacher/solutions/ddpg/eval_td3_baseline.py b/tasks/reacher/solutions/ddpg/eval_td3_per.py
similarity index 86%
rename from tasks/reacher/solutions/ddpg/eval_td3_baseline.py
rename to tasks/reacher/solutions/ddpg/eval_td3_per.py
index 80ab662..650554d 100644
--- a/tasks/reacher/solutions/ddpg/eval_td3_baseline.py
+++ b/tasks/reacher/solutions/ddpg/eval_td3_per.py
@@ -2,7 +2,7 @@
import torch
from tasks.reacher.solutions.utils import get_simulator, BRAIN_NAME
from tasks.reacher.solutions.ddpg import SOLUTIONS_CHECKPOINT_DIR
-from tasks.reacher.solutions.ddpg.train_td3_baseline import get_solution_brain_set, MAX_T
+from tasks.reacher.solutions.ddpg.train_td3_per import get_solution_brain_set, MAX_T
SAVE_TAG = 'per_td3'
ACTOR_CHECKPOINT = os.path.join(SOLUTIONS_CHECKPOINT_DIR, f'{SAVE_TAG}_actor_checkpoint.pth')
diff --git a/tasks/reacher/solutions/ddpg/resources/ddpg_baseline.png b/tasks/reacher/solutions/ddpg/resources/ddpg_baseline.png
deleted file mode 100644
index 662cdc2..0000000
Binary files a/tasks/reacher/solutions/ddpg/resources/ddpg_baseline.png and /dev/null differ
diff --git a/tasks/reacher/solutions/ddpg/resources/per_td3_baseline.png b/tasks/reacher/solutions/ddpg/resources/per_td3_baseline.png
deleted file mode 100644
index bdac2be..0000000
Binary files a/tasks/reacher/solutions/ddpg/resources/per_td3_baseline.png and /dev/null differ
diff --git a/tasks/reacher/solutions/ddpg/train_td3_baseline.py b/tasks/reacher/solutions/ddpg/train_td3_per.py
similarity index 100%
rename from tasks/reacher/solutions/ddpg/train_td3_baseline.py
rename to tasks/reacher/solutions/ddpg/train_td3_per.py
diff --git a/tasks/soccer/REPORT.md b/tasks/soccer/REPORT.md
index 4251af4..f343b3a 100644
--- a/tasks/soccer/REPORT.md
+++ b/tasks/soccer/REPORT.md
@@ -1,15 +1,9 @@
[trained_soccer]:https://user-images.githubusercontent.com/10624937/42135622-e55fb586-7d12-11e8-8a54-3c31da15a90a.gif "Soccer"
[mappo_results_image]: solutions/mappo/solution_checkpoint/mappo_100_consecutive_wins_training_scores.png "MAPPO Training"
-### Multi Agent Soccer Environment
-![Soccer][trained_soccer]
-
-
-
-
# Soccer MAPPO/MATD3 Introduction
-Please see the [repository overview](../../../../README.md) as well as the [task description](../../TASK_DETAILS.md)
-before reading this report. The theoretical details of the utilized algorithms can be found in the [repository overview](../../../../README.md).
+Please see the [repository overview](../../README.md) as well as the [task description](./TASK_DETAILS.md)
+before reading this report. The theoretical details of the utilized algorithms can be found in the [repository overview](../../README.md).
In this environment, two teams (each with a Striker/Goalie agent) compete against each other in the game of soccer. The agents can move laterally
and vertically, and the strikers have the additional action of rotating left/right, resulting in 4 and 6 discrete actions for
diff --git a/tasks/soccer/setup_linux.sh b/tasks/soccer/setup_linux.sh
index aeafa5c..1c6def9 100644
--- a/tasks/soccer/setup_linux.sh
+++ b/tasks/soccer/setup_linux.sh
@@ -1,9 +1,9 @@
-# Execute this script from the /tasks/reacher_continuous_control directory
+# Execute this script from the /tasks/soccer directory
# bash ./setup_linux.sh
mkdir -p environments
-# Download the reacher environment
+# Download the soccer environment
wget https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer_Linux.zip --no-check-certificate
unzip Soccer_Linux.zip && mv Soccer_Linux environments/ && rm Soccer_Linux.zip
diff --git a/tasks/tennis/REPORT.md b/tasks/tennis/REPORT.md
index da5bc9e..93c2433 100644
--- a/tasks/tennis/REPORT.md
+++ b/tasks/tennis/REPORT.md
@@ -1,10 +1,9 @@
-[trained_tennis_gif]: https://user-images.githubusercontent.com/10624937/42135623-e770e354-7d12-11e8-998d-29fc74429ca2.gif "Trained Agent"
[mappo_results_image]: solutions/mappo/solution_checkpoint/mappo_training_scores.png "MAPPO Training"
[matd3_results_image]: solutions/maddpg/solution_checkpoint/independent_madtd3_training_scores.png "MATD3 Training"
# Tennis MAPPO/MATD3 Introduction
-Please see the [repository overview](../../../../README.md) as well as the [task description](./TASK_DETAILS.md)
-before reading this report. The theoretical details of the utilized algorithms can be found in the [repository overview](../../../../README.md).
+Please see the [repository overview](../../README.md) as well as the [task description](./TASK_DETAILS.md)
+before reading this report. The theoretical details of the utilized algorithms can be found in the [repository overview](../../README.md).
In this environment, two agents control rackets to bounce a ball over a net. If an agent hits the ball over the net, it receives a reward of +0.1. If an agent lets a ball hit the ground or hits the ball out of bounds, it receives a reward of -0.01. Thus, the goal of each agent is to keep the ball in play.
@@ -21,7 +20,7 @@ The unity environment consists of 2 agents which have separate brains (models/op
but can observe the states and actions of the other agents and use this information during training time.
-![Trained Agent][trained_tennis_gif]
+
# Solution Overview
@@ -171,13 +170,13 @@ a score of >0.5 in ~ 2700 episodes (15 minutes), and a score of > 1 in about 320
![Training MATD3 Agent][matd3_results_image]
-##### Discussion
+## Discussion
The MAPPO algorithm converged *significantly* faster than the MATD3 algorithm, achieving a score of >1 about 33x faster
than the MAPPO algorithm (20 minutes vs. 11 hours). It should be noted, though, that hyper-parameter tuning
(especially on the MATD3 algorithm) was not conducted due to the long training duration. Overall, this result demonstrates
the robustness of the PPO algorithm to a wide range of tasks.
-The MAPPO algorithm, beign on-policy, is shown to be relatively sample inefficient compared to off-policy algorithms such as
+The MAPPO algorithm, being on-policy, is shown to be relatively sample inefficient compared to off-policy algorithms such as
MATD3, where MAPPO achieved a score of > 1 after 3200 episodes compared to 800 episodes by MATD3. The MATD3 algorithm takes
advantage of prioritized experience replay (PER) to sample experience based on the amount of information the experience provides, while
the MAPPO algorithm has no such intelligent memory buffer and simply discards trajectories of experience after a few learning epochs.
diff --git a/tasks/tennis/TASK_DETAILS.md b/tasks/tennis/TASK_DETAILS.md
index 0689e34..689ddba 100644
--- a/tasks/tennis/TASK_DETAILS.md
+++ b/tasks/tennis/TASK_DETAILS.md
@@ -1,12 +1,8 @@
-[//]: # (Image References)
-
-[image1]: https://user-images.githubusercontent.com/10624937/42135623-e770e354-7d12-11e8-998d-29fc74429ca2.gif "Trained Agent"
-
# Multi Agent Tenis
### Introduction
-![Trained Agent][image1]
+
In this environment, two agents control rackets to bounce a ball over a net. If an agent hits the ball over the net, it receives a reward of +0.1. If an agent lets a ball hit the ground or hits the ball out of bounds, it receives a reward of -0.01. Thus, the goal of each agent is to keep the ball in play.
diff --git a/tasks/tennis/solutions/mappo/eval_mappo.py b/tasks/tennis/solutions/mappo/eval_mappo.py
index c02f3cc..720bd58 100644
--- a/tasks/tennis/solutions/mappo/eval_mappo.py
+++ b/tasks/tennis/solutions/mappo/eval_mappo.py
@@ -18,7 +18,6 @@
# Load the agents
brain_set['TennisBrain'].agents[0].target_actor_critic.load_state_dict(torch.load(SAVED_AGENT_0_FP))
- brain_set['TennisBrain'].agents[1].target_actor_critic.load_state_dict(torch.load(SAVED_AGENT_1_FP))
brain_set, average_score = simulator.evaluate(
brain_set,