-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 693736b
Showing
179 changed files
with
66,784 additions
and
0 deletions.
There are no files selected for viewing
177 changes: 177 additions & 0 deletions
177
_sources/autoapi/waymax/agents/actor_core/index.rst.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
:py:mod:`waymax.agents.actor_core` | ||
================================== | ||
|
||
.. py:module:: waymax.agents.actor_core | ||
.. autoapi-nested-parse:: | ||
|
||
Abstract definition of a Waymax actor for use at inference-time. | ||
|
||
|
||
|
||
Module Contents | ||
--------------- | ||
|
||
Classes | ||
~~~~~~~ | ||
|
||
.. autoapisummary:: | ||
|
||
waymax.agents.actor_core.WaymaxActorOutput | ||
waymax.agents.actor_core.WaymaxActorCore | ||
|
||
|
||
|
||
Functions | ||
~~~~~~~~~ | ||
|
||
.. autoapisummary:: | ||
|
||
waymax.agents.actor_core.register_actor_core | ||
waymax.agents.actor_core.actor_core_factory | ||
waymax.agents.actor_core.merge_actions | ||
|
||
|
||
|
||
Attributes | ||
~~~~~~~~~~ | ||
|
||
.. autoapisummary:: | ||
|
||
waymax.agents.actor_core.ActorState | ||
waymax.agents.actor_core.Params | ||
waymax.agents.actor_core._ActorCore | ||
|
||
|
||
.. py:data:: ActorState | ||
.. py:data:: Params | ||
.. py:class:: WaymaxActorOutput | ||
Output of the Waymax actor including an action and its internal state. | ||
|
||
.. attribute:: actor_state | ||
|
||
Internal state for whatever the agent needs to keep as its | ||
state. This can be recurrent embeddings or accounting information. | ||
|
||
.. attribute:: action | ||
|
||
Action of shape (..., num_objects) predicted by the Waymax actor at | ||
the most recent simulation step given the inputs in the `select_action` | ||
function of `WaymaxActorCore`. | ||
|
||
.. attribute:: is_controlled | ||
|
||
A binary indicator of shape (..., num_objects) representing | ||
which objects are controlled by the actor. | ||
|
||
.. py:attribute:: actor_state | ||
:type: ActorState | ||
|
||
|
||
|
||
.. py:attribute:: action | ||
:type: waymax.datatypes.Action | ||
|
||
|
||
|
||
.. py:attribute:: is_controlled | ||
:type: jax.Array | ||
|
||
|
||
|
||
.. py:method:: validate() | ||
Validates shapes. | ||
|
||
|
||
|
||
.. py:class:: WaymaxActorCore | ||
Bases: :py:obj:`abc.ABC` | ||
|
||
Interface that defines actor functionality for inference. | ||
|
||
.. py:property:: name | ||
:type: str | ||
:abstractmethod: | ||
|
||
Name of the agent used for inspection and logging. | ||
|
||
.. py:method:: init(rng: jax.Array, state: waymax.datatypes.SimulatorState) -> ActorState | ||
:abstractmethod: | ||
|
||
Initializes the actor's internal state. | ||
|
||
ActorState is a generic type which can contain anything that the agent | ||
needs to pass through to the next call, e.g. for recurrent state or | ||
batch normalization. The `init` function takes a random key to help | ||
randomize initialization and the initial timestep. | ||
|
||
:param rng: A random key. | ||
:param state: The initial simulator state. | ||
|
||
:returns: The actor's initial state. | ||
|
||
|
||
.. py:method:: select_action(params: Params, state: waymax.datatypes.SimulatorState, actor_state: ActorState, rng: jax.Array) -> WaymaxActorOutput | ||
:abstractmethod: | ||
|
||
Selects an action given the current simulator state. | ||
|
||
:param params: Actor parameters, e.g. neural network weights. | ||
:param state: The current simulator state. | ||
:param actor_state: The actor state, e.g. recurrent state or batch normalization. | ||
:param rng: A random key. | ||
|
||
:returns: An actor output containing the next action and actor state. | ||
|
||
|
||
|
||
.. py:data:: _ActorCore | ||
.. py:function:: register_actor_core(actor_core_cls: type[_ActorCore]) -> type[_ActorCore] | ||
Registers an ActorCore class as a PyTree node. | ||
|
||
|
||
.. py:function:: actor_core_factory(init: Callable[[jax.Array, waymax.datatypes.SimulatorState], ActorState], select_action: Callable[[Params, waymax.datatypes.SimulatorState, ActorState, jax.Array], WaymaxActorOutput], name: str = 'WaymaxActorCore') -> WaymaxActorCore | ||
Creates a WaymaxActorCore from pure functions. | ||
|
||
:param init: A function that initializes the actor's internal state. This is a | ||
generic type which can contain anything that the agent needs to pass | ||
through to the next call. The `init` function takes a random key to help | ||
randomize initialization and the initial timestep. It should return its | ||
specific internal state. | ||
:param select_action: A function that selects an action given the current simulator | ||
state of the environment, the previous actor state and an optional random | ||
key. Returns the action and the updated internal actor state. | ||
:param name: Name of the agent used for inspection and logging. | ||
|
||
:returns: An actor core instance defined by init and select_action. | ||
|
||
|
||
.. py:function:: merge_actions(actor_outputs: Sequence[WaymaxActorOutput]) -> waymax.datatypes.Action | ||
Combines multiple actor_outputs into one action instance. | ||
|
||
:param actor_outputs: A sequence of WaymaxActorOutput to be combined, each | ||
corresponds to a different actor. Note different actor should not be | ||
controlling the same object (i.e. is_controlled flags from different | ||
actors should be disjoint). Note all actors must use the same dynamics | ||
model. | ||
|
||
:returns: An `Action` consists of information from all actor outputs. | ||
|
||
|
34 changes: 34 additions & 0 deletions
34
_sources/autoapi/waymax/agents/agent_builder/index.rst.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
:py:mod:`waymax.agents.agent_builder` | ||
===================================== | ||
|
||
.. py:module:: waymax.agents.agent_builder | ||
.. autoapi-nested-parse:: | ||
|
||
Waymax sim agent builder functions. | ||
|
||
|
||
|
||
Module Contents | ||
--------------- | ||
|
||
|
||
Functions | ||
~~~~~~~~~ | ||
|
||
.. autoapisummary:: | ||
|
||
waymax.agents.agent_builder.create_sim_agents_from_config | ||
|
||
|
||
|
||
.. py:function:: create_sim_agents_from_config(config: waymax.config.SimAgentConfig) -> waymax.agents.actor_core.WaymaxActorCore | ||
Constructs sim agent WaymaxActorCore objects from a config. | ||
|
||
:param config: Waymax sim agent config specifying agent type and controlled | ||
objects' type. | ||
|
||
:returns: Constructed sim agents. | ||
|
||
|
68 changes: 68 additions & 0 deletions
68
_sources/autoapi/waymax/agents/constant_speed/index.rst.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
:py:mod:`waymax.agents.constant_speed` | ||
====================================== | ||
|
||
.. py:module:: waymax.agents.constant_speed | ||
.. autoapi-nested-parse:: | ||
|
||
Constant speed agents. | ||
|
||
|
||
|
||
Module Contents | ||
--------------- | ||
|
||
Classes | ||
~~~~~~~ | ||
|
||
.. autoapisummary:: | ||
|
||
waymax.agents.constant_speed.ConstantSpeedPolicy | ||
|
||
|
||
|
||
Functions | ||
~~~~~~~~~ | ||
|
||
.. autoapisummary:: | ||
|
||
waymax.agents.constant_speed.create_constant_speed_actor | ||
|
||
|
||
|
||
.. py:function:: create_constant_speed_actor(dynamics_model: waymax.dynamics.DynamicsModel, is_controlled_func: Callable[[waymax.datatypes.SimulatorState], jax.Array], speed: Optional[float] = None) -> waymax.agents.actor_core.WaymaxActorCore | ||
Creates an actor with constant speed without changing objects' heading. | ||
|
||
Note the difference against ConstantSpeedPolicy is that an actor requires | ||
input of a dynamics model, while a policy does not (it assumes to use | ||
StateDynamics). | ||
|
||
:param dynamics_model: The dynamics model the actor is using that defines the | ||
action output by the actor. | ||
:param is_controlled_func: Defines which objects are controlled by this actor. | ||
:param speed: Speed of the actor, if None, speed from previous step is used. | ||
|
||
:returns: An statelss actor that drives the controlled objects with constant speed. | ||
|
||
|
||
.. py:class:: ConstantSpeedPolicy(speed: float = 0.0) | ||
Bases: :py:obj:`waymax.agents.waypoint_following_agent.WaypointFollowingPolicy` | ||
|
||
A policy that maintains a constant speed for all sim agents. | ||
|
||
.. py:method:: update_speed(state: waymax.datatypes.SimulatorState, dt: float = 0.1) -> tuple[jax.Array, jax.Array] | ||
Sets the speed for each agent in the current sim step to a constant. | ||
|
||
:param state: The simulator state of shape (...). | ||
:param dt: Delta between timesteps of the simulator state. | ||
|
||
:returns: A (..., num_objects) float array of constant speeds. | ||
valids: A (..., num_objects) bool array of valids. | ||
:rtype: speeds | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
:py:mod:`waymax.agents.expert` | ||
============================== | ||
|
||
.. py:module:: waymax.agents.expert | ||
.. autoapi-nested-parse:: | ||
|
||
Expert actor which returns an action inferred from logged data. | ||
|
||
|
||
|
||
Module Contents | ||
--------------- | ||
|
||
|
||
Functions | ||
~~~~~~~~~ | ||
|
||
.. autoapisummary:: | ||
|
||
waymax.agents.expert.infer_expert_action | ||
waymax.agents.expert.create_expert_actor | ||
|
||
|
||
|
||
Attributes | ||
~~~~~~~~~~ | ||
|
||
.. autoapisummary:: | ||
|
||
waymax.agents.expert._EXPERT_NAME | ||
waymax.agents.expert._IS_SDC_FUNC | ||
|
||
|
||
.. py:data:: _EXPERT_NAME | ||
:value: 'expert' | ||
|
||
|
||
|
||
.. py:data:: _IS_SDC_FUNC | ||
.. py:function:: infer_expert_action(simulator_state: waymax.datatypes.SimulatorState, dynamics_model: waymax.dynamics.DynamicsModel) -> waymax.datatypes.Action | ||
Infers an action from sim_traj[timestep] to log_traj[timestep + 1]. | ||
|
||
:param simulator_state: State of the simulator at the current timestep. Will use | ||
the `sim_trajectory` and `log_trajectory` fields to calculate an action. | ||
:param dynamics_model: Dynamics model whose `inverse` function will be used to | ||
infer the expert action given the logged states. | ||
|
||
:returns: | ||
|
||
Action that will take the agent from sim_traj[timestep] to | ||
log_traj[timestep + 1]. | ||
|
||
|
||
.. py:function:: create_expert_actor(dynamics_model: waymax.dynamics.DynamicsModel, is_controlled_func: Callable[[waymax.datatypes.SimulatorState], jax.Array] = _IS_SDC_FUNC) -> waymax.agents.actor_core.WaymaxActorCore | ||
Creates an expert agent using the WaymaxActorCore interface. | ||
|
||
This agent infers an action from the `expert` by inferring an action using | ||
the logged data. It does this by calling the `inverse` function on the passed | ||
in `dynamics` parameter. It will return an action in the format returned by | ||
the `dynamics` parameter. | ||
|
||
:param dynamics_model: Dynamics model whose `inverse` function will be used to | ||
infer the expert action given the logged states. | ||
:param is_controlled_func: A function that maps state to a controlled objects mask | ||
of shape (..., num_objects). | ||
|
||
:returns: A Stateless Waymax actor which returns an `expert` action for all controlled | ||
objects (defined by is_controlled_func) by inferring the best-fit action | ||
given the logged state. | ||
|
||
|
Oops, something went wrong.