Sampling-based model predictive control on GPU with JAX and MuJoCo MJX.
Hydrax implements various sampling-based MPC algorithms on GPU. It is heavily inspired by MJPC, but focuses exclusively on sampling-based algorithms, runs on hardware accelerators via JAX and MJX, and includes support for online domain randomization.
Available methods:
Algorithm | Description | Import |
---|---|---|
Predictive sampling | Take the lowest-cost rollout at each iteration. | hydrax.algs.PredictiveSampling |
MPPI | Take an exponentially weighted average of the rollouts. | hydrax.algs.MPPI |
Cross Entropy Method | Fit a Gaussian distribution to the n best "elite" rollouts. |
hydrax.algs.CEM |
Evosax | Any of the 30+ evolution strategies implemented in evosax . Includes CMA-ES, differential evolution, and many more. |
hydrax.algs.Evosax |
Set up a conda env with cuda support (first time only):
conda env create -f environment.yml
Enter the conda env:
conda activate hydrax
Install the package and dependencies:
pip install -e .
(Optional) set up pre-commit hooks:
pre-commit autoupdate
pre-commit install
(Optional) run unit tests:
pytest
Launch an interactive pendulum swingup simulation with predictive sampling:
python examples/pendulum.py ps
Launch an interactive planar walker simulation (shown above) with MPPI:
python examples/walker mppi
Other demos can be found in the examples
folder.
Hydrax considers optimal control problems of the form
where
To design a new task, you'll need to specify the cost (hydrax.task_base.Task
:
class MyNewTask(Task):
def __init__(self, ...):
# Create or load a mujoco model defining the dynamics (f)
mj_model = ...
super().__init__(mj_model, ...)
def running_cost(self, x: mjx.Data, u: jax.Array) -> float:
# Implement the running cost (l) here
return ...
def terminal_cost(self, x: jax.Array) -> float:
# Implement the terminal cost (phi) here
return ...
The dynamics (mujoco.MjModel
that is passed to the
constructor. Other constructor arguments specify the planning horizon
For the cost, simply implement the running_cost
(terminal_cost
(
See hydrax.tasks
for some example task implementations.
Hydrax considers sampling-based MPC algorithms that follow the following generic structure:
The meaning of the parameters
To implement a new planning algorithm, you'll need to inherit from
hydrax.alg_base.SamplingBasedController
and implement
the four methods shown below:
class MyControlAlgorithm(SamplingBasedController):
def init_params(self) -> Any:
# Initialize the policy parameters (theta).
...
return params
def sample_controls(self, params: Any) -> Tuple[jax.Array, Any]:
# Sample control sequences U from the policy. Return the samples
# and the (updated) parameters.
...
return controls, params
def update_params(self, params: Any, rollouts: Trajectory) -> Any:
# Update the policy parameters (theta) based on the trajectory data
# (costs, controls, observations, etc) stored in the rollouts.
...
return new_params
def get_action(self, params: Any, t: float) -> Any:
# Return the control action applied t seconds into the trajectory.
...
return u
These four methods define a unique sampling-based MPC algorithm. Hydrax takes
care of the rest, including parallelizing rollouts on GPU and collecting the
rollout data in a Trajectory
object.
Note: because of
the way JAX handles randomness,
we assume the PRNG key is stored as one of parameters sample_controls
returns updated parameters along with the control samples
For some examples, take a look at hydrax.algs
.
One benefit of GPU-based simulation is the ability to roll out trajectories with different model parameters in parallel. Such domain randomization can improve robustness and help reduce the sim-to-real gap.
Hydrax provides tools to make online domain randomization easy. In particular,
you can add domain randomization to any task by overriding the
domain_randomize_model
and domain_randomize_data
methods of a given
Task
. For example:
class MyDomainRandomizedTask(Task):
...
def domain_randomize_model(self, rng: jax.Array) -> Dict[str, jax.Array]:
"""Randomize the friction coefficients."""
n_geoms = self.model.geom_friction.shape[0]
multiplier = jax.random.uniform(rng, (n_geoms,), minval=0.5, maxval=2.0)
new_frictions = self.model.geom_friction.at[:, 0].set(
self.model.geom_friction[:, 0] * multiplier
)
return {"geom_friction": new_frictions}
def domain_randomize_data(self, data: mjx.Data, rng: jax.Array) -> Dict[str, jax.Array]:
"""Randomly shift the measured configurations."""
shift = 0.005 * jax.random.normal(rng, (self.model.nq,))
return {"qpos": data.qpos + shift}
These methods return a dictionary of randomized parameters, given a particular
random seed (rng
). Hydrax takes care of the details of applying these
parameters to the model and data, and performing rollouts in parallel.
To use a domain randomized task, you'll need to tell the planner how many random
models to use with the num_randomizations
flag. For example,
task = MyDomainRandomizedTask(...)
ctrl = PredictiveSampling(
task,
num_samples=32,
noise_level=0.1,
num_randomizations=16,
)
sets up a predictive sampling controller that rolls out 32 control sequences across 16 domain randomized models.
The resulting Trajectory
rollouts will have
dimensions (num_randomizations, num_samples, num_time_steps, ...)
.
With domain randomization, we need to somehow aggregate costs across the
different domains. By default, we take the average cost over the randomizations,
similar to domain randomization in reinforcement learning. Other strategies are
available via the RiskStrategy
interface.
For example, to plan using the worst-case maximum cost across randomizations:
from hydrax.risk import WorstCase
...
task = MyDomainRandomizedTask(...)
ctrl = PredictiveSampling(
task,
num_samples=32,
noise_level=0.1,
num_randomizations=16,
risk_strategy=WorstCase(),
)
Available risk strategies:
Strategy | Description | Import |
---|---|---|
Average (default) | Take the expected cost across randomizations. | hydrax.risk.AverageCost |
Worst-case | Take the maximum cost across randomizations. | hydrax.risk.WorstCase |
Best-case | Take the minimum cost across randomizations. | hydrax.risk.BestCase |
Exponential | Take an exponentially weighted average with parameter |
hydrax.risk.ExponentialWeightedAverage |
VaR | Use the Value at Risk (VaR). | hydrax.risk.ValueAtRisk |
CVaR | Use the Conditional Value at Risk (CVaR). | hydrax.risk.ConditionalValueAtRisk |
@misc{kurtz2024hydrax,
title={Hydrax: Sampling-based model predictive control on GPU with JAX and MuJoCo MJX},
author={Kurtz, Vince},
year={2024},
note={https://github.com/vincekurtz/hydrax}
}