To reproduce and explore the results from DeepMind's paper on Population Based Training of Neural Networks.
PBT is an optimization algorithm that maximizes the performance of a network by optimizating a population of models and their hyperparameters. It determines a schedule of hyperparameter settings using an evolutionary strategy of exploration and exploitation - a much more powerful method than simply using a fixed set of hyperparameters throughout the entire training or using grid-search and hand-tuning, which is time-extensive and difficult.
It is recommended to run from a virtual environment to ensure all dependencies are met.
virtualenv -p python3 pbt_env
source pbt_env/bin/activate.csh
pip3 install -r requirements.txt
The toy example was reproduced from fig. 2 in the paper (pg. 6). The idea is to maximize an unknown quadratic equation Q = 1.2 - w1^2 - w2^2
, given a surrogate function Q_hat = 1.2 - h1 w1^2 - h2 w2^2
, where h1
and h2
are hyperparameters and w1
and w2
are weights. Training begins with a Population
, consisting of a set of Workers
each with their own weights and hyperparameters. During exploration, the hyperparameters are perturbed by gaussian noise, and during exploitation, a Worker
inherits the weights of the best Worker
in the population. As per the paper, only two
workers were used.
The reproduced plots are seen below: Some key observations:
- Theta Plots
- In Exploit only, the intersection of the workers represents the inheritance of best weights from one worker to the other; this occurs every
10
steps (set by the user) - In Explore only, we don't see any intersections. Each point follows closely from the last from random perturbations and gradient descent steps
- In PBT, we see the combination of the aformentioned effects
- In Exploit only, the intersection of the workers represents the inheritance of best weights from one worker to the other; this occurs every
Q
Plots- The Grid search, plot never converges to
1.2
due to bad initialization. As the hyperparameters are fixed during the entire training,Worker1
withh=[1 0]
andWorker2
withh=[0 1]
, the surrogate function will never converge to the real function withh=[1 1]
. This illustrates the shortcomings of grid-search, which can limit the generalization capabilities of a model (especically with bad initializations).
- The Grid search, plot never converges to
./pbt.py
or ./toy_example.py
pbt.py
was the original implementation of the toy example, but much complexity has been added to it to support other scripts. For a clean implementation of the toy example, please read toy_example.py
.
general_pbt.py
implements pbd fully asynchronously, where Workers
work in parallel and interact via shared memory. The below plots illustrate the effect of population size on Q
(objective function), loss
, and theta
.
Population sizes of 1, 2, 4, 8, 16, and 32 were used, and the best performing worker from each population was graphed (see the legend for the color scheme).
- Generally, the more workers used, the faster the population converges to
Q
- The benefits of adding more workers tends to tail off, as each subsequent increase in population size introduces less performance benefits than the previous (
2
workers is a lot better than1
, but16
is only marginally better than8
) - The jumps in the green plot represent exploration and exploitation; there are no jumps in the blue plot as there's no concept of exploitation for
1
worker (but we can see exploration if we look close enough) - Generally, "lines" corresponding to larger population sizes are shorter; that's because the more workers, the faster it finds the optimal
theta
value
pbtv2_tf.py
is a distributed tensorflow implementation of the toy example. To run, you may either start them manually on different terminals:
python3 pbtv2_tf.py --ps_hosts=localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=ps --task_index=0
python3 pbtv2_tf.py --ps_hosts=localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=worker --task_index=0
python3 pbtv2_tf.py --ps_hosts=localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=worker --task_index=1
...
or use the wrapper file pbt_wrapper.py
where size
is the population size:
python3 pbt_wrapper.py --size 20 --task toy
mueller_tf.py
optimizes the mueller potential from here.
python3 pbt_wrapper.py --size 40 --task mueller
tensorboard --logdir=logs
Check out tensorboard/logs
for my visualization plots.
- Try different exploration and exploitation methods (e.g truncation)
- How does the learning rate decay in adam affect PBT's own learning rate exploration / exploitation?
- Bug: fix cases where workers end up with "nan" weights (due to aggressive initialization of hyperparameters e.g -50 to 50 or -20 to 20 for the
exp
model, the loss becomes a very large negative number leading to "nan" backprops). Since "nan" <x
is alwaysFalse
, these workers are dead