diff --git a/waymax/dynamics/abstract_dynamics.py b/waymax/dynamics/abstract_dynamics.py index 7ef57d9..08f1450 100644 --- a/waymax/dynamics/abstract_dynamics.py +++ b/waymax/dynamics/abstract_dynamics.py @@ -140,6 +140,7 @@ def apply_trajectory_update_to_state( is_controlled: jax.Array, timestep: int, allow_object_injection: bool = False, + use_fallback: bool = False, ) -> datatypes.Trajectory: """Applies a TrajectoryUpdate to the sim trajectory at the next timestep. @@ -150,7 +151,7 @@ def apply_trajectory_update_to_state( For objects not in is_controlled, reference_trajectory is used. For objects in is_controlled, but not valid in trajectory_update, fall back to - constant speed behaviour. + constant speed behaviour if the use_fallback flag is on. Args: trajectory_update: Updated trajectory fields for all objects after the @@ -168,6 +169,8 @@ def apply_trajectory_update_to_state( allow_object_injection: Whether to allow new objects to enter the scene. If this is set to False, all objects that are not valid at the current timestep will not be valid at the next timestep and visa versa. + use_fallback: Whether to fall back to constant speed if a controlled agent + is given an invalid action. Otherwise, the agent will be invalidated. Returns: Updated trajectory given update from a dynamics model at `timestep` + 1. @@ -198,19 +201,25 @@ def apply_trajectory_update_to_state( # Some fields such as the length, width and height of objects are set to be # the same for every timestep during data loading and so we don't update these # from the current trajectory. - # TODO: Update z using the (x, y) coordinates of the vehicle. + # Update z using the (x, y) coordinates of the vehicle. replacement_dict = {} for field in CONTROLLABLE_FIELDS: - # Use fallback trajectory if user doesn't not provide valid action. - new_value = jnp.where( - trajectory_update.valid, - trajectory_update[field], - fallback_trajectory[field], - ) - # Only update for is_controlled objects from users. - replacement_dict[field] = jnp.where( - is_controlled, new_value, default_next_traj[field] - ) + if use_fallback: + # Use fallback trajectory if user doesn't not provide valid action. + new_value = jnp.where( + trajectory_update.valid, + trajectory_update[field], + fallback_trajectory[field], + ) + # Only update for is_controlled objects from users. + replacement_dict[field] = jnp.where( + is_controlled, new_value, default_next_traj[field] + ) + else: + new_value = jnp.where( + is_controlled, trajectory_update[field], default_next_traj[field] + ) + replacement_dict[field] = new_value exist_and_controlled = is_controlled & current_traj.valid # For exist_and_controlled objects, valid flags should remain the same as