Code for NeurIPS 2023 paper Rewiring Neurons in Non-Stationary Environments.
- Make sure you have PyTorch and JAX installed with CUDA support.
- Install SaLinA and Continual World following their instructions. Note that the latter only supports MuJoCo version 2.0.
- Install additional packages via
pip install -r requirements.txt
.
Simply run the file run.py
with the desired config available in configs:
python run.py -cn=METHOD scenario=SCENARIO OPTIONAL_CONFIGS
Expand
We present 9 different CRL methods all built on top of soft-actor critic algorithm. To try them, just add the flag -cn=my_method
on the command line. You can find the hyperparameters in configs:
rewire
: our method in "Rewiring Neurons in Non-Stationary Environments".ft_1
: Fine-tune a single policy during the whole training.sac_n
: Fine-tune and save the policy at the end of the task. Start with a randomized policy when encountering a new task.ft_n
: Fine-tune and save the policy at the end of the task. Clone the last policy when encountering a new task.ft_l2
: Fine-tune a single policy during the whole training with a regularization cost (a simpler EWC method).ewc
: see the paper Overcoming catastrophic forgetting in neural networks.pnn
: see the paper Progressive Neural Networks.packnet
: see the paper PackNet: Adding Multiple Tasks to a Single Network by Iterative Pruning.csp
: see the paper Building a Subspace of Policies for Scalable Continual Learning.
Expand
We integrate 9 CRL scenarios over 3 different Brax domains and 2 scenarios of the Continual World domain. To try them, just add the flag scenario=...
on the command line:
halfcheetah/forgetting
: 8 tasks - 1M samples for each task.halfcheetah/transfer
: 8 tasks - 1M samples for each task.halfcheetah/robustness
: 8 tasks - 1M samples for each task.halfcheetah/compositionality
: 8 tasks - 1M samples for each task.ant/forgetting
: 8 tasks - 1M samples for each task.ant/transfer
: 8 tasks - 1M samples for each task.ant/robustness
: 8 tasks - 1M samples for each task.ant/compositionality
: 8 tasks - 1M samples for each task.humanoid/hard
: 4 tasks - 2M samples for each task.continual_world/t1-t8
: 8 triplets of 3 tasks - 1M samples for each task.continual_world/cw10
: 10 tasks - 1M samples for each task.
Expand
The core.py
file contains the building blocks of this framework. Each experiment consists in running a Framework
over a Scenario
, i.e. a sequence of train and test Task
. The models are learning procedures that use CRL agents to interact with the tasks and learn from them through one or multiple algorithms.
- frameworks contains generic learning procedures (e.g. using only one algorithm, or adding a regularization method in the end).
- scenarios contains CRL scenarios i.e sequence of train and test tasks.
- algorithms contains different RL / CL algorithms (e.g. SAC, or EWC).
- agents contains CRL agents (e.g. PackNet, CSP, or Rewire).
- configs contains the configs files of benchmarked methods/scenarios.
Our implementation is based on: