Skip to content

Commit

Permalink
Undo fallback in abstract_dynamics.py
Browse files Browse the repository at this point in the history
  • Loading branch information
justinjfu authored Mar 14, 2024
1 parent cc5d1af commit 715a3a2
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions waymax/dynamics/abstract_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def apply_trajectory_update_to_state(
is_controlled: jax.Array,
timestep: int,
allow_object_injection: bool = False,
use_fallback: bool = False,
) -> datatypes.Trajectory:
"""Applies a TrajectoryUpdate to the sim trajectory at the next timestep.
Expand All @@ -150,7 +151,7 @@ def apply_trajectory_update_to_state(
For objects not in is_controlled, reference_trajectory is used.
For objects in is_controlled, but not valid in trajectory_update, fall back to
constant speed behaviour.
constant speed behaviour if the use_fallback flag is on.
Args:
trajectory_update: Updated trajectory fields for all objects after the
Expand All @@ -168,6 +169,8 @@ def apply_trajectory_update_to_state(
allow_object_injection: Whether to allow new objects to enter the scene. If
this is set to False, all objects that are not valid at the current
timestep will not be valid at the next timestep and visa versa.
use_fallback: Whether to fall back to constant speed if a controlled agent
is given an invalid action. Otherwise, the agent will be invalidated.
Returns:
Updated trajectory given update from a dynamics model at `timestep` + 1.
Expand Down Expand Up @@ -198,19 +201,25 @@ def apply_trajectory_update_to_state(
# Some fields such as the length, width and height of objects are set to be
# the same for every timestep during data loading and so we don't update these
# from the current trajectory.
# TODO: Update z using the (x, y) coordinates of the vehicle.
# Update z using the (x, y) coordinates of the vehicle.
replacement_dict = {}
for field in CONTROLLABLE_FIELDS:
# Use fallback trajectory if user doesn't not provide valid action.
new_value = jnp.where(
trajectory_update.valid,
trajectory_update[field],
fallback_trajectory[field],
)
# Only update for is_controlled objects from users.
replacement_dict[field] = jnp.where(
is_controlled, new_value, default_next_traj[field]
)
if use_fallback:
# Use fallback trajectory if user doesn't not provide valid action.
new_value = jnp.where(
trajectory_update.valid,
trajectory_update[field],
fallback_trajectory[field],
)
# Only update for is_controlled objects from users.
replacement_dict[field] = jnp.where(
is_controlled, new_value, default_next_traj[field]
)
else:
new_value = jnp.where(
is_controlled, trajectory_update[field], default_next_traj[field]
)
replacement_dict[field] = new_value

exist_and_controlled = is_controlled & current_traj.valid
# For exist_and_controlled objects, valid flags should remain the same as
Expand Down

0 comments on commit 715a3a2

Please sign in to comment.