Accompanying code for the paper Coprocessor Actor Critic: A Model-Based Reinforcement Learning Approach For Adaptive Brain Stimulation.
Use the following commands to create a Conda environment with the required packages:
PIP_NO_DEPS=1 conda env create -f environment.yml
conda activate coproc
pip install -e .
pip install -e myosuite
To train a healthy brain policy, run train/train_brain.py
.
python train/train_brain.py --gym_env myoHandReachFixed-v0 --brain michaels --timesteps 500000
In our paper, we use Michaels model brains for MyoSuite environments and SAC for other environments.
To train a coprocessor using CopAC, first train an optimal policy with train/train_sac_env.py
.
python train/train_sac_env.py --gym_env myoHandReachFixed-v0 --timesteps 5000000
Then run train/train_action_conv.py
.
python train/train_action_conv.py --gym_env myoHandReachFixed-v0 --brain michaels --region M1 --pct_lesion 0.9 --stim_dim 2 -action_conv qmax --num_q_update_traj 5
To run without Q-updates, run with --num_q_update_traj 0
. To run without both Q-updates and Q-max, use --action_conv random
.
The following command trains a SAC coprocessor:
python train/train_sac_coproc.py --gym_env myoHandReachFixed-v0 --brain michaels --region M1 --pct_lesion 0.9 --stim_dim 2
Due to dependency incompatibilities, MBPO must be run in a separate environment. First set up the environment:
conda create -n coproc-mbrl python=3.9
conda activate coproc-mbrl
sh install_mbrl_deps.sh
Then train the MBPO coprocessor:
python -m mbrl-lib.mbrl.examples.main algorithm=mbpo overrides=mbpo_coproc_myo-hand overrides.env_cfg.pct_lesion=0.9 overrides.env_cfg.stim_dim=2
To learn the optimal policy from offline data, first collect healthy brain rollouts.
python coprocessors/utils/collect_offline_data.py --gym_env myoHandReachFixed-v0 --brain michaels --data_size 1000
Then train the policy.
python train/train_offline.py --env myoHandReachFixed-v0 --data_size 1000 --episodes 5000
Finally, run CoPAC with the offline policy.
python train/train_action_conv.py --gym_env myoHandReachFixed-v0 --brain michaels --region M1 --pct_lesion 0.9 --stim_dim 2 --action_conv qmax_offline