JAX implementations of various deep reinforcement learning algorithms.
Main libraries used:
- JAX - main framework
- Haiku - neural networks
- Optax - gradient based optimisation
Algorithms | Paper |
---|---|
Proximal Policy Optimization (PPO) | https://arxiv.org/abs/1707.06347 |
Deep Q-Network (DQN) | https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf |
Double Deep Q-Network (DDQN) | https://arxiv.org/abs/1509.06461 |
Deep Recurrent Q-Network (DRQN) | https://arxiv.org/abs/1507.06527 |
Deep Deterministic Policy Gradient (DDPG) | https://arxiv.org/abs/1509.02971 |
- Q-learning
- Double Q-learning
- SARSA
- Expected SARSA
$ pip install git+https://github.com/hamishs/JAX-RL