Figure 1: Conservative zero-shot RL methods suppress the values or measures on actions not in the dataset for all tasks. Black dots represent state-action samples present in the dataset.
The is the official codebase for Zero-Shot Reinforcement Learning from Low Quality Data by Scott Jeen, Tom Bewley and Jonathan Cullen.
This work proposes methods for performing zero-shot RL when the pre-training datasets are small and homogeneous. We show that by suppressing the predicted values (or measures) for actions not in the dataset (Figure 1), we can resolve an overestimation bias that arises when the dataset is inexhaustive. We demonstrate this on the ExORL (Figure 2) and D4RL (Figure 3) benchmarks, showing improved performance over existing works.
Figure 2: Aggregate ExORL performance. (Left) Normalised average performance w.r.t. single-task baseline algorithm CQL. (Right) Performance profiles showing distribution of scores across all tasks and domains. Both conservative FB variants stochastically dominate vanilla FB.
Figure 3: Aggregate D4RL performance. Normalised average performance w.r.t. single-task baseline algorithm CQL.
For further detail check out the paper. Direct any correspondance to Scott Jeen or raise an issue!
Assuming you have MuJoCo installed, setup a conda env with Python 3.9.16 using requirements.txt
as usual:
conda create --name zsrl python=3.9.16
then install the dependencies from requirements.txt
:
pip install -r requirements.txt
We provide implementations of the following algorithms:
Algorithm | Authors | Type | Command Line Argument |
---|---|---|---|
Conservative |
Kumar et. al (2020) | Single-task Offline RL | cql |
Offline TD3 | Fujimoto et. al (2021) | Single-task Offline RL | td3 |
Goal-conditioned Implicit |
Park et. al (2023) | Goal-conditioned RL | gciql |
Universal Successor Features learned with Laplacian Eigenfunctions (SF-LAP) | Borsa et. al (2018) | Zero-shot RL | sf-lap |
FB Representations | Touati et. al (2023) | Zero-shot RL | fb |
Value-Conservative FB Representations | Jeen et. al (2024) | Zero-shot RL | vcfb |
Measure-Conservative FB Representations | Jeen et. al (2024) | Zero-shot RL | mcfb |
In the paper we report results with agents trained on datasets collected from different exploratory algorithms on different domains. The domains are:
Domain | Eval Tasks | Dimensionality | Type | Reward | Command Line Argument |
---|---|---|---|---|---|
Walker | stand walk run flip |
Low | Locomotion | Dense | walker |
Quadruped | stand roll roll_fast jump escape |
High | Locomotion | Dense | quadruped |
Point-mass Maze | reach_top_left reach_top_right reach_bottom_left reach_bottom_right |
Low | Goal-reaching | Sparse | point_mass_maze |
Jaco | reach_top_left reach_top_right reach_bottom_left reach_bottom_right |
High | Goal-reaching | Sparse | jaco |
and the dataset collecting algorithms are:
Dataset Collecting Algorithm | State Coverage | Command Line Argument |
---|---|---|
Random Network Distillation (RND) | High | rnd |
Diversity is All You Need (DIAYN) | Medium | diayn |
Random | Low | random |
State coverage illustrations on point_mass_maze
are provided in Figure 4. For each domain, datasets need to be downloaded manually from the ExORL benchmark then reformatted.
To download the rnd
dataset on the walker
domain, seperate their command line args with an _
and run:
python exorl_reformatter.py walker_rnd
this will create a single dataset.npz
file in the dataset/walker/rnd/buffer
directory.
Figure 4: State coverage by dataset on point_mass_maze
.
To train a standard Value-Conservative Forward Backward Representation with the rnd
(100k) dataset to solve all tasks in the walker
domain, run:
python main_exorl.py vcfb walker rnd --eval_task stand run walk flip
In the paper we report results on:
Domain | Command Line Argument |
---|---|
Walker | walker |
Cheetah | cheetah |
trained on the following datasets:
Datasets | Description | Command Line Argument |
---|---|---|
Medium | Generated by training an SAC policy, early-stopping the training, and collecting 1M samples from this partially-trained policy | medium |
Medium-replay | Generated by recording all samples in the replay buffer observed during training until the policy reaches the “medium” level of performance. | medium-replay |
Medium-expert | Generated by mixing equal amounts of expert demonstrations and suboptimal data, either from a partially trained policy or by unrolling a uniform-at-random policy. | medium-expert |
You'll need to manually download the D4RL datasets following their instructions, rename them to dataset.hdf5
and place them in the correct directory inside /datasets
e.g. the walker
medium-expert
dataset should be saved to datasets/walker/medium-expert/dataset.hdf5
.
To train a standard Value-Conservative Forward Backward Representation with the medium-expert
dataset on walker
, run:
python main_d4rl.py vcfb walker medium-expert
If you find this work informative please consider citing the paper!
@article{jeen2024,
url = {https://arxiv.org/abs/2309.15178},
author = {Jeen, Scott and Bewley, Tom and Cullen, Jonathan M.},
title = {Zero-Shot Reinforcement Learning from Low Quality Data},
journal = {Advances in Neural Information Processing Systems 38},
year = {2024},
}
This work licensed under a standard MIT License, see LICENSE.md
for further details.