diff --git a/mava/configs/arch/anakin.yaml b/mava/configs/arch/anakin.yaml index 110c5033e..1c7041c58 100644 --- a/mava/configs/arch/anakin.yaml +++ b/mava/configs/arch/anakin.yaml @@ -1,4 +1,5 @@ # --- Anakin config --- +architecture_name: anakin # --- Training --- num_envs: 16 # Number of vectorised environments per device. diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml new file mode 100644 index 000000000..278b0592d --- /dev/null +++ b/mava/configs/arch/sebulba.yaml @@ -0,0 +1,26 @@ +# --- Sebulba config --- +architecture_name: sebulba + +# --- Training --- +num_envs: 32 # number of environments per thread. + +# --- Evaluation --- +evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select + # an action which corresponds to the greatest logit. If false, the policy will sample + # from the logits. +num_eval_episodes: 200 # Number of episodes to evaluate per evaluation. +num_evaluation: 200 # Number of evenly spaced evaluations to perform during training. +num_absolute_metric_eval_episodes: 32 # Number of episodes to evaluate the absolute metric (the final evaluation). +absolute_metric: True # Whether the absolute metric should be computed. For more details +# on the absolute metric please see: https://arxiv.org/abs/2209.10485 + +# --- Sebulba devices config --- +n_threads_per_executor: 2 # num of different threads/env batches per actor +actor_device_ids: [0] # ids of actor devices +learner_device_ids: [0] # ids of learner devices +n_learner_accumulate: 1 # Number of envoirnments to accumulate before updating the parameters. This determines the num_envs for learning updates which equals (num_envs * n_learner_accumulate) / len(learner_device_ids). +rollout_queue_size : 5 +# The size of the pipeline queue determines the extent of off-policy training allowed. A larger value permits more off-policy training. +# Too large of a value with too many actors will lead to all of the updates getting wasted in old episodes +# Too small of a value and the utility of having multiple actors is lost. +# A value of 1 with a single actor leads to almost strictly on-policy training. diff --git a/mava/configs/default/ff_ippo_sebulba.yaml b/mava/configs/default/ff_ippo_sebulba.yaml new file mode 100644 index 000000000..cc2b4acae --- /dev/null +++ b/mava/configs/default/ff_ippo_sebulba.yaml @@ -0,0 +1,11 @@ +defaults: + - logger: logger + - arch: sebulba + - system: ppo/ff_ippo + - network: mlp # [mlp, continuous_mlp, cnn] + - env: smac_gym # [rware_gym, lbf_gym, smac_gym] + - _self_ + +hydra: + searchpath: + - file://mava/configs diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml new file mode 100644 index 000000000..f001e0913 --- /dev/null +++ b/mava/configs/env/lbf_gym.yaml @@ -0,0 +1,25 @@ +# ---Environment Configs--- +defaults: + - _self_ + +env_name: LevelBasedForaging # Used for logging purposes. +scenario: + name: lbforaging + task_name: Foraging-8x8-2p-1f-v3 + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. +log_win_rate: False + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True + +kwargs: + max_episode_steps: 100 diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml new file mode 100644 index 000000000..facf7f8d7 --- /dev/null +++ b/mava/configs/env/rware_gym.yaml @@ -0,0 +1,25 @@ +# ---Environment Configs--- +defaults: + - _self_ + +env_name: RobotWarehouse # Used for logging purposes. +scenario: + name: rware + task_name: rware-tiny-2ag-v2 # [rware-tiny-2ag-v2, rware-tiny-4ag-v2, rware-tiny-4ag-easy-v2, rware-small-4ag-v2] + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. +log_win_rate: False + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True + +kwargs: + max_episode_steps: 500 diff --git a/mava/configs/env/smac_gym.yaml b/mava/configs/env/smac_gym.yaml new file mode 100644 index 000000000..1f2f48c89 --- /dev/null +++ b/mava/configs/env/smac_gym.yaml @@ -0,0 +1,25 @@ +# ---Environment Configs--- +defaults: + - _self_ + +env_name: Starcraft # Used for logging purposes. +scenario: + name: smaclite + task_name: smaclite/2s3z-v0 # smaclite/ + ['10m_vs_11m-v0', '27m_vs_30m-v0', '3s5z_vs_3s6z-v0', '2s3z-v0', '3s5z-v0', 'MMM-v0', 'MMM2-v0', '2c_vs_64zg-v0', 'bane_vs_bane-v0', 'corridor-v0', '2s_vs_1sc-v0', '3s_vs_5z-v0'] + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. +log_win_rate: False + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True + +kwargs: + max_episode_steps: 500 diff --git a/mava/evaluator.py b/mava/evaluator.py index ba8132ae8..8e4dd5dee 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp +import numpy as np from chex import Array, PRNGKey from flax.core.frozen_dict import FrozenDict from jax import tree @@ -207,3 +208,115 @@ def eval_act_fn( return action.squeeze(0), {_hidden_state: hidden_state} return eval_act_fn + + +def get_sebulba_eval_fn( + env_maker: Callable, + act_fn: EvalActFn, + config: DictConfig, + np_rng: np.random.Generator, + absolute_metric: bool, +) -> Tuple[EvalFn, Any]: + """Creates a function that can be used to evaluate agents on a given environment. + + Args: + ---- + env: an environment that conforms to the mava environment spec. + act_fn: a function that takes in params, timestep, key and optionally a state + and returns actions and optionally a state (see `EvalActFn`). + config: the system config. + absolute_metric: whether or not this evaluator calculates the absolute_metric. + This determines how many evaluation episodes it does. + """ + n_devices = jax.device_count() + eval_episodes = ( + config.arch.num_absolute_metric_eval_episodes + if absolute_metric + else config.arch.num_eval_episodes + ) + + n_parallel_envs = min(eval_episodes, config.arch.num_envs) + episode_loops = math.ceil(eval_episodes / n_parallel_envs) + env = env_maker(config, n_parallel_envs) + + act_fn = jax.jit( + act_fn, device=jax.devices("cpu")[0] + ) # cpu so that we don't block actors/learners + + # Warnings if num eval episodes is not divisible by num parallel envs. + if eval_episodes % n_parallel_envs != 0: + warnings.warn( + f"Number of evaluation episodes ({eval_episodes}) is not divisible by `num_envs` * " + f"`num_devices` ({n_parallel_envs} * {n_devices}). Some extra evaluations will be " + f"executed. New number of evaluation episodes = {episode_loops * n_parallel_envs}", + stacklevel=2, + ) + + def eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics: + """Evaluates the given params on an environment and returns relevent metrics. + + Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length, + also win rate for environments that support it. + + Returns: Dict[str, Array] - dictionary of metric name to metric values for each episode. + """ + + def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: + """Simulates `num_envs` episodes.""" + + seeds = np_rng.integers(np.iinfo(np.int32).max, size=n_parallel_envs).tolist() + ts = env.reset(seed=seeds) + + timesteps = [ts] + + actor_state = init_act_state + finished_eps = ts.last() + + while not finished_eps.all(): + key, act_key = jax.random.split(key) + action, actor_state = act_fn(params, ts, act_key, actor_state) + cpu_action = jax.device_get(action).swapaxes(0, 1) + ts = env.step(cpu_action) + timesteps.append(ts) + + finished_eps = np.logical_or(finished_eps, ts.last()) + + timesteps = jax.tree.map(lambda *x: np.stack(x), *timesteps) + + metrics = timesteps.extras + if config.env.log_win_rate: + metrics["won_episode"] = timesteps.extras["won_episode"] + + # find the first instance of done to get the metrics at that timestep, we don't + # care about subsequent steps because we only the results from the first episode + done_idx = np.argmax(timesteps.last(), axis=0) + metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) + del metrics["is_terminal_step"] # uneeded for logging + + return key, metrics + + # This loop is important because we don't want too many parallel envs. + # So in evaluation we have num_envs parallel envs and loop enough times + # so that we do at least `eval_episodes` number of episodes. + metrics = [] + for _ in range(episode_loops): + key, metric = _episode(key) + metrics.append(metric) + + metrics: Metrics = jax.tree_map( + lambda *x: np.array(x).reshape(-1), *metrics + ) # flatten metrics + return metrics + + def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics: + """Wrapper around eval function to time it and add in steps per second metric.""" + start_time = time.time() + + metrics = eval_fn(params, key, init_act_state) + + end_time = time.time() + total_timesteps = jnp.sum(metrics["episode_length"]) + metrics["steps_per_second"] = total_timesteps / (end_time - start_time) + return metrics + + return timed_eval_fn, env diff --git a/mava/systems/__init__.py b/mava/systems/__init__.py deleted file mode 100644 index 21db9ec1c..000000000 --- a/mava/systems/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index 9fddad868..da3ff1ebd 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -36,13 +36,13 @@ from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import ( merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index 8728e3c16..1e335e7f9 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -35,9 +35,9 @@ from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index e185e9ea1..f648e12ea 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -49,9 +49,9 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index a28ef36eb..cd422a566 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -49,9 +49,9 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py new file mode 100644 index 000000000..1869ba092 --- /dev/null +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -0,0 +1,764 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import queue +import threading +import warnings +from collections import defaultdict +from queue import Queue +from typing import Any, Dict, List, Sequence, Tuple + +import chex +import hydra +import jax +import jax.debug +import jax.numpy as jnp +import numpy as np +import optax +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from jax import tree +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding, PartitionSpec, Sharding +from numpy.typing import NDArray +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from mava.evaluator import get_sebulba_eval_fn as get_eval_fn +from mava.evaluator import make_ff_eval_act_fn +from mava.networks import FeedForwardActor as Actor +from mava.networks import FeedForwardValueNet as Critic +from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.types import ( + ActorApply, + CriticApply, + ExperimentOutput, + Observation, + SebulbaLearnerFn, +) +from mava.utils import make_env as environments +from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_sebulba_config, check_total_timesteps +from mava.utils.jax_utils import merge_leading_dims, switch_leading_axes +from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime +from mava.utils.training import make_learning_rate +from mava.wrappers.episode_metrics import get_final_step_metrics +from mava.wrappers.gym import GymToJumanji + + +def rollout( + key: chex.PRNGKey, + env: GymToJumanji, + config: DictConfig, + rollout_queue: Pipeline, + params_source: ParamsSource, + apply_fns: Tuple[ActorApply, CriticApply], + actor_device: int, + seeds: List[int], + thread_lifetime: ThreadLifetime, +) -> None: + """Runs rollouts to collect trajectories from the environment. + + Args: + key (chex.PRNGKey): The PRNGkey. + config (DictConfig): Configuration settings for the environment and rollout. + rollout_queue (Pipeline): Queue for sending collected rollouts to the learner. + params_source (ParamsSource): Source for fetching the latest network parameters + from the learner. + apply_fns (Tuple): Functions for running the actor and critic networks. + actor_device (Device): Actor device to use for rollout. + seeds (List[int]): Seeds for initializing the environment. + thread_lifetime (ThreadLifetime): Manages the thread's lifecycle. + """ + name = threading.current_thread().name + print(f"{Fore.BLUE}{Style.BRIGHT}Thread {name} started{Style.RESET_ALL}") + actor_apply_fn, critic_apply_fn = apply_fns + num_agents, num_envs = config.system.num_agents, config.arch.num_envs + move_to_device = lambda x: jax.device_put(x, device=actor_device) + + @jax.jit + def act_fn( + params: Params, + observation: Observation, + key: chex.PRNGKey, + ) -> Tuple: + """Get action and value.""" + actor_policy = actor_apply_fn(params.actor_params, observation) + action = actor_policy.sample(seed=key) + log_prob = actor_policy.log_prob(action) + # It may be faster to calculate the values in the learner as + # then we won't need to pass critic params to actors. + value = critic_apply_fn(params.critic_params, observation).squeeze() + return action, log_prob, value + + timestep = env.reset(seed=seeds) + dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) + + # Loop till the desired num_updates is reached. + while not thread_lifetime.should_stop(): + # Rollout + traj: List[PPOTransition] = [] + actor_timings: Dict[str, List[float]] = defaultdict(list) + with RecordTimeTo(actor_timings["rollout_time"]): + for _ in range(config.system.rollout_length): + with RecordTimeTo(actor_timings["get_params_time"]): + params = params_source.get() # Get the latest parameters from the learner + + obs_tpu = tree.map(move_to_device, timestep.observation) + + # Get action and value + with RecordTimeTo(actor_timings["compute_action_time"]): + key, act_key = jax.random.split(key) + action, log_prob, value = act_fn(params, obs_tpu, act_key) + cpu_action = jax.device_get(action) + + # Step environment + with RecordTimeTo(actor_timings["env_step_time"]): + timestep = env.step(cpu_action.swapaxes(0, 1)) + + dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) + + # Append data to storage + traj.append( + PPOTransition( + dones, + action, + value, + timestep.reward, + log_prob, + obs_tpu, + timestep.extras, + ) + ) + + # send trajectories to learner + with RecordTimeTo(actor_timings["rollout_put_time"]): + try: + rollout_queue.put(traj, timestep, actor_timings) + except queue.Full: + err = "Waited too long to add to the rollout queue, killing the actor thread" + warnings.warn(err, stacklevel=2) + break + + env.close() + + +def get_learner_step_fn( + apply_fns: Tuple[ActorApply, CriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> SebulbaLearnerFn[LearnerState, PPOTransition]: + """Get the learner function.""" + + num_envs = config.arch.num_envs + num_learner_envs = int(num_envs // len(config.arch.learner_device_ids)) + + # Get apply and update functions for actor and critic networks. + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def _update_step( + learner_state: LearnerState, + traj_batch: PPOTransition, + ) -> Tuple[LearnerState, Tuple]: + """A single update of the network. + + This function calculates advantages and targets based on the trajectories + from the actor and updates the actor and critic networks based on the losses. + + Args: + learner_state (LearnerState): contains all the items needed for learning. + traj_batch (PPOTransition): the batch of data to learn with. + """ + + def _calculate_gae( + traj_batch: PPOTransition, last_val: chex.Array + ) -> Tuple[chex.Array, chex.Array]: + """Calculate the GAE.""" + + gamma, gae_lambda = config.system.gamma, config.system.gae_lambda + + def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple: + """Calculate the GAE for a single transition.""" + gae, next_value = gae_and_next_value + done, value, reward = transition.done, transition.value, transition.reward + + delta = reward + gamma * next_value * (1 - done) - value + gae = delta + gamma * gae_lambda * (1 - done) * gae + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + # Calculate advantage + params, opt_states, key, _, final_timestep = learner_state + last_val = critic_apply_fn(params.critic_params, final_timestep.observation) + advantages, targets = _calculate_gae(traj_batch, last_val) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + + # Unpack train state and batch info + params, opt_states, key = train_state + traj_batch, advantages, targets = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + traj_batch: PPOTransition, + gae: chex.Array, + key: chex.PRNGKey, + ) -> Tuple: + """Calculate the actor loss.""" + # Rerun network + actor_policy = actor_apply_fn(actor_params, traj_batch.obs) + log_prob = actor_policy.log_prob(traj_batch.action) + + # Calculate actor loss + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config.system.clip_eps, + 1.0 + config.system.clip_eps, + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + # The seed will be used in the TanhTransformedDistribution: + entropy = actor_policy.entropy(seed=key).mean() + + total_loss_actor = loss_actor - config.system.ent_coef * entropy + return total_loss_actor, (loss_actor, entropy) + + def _critic_loss_fn( + critic_params: FrozenDict, traj_batch: PPOTransition, targets: chex.Array + ) -> Tuple: + """Calculate the critic loss.""" + # Rerun network + value = critic_apply_fn(critic_params, traj_batch.obs) + + # Calculate value loss + value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( + -config.system.clip_eps, config.system.clip_eps + ) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + + critic_total_loss = config.system.vf_coef * value_loss + return critic_total_loss, (value_loss) + + # Calculate actor loss + key, entropy_key = jax.random.split(key) + actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) + actor_loss_info, actor_grads = actor_grad_fn( + params.actor_params, traj_batch, advantages, entropy_key + ) + + # Calculate critic loss + critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) + critic_loss_info, critic_grads = critic_grad_fn( + params.critic_params, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # pmean over learner devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), + axis_name="learner_devices", + ) + + # pmean over learner devices. + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="learner_devices" + ) + + # Update actor params and optimiser state + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # Update critic params and optimiser state + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + # Pack new params and optimiser state + new_params = Params(actor_new_params, critic_new_params) + new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) + # Pack loss info + actor_total_loss, (actor_loss, entropy) = actor_loss_info + critic_total_loss, (value_loss) = critic_loss_info + total_loss = critic_total_loss + actor_total_loss + loss_info = { + "total_loss": total_loss, + "value_loss": value_loss, + "actor_loss": actor_loss, + "entropy": entropy, + } + return (new_params, new_opt_state, key), loss_info + + params, opt_states, traj_batch, advantages, targets, key = update_state + key, shuffle_key, entropy_key = jax.random.split(key, 3) + # Shuffle minibatches + batch_size = config.system.rollout_length * num_learner_envs + permutation = jax.random.permutation(shuffle_key, batch_size) + batch = (traj_batch, advantages, targets) + batch = tree.map(lambda x: merge_leading_dims(x, 2), batch) + shuffled_batch = tree.map(lambda x: jnp.take(x, permutation, axis=0), batch) + minibatches = tree.map( + lambda x: jnp.reshape(x, (config.system.num_minibatches, -1, *x.shape[1:])), + shuffled_batch, + ) + # Update minibatches + (params, opt_states, _), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states, entropy_key), minibatches + ) + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + return update_state, loss_info + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + # Update epochs + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.ppo_epochs + ) + + params, opt_states, traj_batch, advantages, targets, key = update_state + learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn( + learner_state: LearnerState, traj_batch: PPOTransition + ) -> ExperimentOutput[LearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + Args: + learner_state (NamedTuple): + - params (Params): The initial model parameters. + - opt_states (OptStates): The initial optimizer state. + - key (chex.PRNGKey): The random number generator state. + - env_state (LogEnvState): The environment state. + - timesteps (TimeStep): The last timestep of the rollout. + """ + # This function is shard mapped on the batch axis, but `_update_step` needs + # the first axis to be time + traj_batch = tree.map(switch_leading_axes, traj_batch) + learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch) + + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_thread( + learn_fn: SebulbaLearnerFn[LearnerState, PPOTransition], + learner_state: LearnerState, + config: DictConfig, + eval_queue: Queue, + pipeline: Pipeline, + params_sources: Sequence[ParamsSource], +) -> None: + for _ in range(config.arch.num_evaluation): + # Create the lists to store metrics and timings for this learning iteration. + metrics: List[Tuple[Dict, Dict]] = [] + rollout_times: List[Dict] = [] + learn_times: Dict[str, List[float]] = defaultdict(list) + + with RecordTimeTo(learn_times["learner_time_per_eval"]): + for _ in range(config.system.num_updates_per_eval): + # Accumulate the batches, timesteps, and rollout times + accumulated_traj_batches = [] + accumulated_timesteps = [] + + # Possibly get many rollouts for 1 learn step - allows learning with large batches + for _ in range(config.arch.n_learner_accumulate): + # Get the trajectory batch from the pipeline + # This is blocking so it will wait until the pipeline has data. + with RecordTimeTo(learn_times["rollout_get_time"]): + traj_batch, timestep, rollout_time = pipeline.get(block=True) + + # Store the retrieved data + accumulated_traj_batches.append(traj_batch) + accumulated_timesteps.append(timestep) + rollout_times.append(rollout_time) + + # Concatenate the accumulated timesteps and trajectory batches on the num_envs axis + traj_batches = tree.map(lambda *x: jnp.concat(x, axis=0), *accumulated_traj_batches) + timesteps = tree.map(lambda *x: jnp.concat(x, axis=0), *accumulated_timesteps) + + # Replace the timestep in the learner state with the latest timestep + # This means the learner has access to the entire trajectory as well as + # an additional timestep which it can use to bootstrap. + learner_state = learner_state._replace(timestep=timesteps) + # Update the networks + with RecordTimeTo(learn_times["learning_time"]): + learner_state, ep_metrics, train_metrics = learn_fn(learner_state, traj_batches) + + metrics.append((ep_metrics, train_metrics)) + + # Update all the params sources so all actors can get the latest params + params = jax.block_until_ready(learner_state.params) + for source in params_sources: + source.update(params) + + # Pass all the metrics and params to the main thread (evaluator) for logging and evaluation + ep_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) + rollout_times: Dict[str, NDArray] = tree.map(lambda *x: np.mean(x), *rollout_times) + timing_dict = rollout_times | learn_times + timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) + + eval_queue.put((ep_metrics, train_metrics, learner_state, timing_dict)) + + +def learner_setup( + key: chex.PRNGKey, config: DictConfig, learner_devices: List +) -> Tuple[ + SebulbaLearnerFn[LearnerState, PPOTransition], + Tuple[ActorApply, CriticApply], + LearnerState, + Sharding, +]: + """Initialise learner_fn, network and learner state.""" + + # create temporory envoirnments. + env = environments.make_gym_env(config, config.arch.num_envs) + # Get number of agents and actions. + action_space = env.single_action_space + config.system.num_agents = len(action_space) + config.system.num_actions = int(action_space[0].n) + + devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) + mesh = Mesh(devices, axis_names=("learner_devices",)) + model_spec = PartitionSpec() + data_spec = PartitionSpec("learner_devices") + learner_sharding = NamedSharding(mesh, model_spec) + + # PRNG keys. + key, actor_key, critic_key = jax.random.split(key, 3) + + # Define network and optimiser. + actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_action_head = hydra.utils.instantiate( + config.network.action_head, action_dim=config.system.num_actions + ) + critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + + actor_network = Actor(torso=actor_torso, action_head=actor_action_head) + critic_network = Critic(torso=critic_torso) + + actor_lr = make_learning_rate(config.system.actor_lr, config) + critic_lr = make_learning_rate(config.system.critic_lr, config) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(critic_lr, eps=1e-5), + ) + + # Initialise observation: Select only obs for a single agent. + init_obs = jnp.array([env.single_observation_space.sample()]) + init_action_mask = jnp.ones((config.system.num_agents, config.system.num_actions)) + init_x = Observation(init_obs, init_action_mask) + + # Initialise actor params and optimiser state. + actor_params = actor_network.init(actor_key, init_x) + actor_opt_state = actor_optim.init(actor_params) + + # Initialise critic params and optimiser state. + critic_params = critic_network.init(critic_key, init_x) + critic_opt_state = critic_optim.init(critic_params) + + # Pack params. + params = Params(actor_params, critic_params) + + # Pack apply and update functions. + apply_fns = (actor_network.apply, critic_network.apply) + update_fns = (actor_optim.update, critic_optim.update) + + # defines how the learner state is sharded: params, opt and key = replicated, timestep = sharded + learn_state_spec = LearnerState(model_spec, model_spec, model_spec, None, data_spec) + learn = get_learner_step_fn(apply_fns, update_fns, config) + learn = jax.jit( + shard_map( + learn, + mesh=mesh, + in_specs=(learn_state_spec, data_spec), + out_specs=ExperimentOutput(learn_state_spec, data_spec, data_spec), + ) + ) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params(input_params=params) + # Update the params + params = restored_params + + # Define params to be replicated across devices and batches. + key, step_keys = jax.random.split(key) + opt_states = OptStates(actor_opt_state, critic_opt_state) + + # Duplicate learner across Learner devices. + params, opt_states, step_keys = jax.device_put( + (params, opt_states, step_keys), learner_sharding + ) + + # Initialise learner state. + init_learner_state = LearnerState(params, opt_states, step_keys, None, None) # type: ignore + env.close() + + return learn, apply_fns, init_learner_state, learner_sharding # type: ignore + + +def run_experiment(_config: DictConfig) -> float: + """Runs experiment.""" + config = copy.deepcopy(_config) + + local_devices = jax.local_devices() + devices = jax.devices() + err = "Local and global devices must be the same, we dont support multihost yet" + assert len(local_devices) == len(devices), err + learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] + actor_devices = [local_devices[device_id] for device_id in config.arch.actor_device_ids] + + # JAX and numpy RNGs + key = jax.random.PRNGKey(config.system.seed) + np_rng = np.random.default_rng(config.system.seed) + + # Setup learner. + learn, apply_fns, learner_state, learner_sharding = learner_setup(key, config, learner_devices) + + # Setup evaluator. + # One key per device for evaluation. + eval_act_fn = make_ff_eval_act_fn(apply_fns[0], config) + evaluator, evaluator_envs = get_eval_fn( + environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=False + ) + + # Calculate total timesteps. + config = check_total_timesteps(config) + check_sebulba_config(config) + + steps_per_rollout = ( + config.system.rollout_length + * config.arch.num_envs + * config.system.num_updates_per_eval + * config.arch.n_learner_accumulate + ) + + # Logger setup + logger = MavaLogger(config) + print_cfg: Dict = OmegaConf.to_container(config, resolve=True) + print_cfg["arch"]["devices"] = jax.devices() + pprint(print_cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Executor setup and launch. + inital_params = jax.device_put(learner_state.params, actor_devices[0]) # unreplicate + + # the rollout queue/ the pipe between actor and learner + pipe_lifetime = ThreadLifetime() + pipe = Pipeline(config.arch.rollout_queue_size, learner_sharding, pipe_lifetime) + pipe.start() + + params_sources: List[ParamsSource] = [] + actor_threads: List[threading.Thread] = [] + actor_lifetime = ThreadLifetime() + params_sources_lifetime = ThreadLifetime() + + # Create the actor threads + print(f"{Fore.BLUE}{Style.BRIGHT}Starting up actor threads...{Style.RESET_ALL}") + for actor_device in actor_devices: + # Create 1 params source per device + params_source = ParamsSource(inital_params, actor_device, params_sources_lifetime) + params_source.start() + params_sources.append(params_source) + # Create multiple rollout threads per actor device + for thread_id in range(config.arch.n_threads_per_executor): + key, act_key = jax.random.split(key) + seeds = np_rng.integers(np.iinfo(np.int32).max, size=config.arch.num_envs).tolist() + act_key = jax.device_put(key, actor_device) + + actor = threading.Thread( + target=rollout, + args=( + act_key, + # We have to do this here, creating envs inside actor threads causes deadlocks + environments.make_gym_env(config, config.arch.num_envs), + config, + pipe, + params_source, + apply_fns, + actor_device, + seeds, + actor_lifetime, + ), + name=f"Actor-{actor_device}-{thread_id}", + ) + actor_threads.append(actor) + + # Start the actors simultaneously + for actor in actor_threads: + actor.start() + + eval_queue: Queue = Queue() + threading.Thread( + target=learner_thread, + name="Learner", + args=(learn, learner_state, config, eval_queue, pipe, params_sources), + ).start() + + max_episode_return = -np.inf + best_params_cpu = jax.device_get(inital_params.actor_params) + + # This is the main loop, all it does is evaluation and logging. + # Acting and learning is happening in their own threads. + # This loop waits for the learner to finish an update before evaluation and logging. + for eval_step in range(config.arch.num_evaluation): + # Sync with the learner - the get() is blocking so it keeps eval and learning in step. + episode_metrics, train_metrics, learner_state, time_metrics = eval_queue.get() + + t = int(steps_per_rollout * (eval_step + 1)) + time_metrics |= {"timestep": t, "pipline_size": pipe.qsize()} + logger.log(time_metrics, t, eval_step, LogEvent.MISC) + + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / time_metrics["rollout_time"] + if ep_completed: + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + + train_metrics["learner_step"] = (eval_step + 1) * config.system.num_updates_per_eval + train_metrics["learner_steps_per_second"] = ( + config.system.num_updates_per_eval + ) / time_metrics["learner_time_per_eval"] + logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) + + learner_state_cpu = jax.device_get(learner_state) + key, eval_key = jax.random.split(key, 2) + eval_metrics = evaluator(learner_state_cpu.params.actor_params, eval_key, {}) + logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) + + episode_return = np.mean(eval_metrics["episode_return"]) + + if save_checkpoint: # Save a checkpoint of the learner state + checkpointer.save( + timestep=steps_per_rollout * (eval_step + 1), + unreplicated_learner_state=learner_state_cpu, + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params_cpu = copy.deepcopy(learner_state_cpu.params.actor_params) + max_episode_return = float(episode_return) + + evaluator_envs.close() + eval_performance = float(np.mean(eval_metrics[config.env.eval_metric])) + + # Measure absolute metric. + if config.arch.absolute_metric: + print(f"{Fore.BLUE}{Style.BRIGHT}Measuring absolute metric...{Style.RESET_ALL}") + abs_metric_evaluator, abs_metric_evaluator_envs = get_eval_fn( + environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=True + ) + key, eval_key = jax.random.split(key, 2) + eval_metrics = abs_metric_evaluator(best_params_cpu, eval_key, {}) + + t = int(steps_per_rollout * (eval_step + 1)) + logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) + abs_metric_evaluator_envs.close() + + # Stop all the threads. + logger.stop() + actor_lifetime.stop() + pipe.clear() # We clear the pipeline before stopping the actor threads to avoid deadlock + print(f"{Fore.RED}{Style.BRIGHT}Pipe cleared{Style.RESET_ALL}") + print(f"{Fore.RED}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}") + for actor in actor_threads: + actor.join() + print(f"{Fore.RED}{Style.BRIGHT}{actor.name} stopped{Style.RESET_ALL}") + print(f"{Fore.RED}{Style.BRIGHT}Stopping pipeline...{Style.RESET_ALL}") + pipe_lifetime.stop() + pipe.join() + print(f"{Fore.RED}{Style.BRIGHT}Stopping params sources...{Style.RESET_ALL}") + params_sources_lifetime.stop() + for params_source in params_sources: + params_source.join() + print(f"{Fore.RED}{Style.BRIGHT}All threads stopped...{Style.RESET_ALL}") + + return eval_performance + + +@hydra.main( + config_path="../../../configs/default/", + config_name="ff_ippo_sebulba.yaml", + version_base="1.2", +) +def hydra_entry_point(cfg: DictConfig) -> float: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + cfg.logger.system_name = "ff_ippo_sebulba" + + # Run experiment. + eval_performance = run_experiment(cfg) + print(f"{Fore.CYAN}{Style.BRIGHT}IPPO experiment completed{Style.RESET_ALL}") + return eval_performance + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/mava/systems/ppo/types.py b/mava/systems/ppo/types.py index f129b89d3..c8145b1a7 100644 --- a/mava/systems/ppo/types.py +++ b/mava/systems/ppo/types.py @@ -20,7 +20,7 @@ from optax._src.base import OptState from typing_extensions import NamedTuple -from mava.types import Action, Done, HiddenState, State, Value +from mava.types import Action, Done, HiddenState, Observation, State, Value class Params(NamedTuple): @@ -74,7 +74,7 @@ class PPOTransition(NamedTuple): value: Value reward: chex.Array log_prob: chex.Array - obs: chex.Array + obs: Observation info: Dict diff --git a/mava/systems/q_learning/anakin/rec_iql.py b/mava/systems/q_learning/anakin/rec_iql.py index 2a3c9783c..5a1d7df34 100644 --- a/mava/systems/q_learning/anakin/rec_iql.py +++ b/mava/systems/q_learning/anakin/rec_iql.py @@ -48,13 +48,13 @@ from mava.types import Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import ( switch_leading_axes, unreplicate_batch_dim, unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics diff --git a/mava/systems/sac/anakin/ff_isac.py b/mava/systems/sac/anakin/ff_isac.py index a3a4d3430..9e54dde2d 100644 --- a/mava/systems/sac/anakin/ff_isac.py +++ b/mava/systems/sac/anakin/ff_isac.py @@ -49,9 +49,9 @@ from mava.types import MarlEnv, Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics diff --git a/mava/systems/sac/anakin/ff_masac.py b/mava/systems/sac/anakin/ff_masac.py index d006bb250..d0c763760 100644 --- a/mava/systems/sac/anakin/ff_masac.py +++ b/mava/systems/sac/anakin/ff_masac.py @@ -50,9 +50,9 @@ from mava.utils import make_env as environments from mava.utils.centralised_training import get_joint_action, get_updated_joint_actions from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics diff --git a/mava/types.py b/mava/types.py index 5d1109785..8a191f5ab 100644 --- a/mava/types.py +++ b/mava/types.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Generic, Protocol, Tuple, TypeVar +from typing import Any, Callable, Dict, Generic, Optional, Protocol, Tuple, TypeVar import chex import jumanji.specs as specs @@ -118,7 +118,7 @@ class Observation(NamedTuple): agents_view: chex.Array # (num_agents, num_obs_features) action_mask: chex.Array # (num_agents, num_actions) - step_count: chex.Array # (num_agents, ) + step_count: Optional[chex.Array] = None # (num_agents, ) class ObservationGlobalState(NamedTuple): @@ -131,7 +131,7 @@ class ObservationGlobalState(NamedTuple): agents_view: chex.Array # (num_agents, num_obs_features) action_mask: chex.Array # (num_agents, num_actions) global_state: chex.Array # (num_agents, num_agents * num_obs_features) - step_count: chex.Array # (num_agents, ) + step_count: Optional[chex.Array] = None # (num_agents, ) RNNObservation: TypeAlias = Tuple[Observation, Done] @@ -141,6 +141,7 @@ class ObservationGlobalState(NamedTuple): # `MavaState` is the main type passed around in our systems. It is often used as a scan carry. # Types like: `LearnerState` (mava/systems//types.py) are `MavaState`s. MavaState = TypeVar("MavaState") +MavaTransition = TypeVar("MavaTransition") class ExperimentOutput(NamedTuple, Generic[MavaState]): @@ -152,6 +153,7 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]): LearnerFn = Callable[[MavaState], ExperimentOutput[MavaState]] +SebulbaLearnerFn = Callable[[MavaState, MavaTransition], ExperimentOutput[MavaState]] ActorApply = Callable[[FrozenDict, Observation], Distribution] CriticApply = Callable[[FrozenDict, Observation], Value] RecActorApply = Callable[ diff --git a/mava/utils/total_timestep_checker.py b/mava/utils/config.py similarity index 53% rename from mava/utils/total_timestep_checker.py rename to mava/utils/config.py index c2cda8320..c82e3a315 100644 --- a/mava/utils/total_timestep_checker.py +++ b/mava/utils/config.py @@ -18,9 +18,39 @@ from omegaconf import DictConfig +def check_sebulba_config(config: DictConfig) -> None: + """Checks that the given config does not have conflicting values.""" + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + + assert config.arch.num_envs % len(config.arch.learner_device_ids) == 0, ( + "Number of environments must be divisible by the number of learner." + + "The output of each actor is equally split across the learners." + ) + + num_eval_samples = ( + int(config.arch.num_envs / len(config.arch.learner_device_ids)) + * config.system.rollout_length + ) + assert num_eval_samples % config.system.num_minibatches == 0, ( + f"Number of training samples per evaluator ({num_eval_samples})" + + f"must be divisible by num_minibatches ({config.system.num_minibatches})." + ) + + def check_total_timesteps(config: DictConfig) -> DictConfig: """Check if total_timesteps is set, if not, set it based on the other parameters""" - n_devices = len(jax.devices()) + + if config.arch.architecture_name == "anakin": + n_devices = len(jax.devices()) + update_batch_size = config.system.update_batch_size + n_accumulate = 1 # We dont accumulate envs in anakin + else: + n_devices = 1 # We only use a single device's output when updating. + update_batch_size = 1 + n_accumulate = config.arch.n_learner_accumulate if config.system.total_timesteps is None: config.system.num_updates = int(config.system.num_updates) @@ -28,17 +58,19 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: n_devices * config.system.num_updates * config.system.rollout_length - * config.system.update_batch_size + * update_batch_size * config.arch.num_envs + * n_accumulate ) else: config.system.total_timesteps = int(config.system.total_timesteps) config.system.num_updates = int( config.system.total_timesteps // config.system.rollout_length - // config.system.update_batch_size + // update_batch_size // config.arch.num_envs // n_devices + // n_accumulate ) print( f"{Fore.RED}{Style.BRIGHT} Changing the number of updates " diff --git a/mava/utils/jax_utils.py b/mava/utils/jax_utils.py index 3c03455f2..c89c6a4a4 100644 --- a/mava/utils/jax_utils.py +++ b/mava/utils/jax_utils.py @@ -71,5 +71,4 @@ def unreplicate_batch_dim(x: Any) -> Any: def switch_leading_axes(arr: chex.Array) -> chex.Array: """Switches the first two axes, generally used for BT -> TB.""" - arr = tree.map(lambda x: jax.numpy.swapaxes(x, 0, 1), arr) - return arr + return tree.map(lambda x: x.swapaxes(0, 1), arr) diff --git a/mava/utils/logger.py b/mava/utils/logger.py index 1ab519f30..bd090604b 100644 --- a/mava/utils/logger.py +++ b/mava/utils/logger.py @@ -153,8 +153,11 @@ class NeptuneLogger(BaseLogger): def __init__(self, cfg: DictConfig, unique_token: str) -> None: tags = list(cfg.logger.kwargs.neptune_tag) project = cfg.logger.kwargs.neptune_project + mode = ( + "async" if cfg.arch.architecture_name == "anakin" else "sync" + ) # async logging leads to deadlocks in sebulba - self.logger = neptune.init_run(project=project, tags=tags) + self.logger = neptune.init_run(project=project, tags=tags, mode=mode) self.logger["config"] = stringify_unsupported(cfg) self.detailed_logging = cfg.logger.kwargs.detailed_neptune_logging @@ -175,6 +178,7 @@ def log_stat(self, key: str, value: float, step: int, eval_step: int, event: Log if not self.detailed_logging and not is_main_metric: return + value = value.item() if isinstance(value, (jax.Array, np.ndarray)) else value self.logger[f"{event.value}/{key}"].log(value, step=step) def stop(self) -> None: @@ -341,7 +345,8 @@ def get_logger_path(config: DictConfig, logger_type: str) -> str: def describe(x: ArrayLike) -> Union[Dict[str, ArrayLike], ArrayLike]: """Generate summary statistics for an array of metrics (mean, std, min, max).""" - if not isinstance(x, jax.Array) or x.size <= 1: + + if not isinstance(x, (jax.Array, np.ndarray)) or x.size <= 1: return x # np instead of jnp because we don't jit here diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 3633a6113..1206d3886 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -14,6 +14,10 @@ from typing import Dict, Tuple, Type +import gymnasium +import gymnasium as gym +import gymnasium.vector +import gymnasium.wrappers import jaxmarl import jumanji import matrax @@ -40,12 +44,18 @@ CleanerWrapper, ConnectorWrapper, GigastepWrapper, + GymAgentIDWrapper, + GymRecordEpisodeMetrics, + GymToJumanji, + GymWrapper, LbfWrapper, MabraxWrapper, MatraxWrapper, RecordEpisodeMetrics, RwareWrapper, + SmacWrapper, SmaxWrapper, + async_multiagent_worker, ) from mava.wrappers.jaxmarl import JaxMarlWrapper @@ -65,6 +75,12 @@ _jaxmarl_registry: Dict[str, Type[JaxMarlWrapper]] = {"Smax": SmaxWrapper, "MaBrax": MabraxWrapper} _gigastep_registry = {"Gigastep": GigastepWrapper} +_gym_registry = { + "RobotWarehouse": GymWrapper, + "LevelBasedForaging": GymWrapper, + "Starcraft": SmacWrapper, +} + def add_extra_wrappers( train_env: MarlEnv, eval_env: MarlEnv, config: DictConfig @@ -212,6 +228,44 @@ def make_gigastep_env( return train_env, eval_env +def make_gym_env( + config: DictConfig, + num_env: int, + add_global_state: bool = False, +) -> GymToJumanji: + """ + Create a gymnasium environment. + + Args: + config (Dict): The configuration of the environment. + num_env (int) : The number of parallel envs to create. + add_global_state (bool): Whether to add the global state to the observation. Default False. + + Returns: + Async environments. + """ + wrapper = _gym_registry[config.env.env_name] + config.system.add_agent_id = config.system.add_agent_id & (~config.env.implicit_agent_id) + + def create_gym_env(config: DictConfig, add_global_state: bool = False) -> gymnasium.Env: + registered_name = f"{config.env.scenario.name}:{config.env.scenario.task_name}" + env = gym.make(registered_name, disable_env_checker=False, **config.env.kwargs) + wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state) + if config.system.add_agent_id: + wrapped_env = GymAgentIDWrapper(wrapped_env) + wrapped_env = GymRecordEpisodeMetrics(wrapped_env) + return wrapped_env + + envs = gymnasium.vector.AsyncVectorEnv( + [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)], + worker=async_multiagent_worker, + ) + + envs = GymToJumanji(envs) + + return envs + + def make(config: DictConfig, add_global_state: bool = False) -> Tuple[MarlEnv, MarlEnv]: """ Create environments for training and evaluation. diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py new file mode 100644 index 000000000..0e2e6261d --- /dev/null +++ b/mava/utils/sebulba.py @@ -0,0 +1,190 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import queue +import threading +import time +from typing import Any, Dict, List, Sequence, Tuple, Union + +import jax +import jax.numpy as jnp +from colorama import Fore, Style +from jax import tree +from jax.sharding import Sharding +from jumanji.types import TimeStep + +# todo: remove the ppo dependencies when we make sebulba for other systems +from mava.systems.ppo.types import Params, PPOTransition + +QUEUE_PUT_TIMEOUT = 100 + + +class ThreadLifetime: + """Simple class for a mutable boolean that can be used to signal a thread to stop.""" + + def __init__(self) -> None: + self._stop = False + + def should_stop(self) -> bool: + return self._stop + + def stop(self) -> None: + self._stop = True + + +@jax.jit +def _stack_trajectory(trajectory: List[PPOTransition]) -> PPOTransition: + """Stack a list of parallel_env transitions into a single + transition of shape [rollout_len, num_envs, ...].""" + return tree.map(lambda *x: jnp.stack(x, axis=0).swapaxes(0, 1), *trajectory) # type: ignore + + +# Modified from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py +class Pipeline(threading.Thread): + """ + The `Pipeline` shards trajectories into learner devices, + ensuring trajectories are consumed in the right order to avoid being off-policy + and limit the max number of samples in device memory at one time to avoid OOM issues. + """ + + def __init__(self, max_size: int, learner_sharding: Sharding, lifetime: ThreadLifetime): + """ + Initializes the pipeline with a maximum size and the devices to shard trajectories across. + + Args: + max_size: The maximum number of trajectories to keep in the pipeline. + learner_sharding: The sharding used for the learner's update function. + lifetime: A `ThreadLifetime` which is used to stop this thread. + """ + super().__init__(name="Pipeline") + + self.sharding = learner_sharding + self.tickets_queue: queue.Queue = queue.Queue() + self._queue: queue.Queue = queue.Queue(maxsize=max_size) + self.lifetime = lifetime + + def run(self) -> None: + """This function ensures that trajectories on the queue are consumed in the right order. The + start_condition and end_condition are used to ensure that only 1 thread is processing an + item from the queue at one time, ensuring predictable memory usage. + """ + while not self.lifetime.should_stop(): + try: + start_condition, end_condition = self.tickets_queue.get(timeout=1) + with end_condition: + with start_condition: + start_condition.notify() + end_condition.wait() + except queue.Empty: + continue + + def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict) -> None: + """Put a trajectory on the queue to be consumed by the learner.""" + start_condition, end_condition = (threading.Condition(), threading.Condition()) + with start_condition: + self.tickets_queue.put((start_condition, end_condition)) + start_condition.wait() # wait to be allowed to start + + # [Transition(num_envs)] * rollout_len -> Transition[done=(num_envs, rollout_len, ...)] + traj = _stack_trajectory(traj) + traj, timestep = jax.device_put((traj, timestep), device=self.sharding) + + # We block on the `put` to ensure that actors wait for the learners to catch up. + # This ensures two things: + # The actors don't get too far ahead of the learners, which could lead to off-policy data. + # The actors don't "waste" samples by generating samples that the learners can't consume. + # However, we put a timeout of 100 seconds to avoid deadlocks in case the learner + # is not consuming the data. This is a safety measure and should not normally occur. + # We use a try-finally so the lock is released even if an exception is raised. + try: + self._queue.put( + (traj, timestep, time_dict), + block=True, + timeout=QUEUE_PUT_TIMEOUT, + ) + except queue.Full: + print( + f"{Fore.RED}{Style.BRIGHT}Pipeline is full and actor has timed out, " + f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}" + ) + finally: + with end_condition: + end_condition.notify() # notify that we have finished + + def qsize(self) -> int: + """Returns the number of trajectories in the pipeline.""" + return self._queue.qsize() + + def get( + self, block: bool = True, timeout: Union[float, None] = None + ) -> Tuple[PPOTransition, TimeStep, Dict]: + """Get a trajectory from the pipeline.""" + return self._queue.get(block, timeout) # type: ignore + + def clear(self) -> None: + """Clear the pipeline.""" + while not self._queue.empty(): + try: + self._queue.get(block=False) + except queue.Empty: + break + + +class ParamsSource(threading.Thread): + """A `ParamSource` is a component that allows networks params to be passed from a + `Learner` component to `Actor` components. + """ + + def __init__(self, init_value: Params, device: jax.Device, lifetime: ThreadLifetime): + super().__init__(name=f"ParamsSource-{device.id}") + self.value: Params = jax.device_put(init_value, device) + self.device = device + self.new_value: queue.Queue = queue.Queue() + self.lifetime = lifetime + + def run(self) -> None: + """This function is responsible for updating the value of the `ParamSource` when a new value + is available. + """ + while not self.lifetime.should_stop(): + try: + waiting = self.new_value.get(block=True, timeout=1) + self.value = jax.device_put(waiting, self.device) + except queue.Empty: + continue + + def update(self, new_params: Params) -> None: + """Update the value of the `ParamSource` with a new value. + + Args: + new_params: The new value to update the `ParamSource` with. + """ + self.new_value.put(new_params) + + def get(self) -> Params: + """Get the current value of the `ParamSource`.""" + return self.value + + +class RecordTimeTo: + def __init__(self, to: Any): + self.to = to + + def __enter__(self) -> None: + self.start = time.monotonic() + + def __exit__(self, *args: Any) -> None: + end = time.monotonic() + self.to.append(end - self.start) diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 0c6f4753f..f7e89d756 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -16,6 +16,14 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper +from mava.wrappers.gym import ( + GymAgentIDWrapper, + GymRecordEpisodeMetrics, + GymToJumanji, + GymWrapper, + SmacWrapper, + async_multiagent_worker, +) from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, diff --git a/mava/wrappers/episode_metrics.py b/mava/wrappers/episode_metrics.py index e9e130819..f4c34002e 100644 --- a/mava/wrappers/episode_metrics.py +++ b/mava/wrappers/episode_metrics.py @@ -17,6 +17,7 @@ import chex import jax import jax.numpy as jnp +import numpy as np from jax import tree from jumanji.types import TimeStep from jumanji.wrappers import Wrapper @@ -120,12 +121,12 @@ def get_final_step_metrics(metrics: Dict[str, chex.Array]) -> Tuple[Dict[str, ch expects arrays for computing summary statistics on the episode metrics. """ is_final_ep = metrics.pop("is_terminal_step") - has_final_ep_step = bool(jnp.any(is_final_ep)) + has_final_ep_step = bool(np.any(is_final_ep)) final_metrics: Dict[str, chex.Array] # If it didn't make it to the final step, return zeros. if not has_final_ep_step: - final_metrics = tree.map(jnp.zeros_like, metrics) + final_metrics = tree.map(np.zeros_like, metrics) else: final_metrics = tree.map(lambda x: x[is_final_ep], metrics) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py new file mode 100644 index 000000000..020abf158 --- /dev/null +++ b/mava/wrappers/gym.py @@ -0,0 +1,392 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import traceback +import warnings +from dataclasses import field +from enum import IntEnum +from multiprocessing import Queue +from multiprocessing.connection import Connection +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +import gymnasium +import gymnasium.vector.async_vector_env +import numpy as np +from gymnasium import spaces +from gymnasium.spaces.utils import is_space_dtype_shape_equiv +from gymnasium.vector.utils import write_to_shared_memory +from numpy.typing import NDArray + +from mava.types import Observation, ObservationGlobalState + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + +# Filter out the warnings +warnings.filterwarnings("ignore", module="gymnasium.utils.passive_env_checker") + + +# needed to avoid host -> device transfers when calling TimeStep.last() +class StepType(IntEnum): + """Coppy of Jumanji's step type but with numpy arrays""" + + FIRST = 0 + MID = 1 + LAST = 2 + + +@dataclass +class TimeStep: + step_type: StepType + reward: NDArray + discount: NDArray + observation: Union[Observation, ObservationGlobalState] + extras: Dict = field(default_factory=dict) + + def first(self) -> bool: + return self.step_type == StepType.FIRST + + def mid(self) -> bool: + return self.step_type == StepType.MID + + def last(self) -> bool: + return self.step_type == StepType.LAST + + +class GymWrapper(gymnasium.Wrapper): + """Base wrapper for multi-agent gym environments. + This wrapper works out of the box for RobotWarehouse and level based foraging. + """ + + def __init__( + self, + env: gymnasium.Env, + use_shared_rewards: bool = True, + add_global_state: bool = False, + ): + """Initialise the gym wrapper + Args: + env (gymnasium.env): gymnasium env instance. + use_shared_rewards (bool, optional): Use individual or shared rewards. + Defaults to False. + add_global_state (bool, optional) : Create global observations. Defaults to False. + """ + super().__init__(env) + self._env = env + self.use_shared_rewards = use_shared_rewards + self.add_global_state = add_global_state + self.num_agents = len(self._env.action_space) + self.num_actions = self._env.action_space[0].n + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[NDArray, Dict]: + if seed is not None: + self.env.unwrapped.seed(seed) + + agents_view, info = self._env.reset() + + info = {"actions_mask": self.get_actions_mask(info)} + if self.add_global_state: + info["global_obs"] = self.get_global_obs(agents_view) + + return np.array(agents_view), info + + def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + agents_view, reward, terminated, truncated, info = self._env.step(actions) + + info = {"actions_mask": self.get_actions_mask(info)} + if self.add_global_state: + info["global_obs"] = self.get_global_obs(agents_view) + + if self.use_shared_rewards: + reward = np.array([np.array(reward).sum()] * self.num_agents) + else: + reward = np.array(reward) + + return agents_view, reward, terminated, truncated, info + + def get_actions_mask(self, info: Dict) -> NDArray: + if "action_mask" in info: + return np.array(info["action_mask"]) + return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + + def get_global_obs(self, obs: NDArray) -> NDArray: + global_obs = np.concatenate(obs, axis=0) + return np.tile(global_obs, (self.num_agents, 1)) + + +class SmacWrapper(GymWrapper): + """A wrapper that converts actions step to integers.""" + + def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + # Convert actions to integers before passing them to the environment + actions = [int(action) for action in actions] + + agents_view, reward, terminated, truncated, info = super().step(actions) + + return agents_view, reward, terminated, truncated, info + + def get_actions_mask(self, info: Dict) -> NDArray: + return np.array(self._env.unwrapped.get_avail_actions()) + + +class GymRecordEpisodeMetrics(gymnasium.Wrapper): + """Record the episode returns and lengths.""" + + def __init__(self, env: gymnasium.Env): + super().__init__(env) + self._env = env + self.running_count_episode_return = 0.0 + self.running_count_episode_length = 0.0 + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[NDArray, Dict]: + agents_view, info = self._env.reset(seed, options) + + # Create the metrics dict + metrics = { + "episode_return": self.running_count_episode_return, + "episode_length": self.running_count_episode_length, + "is_terminal_step": True, + } + + # Reset the metrics + self.running_count_episode_return = 0.0 + self.running_count_episode_length = 0 + + if "won_episode" in info: + metrics["won_episode"] = info["won_episode"] + + info["metrics"] = metrics + + return agents_view, info + + def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + agents_view, reward, terminated, truncated, info = self._env.step(actions) + + self.running_count_episode_return += float(np.mean(reward)) + self.running_count_episode_length += 1 + + metrics = { + "episode_return": self.running_count_episode_return, + "episode_length": self.running_count_episode_length, + "is_terminal_step": False, + } + if "won_episode" in info: + metrics["won_episode"] = info["won_episode"] + + info["metrics"] = metrics + + return agents_view, reward, terminated, truncated, info + + +class GymAgentIDWrapper(gymnasium.Wrapper): + """Add one hot agent IDs to observation.""" + + def __init__(self, env: gymnasium.Env): + super().__init__(env) + + self.agent_ids = np.eye(self.env.num_agents) + self.observation_space = self.modify_space(self.env.observation_space) + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[NDArray, Dict]: + """Reset the environment.""" + obs, info = self.env.reset(seed, options) + obs = np.concatenate([self.agent_ids, obs], axis=1) + return obs, info + + def step(self, action: list) -> Tuple[NDArray, float, bool, bool, Dict]: + """Step the environment.""" + obs, reward, terminated, truncated, info = self.env.step(action) + obs = np.concatenate([self.agent_ids, obs], axis=1) + return obs, reward, terminated, truncated, info + + def modify_space(self, space: spaces.Space) -> spaces.Space: + if isinstance(space, spaces.Box): + new_shape = (space.shape[0] + len(self.agent_ids),) + return spaces.Box( + low=space.low[0], high=space.high[0], shape=new_shape, dtype=space.dtype + ) + elif isinstance(space, spaces.Tuple): + return spaces.Tuple(self.modify_space(s) for s in space) + else: + raise ValueError(f"Space {type(space)} is not currently supported.") + + +class GymToJumanji: + """Converts from the Gym API to the dm_env API.""" + + def __init__(self, env: gymnasium.vector.VectorEnv): + self.env = env + self.single_action_space = env.unwrapped.single_action_space + self.single_observation_space = env.unwrapped.single_observation_space + + def reset(self, seed: Optional[list[int]] = None, options: Optional[dict] = None) -> TimeStep: + obs, info = self.env.reset(seed=seed, options=options) # type: ignore + + num_agents = len(self.env.single_action_space) # type: ignore + num_envs = self.env.num_envs + + ep_done = np.zeros(num_envs, dtype=float) + rewards = np.zeros((num_envs, num_agents), dtype=float) + teminated = np.zeros(num_envs, dtype=float) + + timestep = self._create_timestep(obs, ep_done, teminated, rewards, info) + + return timestep + + def step(self, action: list) -> TimeStep: + obs, rewards, terminated, truncated, info = self.env.step(action) + + ep_done = np.logical_or(terminated, truncated) + + timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) + + return timestep + + def _format_observation( + self, obs: NDArray, info: Dict + ) -> Union[Observation, ObservationGlobalState]: + """Create an observation from the raw observation and environment state.""" + + # (num_agents, num_envs, ...) -> (num_envs, num_agents, ...) + obs = np.array(obs).swapaxes(0, 1) + action_mask = np.stack(info["actions_mask"]) + obs_data = {"agents_view": obs, "action_mask": action_mask} + + if "global_obs" in info: + global_obs = np.array(info["global_obs"]).swapaxes(0, 1) + obs_data["global_state"] = global_obs + return ObservationGlobalState(**obs_data) + else: + return Observation(**obs_data) + + def _create_timestep( + self, obs: NDArray, ep_done: NDArray, terminated: NDArray, rewards: NDArray, info: Dict + ) -> TimeStep: + observation = self._format_observation(obs, info) + # Filter out the masks and auxiliary data + extras = {key: value for key, value in info["metrics"].items() if key[0] != "_"} + step_type = np.where(ep_done, StepType.LAST, StepType.MID) + + return TimeStep( + step_type=step_type, # type: ignore + reward=rewards, + discount=1.0 - terminated, + observation=observation, + extras=extras, + ) + + def close(self) -> None: + self.env.close() + + +# Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py +# Modified to work with multiple agents +def async_multiagent_worker( # CCR001 + index: int, + env_fn: Callable, + pipe: Connection, + parent_pipe: Connection, + shared_memory: Union[NDArray, dict[str, Any], tuple[Any, ...]], + error_queue: Queue, +) -> None: + env = env_fn() + observation_space = env.observation_space + action_space = env.action_space + parent_pipe.close() + + try: + while True: + command, data = pipe.recv() + + if command == "reset": + observation, info = env.reset(**data) + if shared_memory: + write_to_shared_memory(observation_space, index, observation, shared_memory) + observation = None + pipe.send(((observation, info), True)) + elif command == "step": + # Modified the step function to align with 'AutoResetWrapper'. + # The environment resets immediately upon termination or truncation. + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + if np.logical_or(terminated, truncated).all(): + observation, info = env.reset() + + if shared_memory: + write_to_shared_memory(observation_space, index, observation, shared_memory) + observation = None + + pipe.send(((observation, reward, terminated, truncated, info), True)) + elif command == "close": + pipe.send((None, True)) + break + elif command == "_call": + name, args, kwargs = data + if name in ["reset", "step", "close", "_setattr", "_check_spaces"]: + raise ValueError( + f"Trying to call function `{name}` with \ + `call`, use `{name}` directly instead." + ) + + attr = env.get_wrapper_attr(name) + if callable(attr): + pipe.send((attr(*args, **kwargs), True)) + else: + pipe.send((attr, True)) + elif command == "_setattr": + name, value = data + env.set_wrapper_attr(name, value) + pipe.send((None, True)) + elif command == "_check_spaces": + obs_mode, single_obs_space, single_action_space = data + pipe.send( + ( + ( + ( + single_obs_space == observation_space + if obs_mode == "same" + else is_space_dtype_shape_equiv(single_obs_space, observation_space) + ), + single_action_space == action_space, + ), + True, + ) + ) + else: + raise RuntimeError( + f"Received unknown command `{command}`. Must be one of \ + [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]." + ) + except (KeyboardInterrupt, Exception): + error_type, error_message, _ = sys.exc_info() + trace = traceback.format_exc() + + error_queue.put((index, error_type, error_message, trace)) + pipe.send((None, False)) + finally: + env.close() diff --git a/mava/wrappers/jaxmarl.py b/mava/wrappers/jaxmarl.py index 72608f85f..f6ad51558 100644 --- a/mava/wrappers/jaxmarl.py +++ b/mava/wrappers/jaxmarl.py @@ -214,7 +214,6 @@ def reset( def step( self, state: JaxMarlState, action: Array ) -> Tuple[JaxMarlState, TimeStep[Union[Observation, ObservationGlobalState]]]: - # todo: how do you know if it's a truncation with only dones? key, step_key = jax.random.split(state.key) obs, env_state, reward, done, _ = self._env.step( step_key, state.state, unbatchify(action, self.agents) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index ad22099e4..13ff3a050 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,12 +4,14 @@ distrax flashbax~=0.1.0 flax gigastep @ git+https://github.com/mlech26l/gigastep +gymnasium hydra-core==1.3.2 id-marl-eval @ git+https://github.com/instadeepai/marl-eval jax==0.4.30 jaxlib==0.4.30 jaxmarl jumanji @ git+https://github.com/sash-a/jumanji@old_jumanji # Includes a few extra MARL envs +lbforaging matrax @ git+https://github.com/instadeepai/matrax mujoco==3.1.3 mujoco-mjx==3.1.3 @@ -18,7 +20,9 @@ numpy==1.26.4 omegaconf optax protobuf~=3.20 +rware scipy==1.12.0 +smaclite @ git+https://github.com/uoe-agents/smaclite.git tensorboard_logger tensorflow_probability type_enforced # needed because gigastep is missing this dependency