From adc21144c04a0ade07fd660948bb8f390dc47578 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 10 Jun 2024 11:28:00 +0100 Subject: [PATCH 001/125] feat: gym wrapper --- mava/configs/arch/sebulba.yaml | 24 +++++++++ mava/utils/make_env.py | 28 +++++++++++ mava/wrappers/__init__.py | 1 + mava/wrappers/gym.py | 92 ++++++++++++++++++++++++++++++++++ requirements/requirements.txt | 1 + 5 files changed, 146 insertions(+) create mode 100644 mava/configs/arch/sebulba.yaml create mode 100644 mava/wrappers/gym.py diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml new file mode 100644 index 000000000..ed1d07dff --- /dev/null +++ b/mava/configs/arch/sebulba.yaml @@ -0,0 +1,24 @@ +# --- Sebulba config --- +arch_name: "sebulba" +num_envs: 16 # number of envs per thread + +# --- Evaluation --- +evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select + # an action which corresponds to the greatest logit. If false, the policy will sample + # from the logits. +num_eval_episodes: 32 # Number of episodes to evaluate per evaluation. +num_evaluation: 200 # Number of evenly spaced evaluations to perform during training. +absolute_metric: True # Whether the absolute metric should be computed. For more details + # on the absolute metric please see: https://arxiv.org/abs/2209.10485 + +# --- Sebulba devices config --- +n_threads_per_executor: 1 # num of different threads/env batches per actor +executor_device_ids: [0] # ids of actor devices +learner_device_ids: [0] # ids of learner devices + +# --- Sebulba rollout and env config --- +concurrency: False # whether actor and learner should run concurrently +async_envs: True # "whether to use async vector or sync vector envs" + +# --- To be defined during training --- +log_frequency: ~ \ No newline at end of file diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 39b348b40..c66d585f5 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -14,9 +14,11 @@ from typing import Tuple +import gym.vector import jaxmarl import jumanji import matrax +import gym from gigastep import ScenarioBuilder from jaxmarl.environments.smax import map_name_to_scenario from jumanji.env import Environment @@ -46,6 +48,7 @@ RecordEpisodeMetrics, RwareWrapper, SmaxWrapper, + GymWrapper, ) # Registry mapping environment names to their generator and wrapper classes. @@ -198,6 +201,29 @@ def make_gigastep_env( train_env, eval_env = add_extra_wrappers(train_env, eval_env, config) return train_env, eval_env +def make_gym_env(env_name: str, config: DictConfig, add_global_state: bool = False): + """ + Create a Gym environment. + + Args: + env_name (str): The name of the environment to create. + config (Dict): The configuration of the environment. + add_global_state (bool): Whether to add the global state to the observation. Default False. + + Returns: + A tuple of the environments. + """ + def create_gym_env(config: DictConfig, add_global_state: bool = False, eval_env : bool = False): #todo: add the RecordEpisodeMetrics for gym. + env = gym.make(config.env.scenario) + wrapped_env = GymWrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) + if not config.env.implicit_agent_id: + pass #todo : add agent id wrapper for gym . + return wrapped_env + + num_env = config.arch.num_envs + train_env = gym.vector.async_vector_env([create_gym_env(config, add_global_state) for _ in range(num_env)]) + eval_env = gym.vector.async_vector_env([create_gym_env(config, add_global_state, eval_env=True) for _ in range(num_env)]) + return train_env, eval_env def make(config: DictConfig, add_global_state: bool = False) -> Tuple[Environment, Environment]: """ @@ -220,5 +246,7 @@ def make(config: DictConfig, add_global_state: bool = False) -> Tuple[Environmen return make_matrax_env(env_name, config, add_global_state) elif env_name in _gigastep_registry: return make_gigastep_env(env_name, config, add_global_state) + elif env_name.startswith("gym"): + return make_gym_env(env_name, config, add_global_state) else: raise ValueError(f"{env_name} is not a supported environment.") diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 91bf7b4c4..7fd63ecbc 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -24,3 +24,4 @@ ) from mava.wrappers.matrax import MatraxWrapper from mava.wrappers.observation import AgentIDWrapper +from mava.wrappers.gym import GymWrapper diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py new file mode 100644 index 000000000..f1ea5004b --- /dev/null +++ b/mava/wrappers/gym.py @@ -0,0 +1,92 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +import numpy as np +from gym.spaces import Box, MultiDiscrete +from typing import TYPE_CHECKING, Dict, Tuple, Union + + +class GymWrapper(gym.Wrapper): + """Wrapper for gym environments""" + + def __init__(self, env: gym.env, use_individual_rewards : bool = False,add_global_state : bool = False, eval_env : bool = False): + """Initialize the gym wrapper + + Args: + env (gym.env): gym env instance. + use_individual_rewards (bool, optional): Use individual or group rewards. Defaults to False. + add_global_state (bool, optional) : Create global observations. Defaults to False. + eval_env (bool, optional): Weather the instance is used for training or evaluation. Defaults to False. + """ + super().__init__(env) + self._env = env + self.use_individual_rewards = use_individual_rewards + self.add_global_state = add_global_state #todo : add the global observations + self.eval_env = eval_env + self.num_agents = self._env.n_agents + self.num_actions = self._env.action_space[0].n #todo: all the agents must have the same num_actions, add assertion? + + def reset(self): + + obs, extra = self._env.reset(seed = np.random.randint(), option = {}) #todo: assure reproducibility + reward = np.zeros(self._env.n_agents) + terminated, truncated = np.zeros(self._env.n_agents , dtype=bool), np.zeros(self._env.n_agents , dtype=bool) + actions_mask = self._get_actions_mask(extra) + + + return np.array(obs), actions_mask, reward, terminated, truncated, extra + + def step(self , actions : np.array): + + if self._reset_next_step and not self.eval_env: #only auto-reset in training envs. + return self.reset() + + obs, reward, terminated, truncated, extra = self.env.step(actions) + + terminated, truncated = np.array(terminated), np.array(truncated) + + done = np.logical_or(terminated, truncated).all() + + if done and not self.eval_env: #only auto-reset in training envs, same functionality as the AutoResetWrapper. + return self.reset() + + actions_mask = self._get_actions_mask(extra) + + + + if self.use_individual_rewards: + reward = np.array(reward) + else: + reward = np.array([np.array(reward).mean()] * self.num_agents) + + return np.array(obs), actions_mask, reward, terminated, truncated, extra + + + def _get_actions_mask(self, extra : Dict) -> np.array: + if "action_mask" in extra: + return np.array(extra["action_mask"]) + return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + + + + + + + + + + + + \ No newline at end of file diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 5efd3bbe1..88c61ce0f 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -21,3 +21,4 @@ scipy==1.12.0 tensorboard_logger tensorflow_probability type_enforced # needed because gigastep is missing this dependency +rware @ git+https://github.com/RuanJohn/robotic-warehouse.git \ No newline at end of file From ce86d096060f8fad5e4ef1ddd587cc33b06da692 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 10 Jun 2024 11:54:24 +0100 Subject: [PATCH 002/125] chore : pre-commit hooks --- mava/configs/arch/sebulba.yaml | 2 +- mava/utils/make_env.py | 27 +++++++--- mava/wrappers/__init__.py | 2 +- mava/wrappers/gym.py | 94 +++++++++++++++++----------------- requirements/requirements.txt | 2 +- 5 files changed, 69 insertions(+), 58 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index ed1d07dff..98cd4d96d 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -21,4 +21,4 @@ concurrency: False # whether actor and learner should run concurrently async_envs: True # "whether to use async vector or sync vector envs" # --- To be defined during training --- -log_frequency: ~ \ No newline at end of file +log_frequency: ~ diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index c66d585f5..44758b41d 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -14,11 +14,11 @@ from typing import Tuple +import gym import gym.vector import jaxmarl import jumanji import matrax -import gym from gigastep import ScenarioBuilder from jaxmarl.environments.smax import map_name_to_scenario from jumanji.env import Environment @@ -42,13 +42,13 @@ CleanerWrapper, ConnectorWrapper, GigastepWrapper, + GymWrapper, LbfWrapper, MabraxWrapper, MatraxWrapper, RecordEpisodeMetrics, RwareWrapper, SmaxWrapper, - GymWrapper, ) # Registry mapping environment names to their generator and wrapper classes. @@ -201,7 +201,10 @@ def make_gigastep_env( train_env, eval_env = add_extra_wrappers(train_env, eval_env, config) return train_env, eval_env -def make_gym_env(env_name: str, config: DictConfig, add_global_state: bool = False): + +def make_gym_env( + env_name: str, config: DictConfig, add_global_state: bool = False +) -> Tuple[Environment, Environment]: #todo : create the appropriate annotation for the sync vector """ Create a Gym environment. @@ -213,18 +216,26 @@ def make_gym_env(env_name: str, config: DictConfig, add_global_state: bool = Fal Returns: A tuple of the environments. """ - def create_gym_env(config: DictConfig, add_global_state: bool = False, eval_env : bool = False): #todo: add the RecordEpisodeMetrics for gym. + + def create_gym_env( + config: DictConfig, add_global_state: bool = False, eval_env: bool = False + ) -> Environment: # todo: add the RecordEpisodeMetrics for gym. env = gym.make(config.env.scenario) wrapped_env = GymWrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) if not config.env.implicit_agent_id: - pass #todo : add agent id wrapper for gym . + pass # todo : add agent id wrapper for gym . return wrapped_env - + num_env = config.arch.num_envs - train_env = gym.vector.async_vector_env([create_gym_env(config, add_global_state) for _ in range(num_env)]) - eval_env = gym.vector.async_vector_env([create_gym_env(config, add_global_state, eval_env=True) for _ in range(num_env)]) + train_env = gym.vector.async_vector_env( + [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)] + ) + eval_env = gym.vector.async_vector_env( + [create_gym_env(config, add_global_state, eval_env=True) for _ in range(num_env)] + ) return train_env, eval_env + def make(config: DictConfig, add_global_state: bool = False) -> Tuple[Environment, Environment]: """ Create environments for training and evaluation.. diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 7fd63ecbc..14a679cac 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,6 +15,7 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper +from mava.wrappers.gym import GymWrapper from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, @@ -24,4 +25,3 @@ ) from mava.wrappers.matrax import MatraxWrapper from mava.wrappers.observation import AgentIDWrapper -from mava.wrappers.gym import GymWrapper diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index f1ea5004b..9c4d8b74d 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -12,81 +12,81 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Tuple + import gym import numpy as np -from gym.spaces import Box, MultiDiscrete -from typing import TYPE_CHECKING, Dict, Tuple, Union +from numpy.typing import NDArray class GymWrapper(gym.Wrapper): """Wrapper for gym environments""" - - def __init__(self, env: gym.env, use_individual_rewards : bool = False,add_global_state : bool = False, eval_env : bool = False): + + def __init__( + self, + env: gym.env, + use_individual_rewards: bool = False, + add_global_state: bool = False, + eval_env: bool = False, + ): """Initialize the gym wrapper Args: env (gym.env): gym env instance. - use_individual_rewards (bool, optional): Use individual or group rewards. Defaults to False. + use_individual_rewards (bool, optional): Use individual or group rewards. + Defaults to False. add_global_state (bool, optional) : Create global observations. Defaults to False. - eval_env (bool, optional): Weather the instance is used for training or evaluation. Defaults to False. + eval_env (bool, optional): Weather the instance is used for training or evaluation. + Defaults to False. """ super().__init__(env) self._env = env self.use_individual_rewards = use_individual_rewards - self.add_global_state = add_global_state #todo : add the global observations + self.add_global_state = add_global_state # todo : add the global observations self.eval_env = eval_env self.num_agents = self._env.n_agents - self.num_actions = self._env.action_space[0].n #todo: all the agents must have the same num_actions, add assertion? - - def reset(self): - - obs, extra = self._env.reset(seed = np.random.randint(), option = {}) #todo: assure reproducibility + self.num_actions = self._env.action_space[ + 0 + ].n # todo: all the agents must have the same num_actions, add assertion? + + def reset(self) -> Tuple: + obs, extra = self._env.reset( + seed=np.random.randint(1), option={} + ) # todo: assure reproducibility reward = np.zeros(self._env.n_agents) - terminated, truncated = np.zeros(self._env.n_agents , dtype=bool), np.zeros(self._env.n_agents , dtype=bool) + terminated, truncated = np.zeros(self._env.n_agents, dtype=bool), np.zeros( + self._env.n_agents, dtype=bool + ) actions_mask = self._get_actions_mask(extra) - - - return np.array(obs), actions_mask, reward, terminated, truncated, extra - - def step(self , actions : np.array): - - if self._reset_next_step and not self.eval_env: #only auto-reset in training envs. + + return np.array(obs), actions_mask, reward, terminated, truncated, extra + + def step(self, actions: NDArray) -> Tuple: + + if self._reset_next_step and not self.eval_env: # only auto-reset in training envs. return self.reset() - + obs, reward, terminated, truncated, extra = self.env.step(actions) - + terminated, truncated = np.array(terminated), np.array(truncated) - - done = np.logical_or(terminated, truncated).all() - - if done and not self.eval_env: #only auto-reset in training envs, same functionality as the AutoResetWrapper. + + done = np.logical_or(terminated, truncated).all() + + if ( + done and not self.eval_env + ): # only auto-reset in training envs, same functionality as the AutoResetWrapper. return self.reset() - + actions_mask = self._get_actions_mask(extra) - - if self.use_individual_rewards: reward = np.array(reward) else: reward = np.array([np.array(reward).mean()] * self.num_agents) - - return np.array(obs), actions_mask, reward, terminated, truncated, extra - - - def _get_actions_mask(self, extra : Dict) -> np.array: + + return np.array(obs), actions_mask, reward, terminated, truncated, extra + + def _get_actions_mask(self, extra: Dict) -> NDArray: if "action_mask" in extra: - return np.array(extra["action_mask"]) + return np.array(extra["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) - - - - - - - - - - - - \ No newline at end of file diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 88c61ce0f..3b3bc4c58 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -17,8 +17,8 @@ numpy omegaconf optax protobuf~=3.20 +rware @ git+https://github.com/RuanJohn/robotic-warehouse.git scipy==1.12.0 tensorboard_logger tensorflow_probability type_enforced # needed because gigastep is missing this dependency -rware @ git+https://github.com/RuanJohn/robotic-warehouse.git \ No newline at end of file From d5edf4540092e98c44832863950f23ef976a64b2 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 14 Jun 2024 12:00:56 +0100 Subject: [PATCH 003/125] fix: merged the observations and action mask --- mava/utils/make_env.py | 4 +++- mava/wrappers/gym.py | 20 ++++++++++++++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 44758b41d..22419a4bb 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -204,7 +204,9 @@ def make_gigastep_env( def make_gym_env( env_name: str, config: DictConfig, add_global_state: bool = False -) -> Tuple[Environment, Environment]: #todo : create the appropriate annotation for the sync vector +) -> Tuple[ + Environment, Environment +]: # todo : create the appropriate annotation for the sync vector """ Create a Gym environment. diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 9c4d8b74d..f634dcc46 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -18,6 +18,8 @@ import numpy as np from numpy.typing import NDArray +from mava.types import Observation + class GymWrapper(gym.Wrapper): """Wrapper for gym environments""" @@ -48,9 +50,10 @@ def __init__( self.num_actions = self._env.action_space[ 0 ].n # todo: all the agents must have the same num_actions, add assertion? + self.step_count = 0 # todo : make sure this implementaion is correct def reset(self) -> Tuple: - obs, extra = self._env.reset( + agents_view, extra = self._env.reset( seed=np.random.randint(1), option={} ) # todo: assure reproducibility reward = np.zeros(self._env.n_agents) @@ -59,14 +62,19 @@ def reset(self) -> Tuple: ) actions_mask = self._get_actions_mask(extra) - return np.array(obs), actions_mask, reward, terminated, truncated, extra + obs = Observation( + agents_view=np.array(agents_view), action_mask=actions_mask, step_count=self.step_count + ) + + return obs, reward, terminated, truncated, extra def step(self, actions: NDArray) -> Tuple: + self.step_count += 1 if self._reset_next_step and not self.eval_env: # only auto-reset in training envs. return self.reset() - obs, reward, terminated, truncated, extra = self.env.step(actions) + agents_view, reward, terminated, truncated, extra = self.env.step(actions) terminated, truncated = np.array(terminated), np.array(truncated) @@ -84,7 +92,11 @@ def step(self, actions: NDArray) -> Tuple: else: reward = np.array([np.array(reward).mean()] * self.num_agents) - return np.array(obs), actions_mask, reward, terminated, truncated, extra + obs = Observation( + agents_view=np.array(agents_view), action_mask=actions_mask, step_count=self.step_count + ) + + return obs, reward, terminated, truncated, extra def _get_actions_mask(self, extra: Dict) -> NDArray: if "action_mask" in extra: From f891be555886f0a1ed415683bb499cf32605eb4c Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 14 Jun 2024 12:38:00 +0100 Subject: [PATCH 004/125] fix: Create the gym wrappers directly --- mava/utils/make_env.py | 14 +++++--------- mava/wrappers/gym.py | 3 ++- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 22419a4bb..ed4cec124 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -203,7 +203,7 @@ def make_gigastep_env( def make_gym_env( - env_name: str, config: DictConfig, add_global_state: bool = False + env_name: str, config: DictConfig, add_global_state: bool = False , eval_env : bool = False ) -> Tuple[ Environment, Environment ]: # todo : create the appropriate annotation for the sync vector @@ -229,13 +229,11 @@ def create_gym_env( return wrapped_env num_env = config.arch.num_envs - train_env = gym.vector.async_vector_env( - [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)] - ) - eval_env = gym.vector.async_vector_env( - [create_gym_env(config, add_global_state, eval_env=True) for _ in range(num_env)] + envs = gym.vector.async_vector_env( + [lambda: create_gym_env(config, add_global_state, eval_env=eval_env) for _ in range(num_env)] ) - return train_env, eval_env + + return envs def make(config: DictConfig, add_global_state: bool = False) -> Tuple[Environment, Environment]: @@ -259,7 +257,5 @@ def make(config: DictConfig, add_global_state: bool = False) -> Tuple[Environmen return make_matrax_env(env_name, config, add_global_state) elif env_name in _gigastep_registry: return make_gigastep_env(env_name, config, add_global_state) - elif env_name.startswith("gym"): - return make_gym_env(env_name, config, add_global_state) else: raise ValueError(f"{env_name} is not a supported environment.") diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index f634dcc46..2c06f7e86 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -71,7 +71,7 @@ def reset(self) -> Tuple: def step(self, actions: NDArray) -> Tuple: self.step_count += 1 - if self._reset_next_step and not self.eval_env: # only auto-reset in training envs. + if self._reset_next_step and not self.eval_env: # only auto-reset in training envs. todo: turn this into a sepreat wrapper return self.reset() agents_view, reward, terminated, truncated, extra = self.env.step(actions) @@ -102,3 +102,4 @@ def _get_actions_mask(self, extra: Dict) -> NDArray: if "action_mask" in extra: return np.array(extra["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + From 15f486709e6387dddce83900bed95b85521260e4 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 14 Jun 2024 12:39:10 +0100 Subject: [PATCH 005/125] chore: pre-commit --- mava/utils/make_env.py | 13 +++++++------ mava/wrappers/gym.py | 5 +++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index ed4cec124..01d2a2eb0 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -203,10 +203,8 @@ def make_gigastep_env( def make_gym_env( - env_name: str, config: DictConfig, add_global_state: bool = False , eval_env : bool = False -) -> Tuple[ - Environment, Environment -]: # todo : create the appropriate annotation for the sync vector + env_name: str, config: DictConfig, add_global_state: bool = False, eval_env: bool = False +) -> Environment: # todo : create the appropriate annotation for the sync vector """ Create a Gym environment. @@ -230,9 +228,12 @@ def create_gym_env( num_env = config.arch.num_envs envs = gym.vector.async_vector_env( - [lambda: create_gym_env(config, add_global_state, eval_env=eval_env) for _ in range(num_env)] + [ + lambda: create_gym_env(config, add_global_state, eval_env=eval_env) + for _ in range(num_env) + ] ) - + return envs diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 2c06f7e86..0cbfbc751 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -71,7 +71,9 @@ def reset(self) -> Tuple: def step(self, actions: NDArray) -> Tuple: self.step_count += 1 - if self._reset_next_step and not self.eval_env: # only auto-reset in training envs. todo: turn this into a sepreat wrapper + if ( + self._reset_next_step and not self.eval_env + ): # only auto-reset in training envs. todo: turn this into a sepreat wrapper return self.reset() agents_view, reward, terminated, truncated, extra = self.env.step(actions) @@ -102,4 +104,3 @@ def _get_actions_mask(self, extra: Dict) -> NDArray: if "action_mask" in extra: return np.array(extra["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) - From 82ea827e0e7cf0bcc8ab269877050064ca25b3b7 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 14 Jun 2024 12:47:54 +0100 Subject: [PATCH 006/125] fix: fixed the async env creation --- mava/utils/make_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 01d2a2eb0..d40249c54 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -227,7 +227,7 @@ def create_gym_env( return wrapped_env num_env = config.arch.num_envs - envs = gym.vector.async_vector_env( + envs = gym.vector.AsyncVectorEnv( [ lambda: create_gym_env(config, add_global_state, eval_env=eval_env) for _ in range(num_env) From 4e94df57880b4c6370e2da4489961e5339044eb8 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 14 Jun 2024 14:34:50 +0100 Subject: [PATCH 007/125] fix: gymV26 compatability wrapper --- mava/configs/env/gym.yaml | 21 +++++++++++++++++++++ mava/utils/make_env.py | 4 ++++ 2 files changed, 25 insertions(+) create mode 100644 mava/configs/env/gym.yaml diff --git a/mava/configs/env/gym.yaml b/mava/configs/env/gym.yaml new file mode 100644 index 000000000..ad8d16b9a --- /dev/null +++ b/mava/configs/env/gym.yaml @@ -0,0 +1,21 @@ +# ---Environment Configs--- + +scenario: rware:rware-tiny-2ag-v1 # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] + +env_name: RobotWarehouse # Used for logging purposes. + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. +log_win_rate: False + +use_individual_rewards: True + +kwargs: + time_limit: 500 diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index d40249c54..806883786 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -16,6 +16,7 @@ import gym import gym.vector +import gym.wrappers import jaxmarl import jumanji import matrax @@ -221,6 +222,9 @@ def create_gym_env( config: DictConfig, add_global_state: bool = False, eval_env: bool = False ) -> Environment: # todo: add the RecordEpisodeMetrics for gym. env = gym.make(config.env.scenario) + env = gym.wrappers.EnvCompatibility( + env + ) # todo: check if this will break if env is developed for v26 wrapped_env = GymWrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) if not config.env.implicit_agent_id: pass # todo : add agent id wrapper for gym . From 8a86be98f4f422bfaa627d10eb27c88bb40557ae Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 15 Jun 2024 15:36:31 +0100 Subject: [PATCH 008/125] fix: various minor fixes --- mava/utils/make_env.py | 6 ++++-- mava/wrappers/gym.py | 14 +++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 806883786..1515cca0c 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -17,6 +17,8 @@ import gym import gym.vector import gym.wrappers +import gym.wrappers +import gym.wrappers.compatibility import jaxmarl import jumanji import matrax @@ -204,7 +206,7 @@ def make_gigastep_env( def make_gym_env( - env_name: str, config: DictConfig, add_global_state: bool = False, eval_env: bool = False + config: DictConfig, add_global_state: bool = False, eval_env: bool = False ) -> Environment: # todo : create the appropriate annotation for the sync vector """ Create a Gym environment. @@ -222,7 +224,7 @@ def create_gym_env( config: DictConfig, add_global_state: bool = False, eval_env: bool = False ) -> Environment: # todo: add the RecordEpisodeMetrics for gym. env = gym.make(config.env.scenario) - env = gym.wrappers.EnvCompatibility( + env = gym.wrappers.compatibility.EnvCompatibility( env ) # todo: check if this will break if env is developed for v26 wrapped_env = GymWrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 0cbfbc751..99b56d621 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -26,7 +26,7 @@ class GymWrapper(gym.Wrapper): def __init__( self, - env: gym.env, + env: gym.Env, use_individual_rewards: bool = False, add_global_state: bool = False, eval_env: bool = False, @@ -46,7 +46,7 @@ def __init__( self.use_individual_rewards = use_individual_rewards self.add_global_state = add_global_state # todo : add the global observations self.eval_env = eval_env - self.num_agents = self._env.n_agents + self.num_agents = len(self._env.action_space) self.num_actions = self._env.action_space[ 0 ].n # todo: all the agents must have the same num_actions, add assertion? @@ -54,11 +54,11 @@ def __init__( def reset(self) -> Tuple: agents_view, extra = self._env.reset( - seed=np.random.randint(1), option={} + seed=np.random.randint(1) ) # todo: assure reproducibility - reward = np.zeros(self._env.n_agents) - terminated, truncated = np.zeros(self._env.n_agents, dtype=bool), np.zeros( - self._env.n_agents, dtype=bool + reward = np.zeros(self.num_agents) + terminated, truncated = np.zeros(self.num_agents, dtype=bool), np.zeros( + self.num_agents, dtype=bool ) actions_mask = self._get_actions_mask(extra) @@ -103,4 +103,4 @@ def step(self, actions: NDArray) -> Tuple: def _get_actions_mask(self, extra: Dict) -> NDArray: if "action_mask" in extra: return np.array(extra["action_mask"]) - return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + return np.ones((self.num_agents, self.num_actions), dtype=np.float32) \ No newline at end of file From 1da5c15b13c74c8286819cca9b36277bf8030a27 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 15 Jun 2024 16:09:16 +0100 Subject: [PATCH 009/125] fix: handling rware reset function --- mava/utils/make_env.py | 2 +- mava/wrappers/gym.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 1515cca0c..1e2721dc6 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -226,7 +226,7 @@ def create_gym_env( env = gym.make(config.env.scenario) env = gym.wrappers.compatibility.EnvCompatibility( env - ) # todo: check if this will break if env is developed for v26 + ) wrapped_env = GymWrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) if not config.env.implicit_agent_id: pass # todo : add agent id wrapper for gym . diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 99b56d621..fff21a899 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -53,9 +53,9 @@ def __init__( self.step_count = 0 # todo : make sure this implementaion is correct def reset(self) -> Tuple: - agents_view, extra = self._env.reset( + (agents_view, extra), _ = self._env.reset( seed=np.random.randint(1) - ) # todo: assure reproducibility + ) # todo: assure reproducibility, this only works for rware reward = np.zeros(self.num_agents) terminated, truncated = np.zeros(self.num_agents, dtype=bool), np.zeros( self.num_agents, dtype=bool From 4466044d07541fb3e48b56f42c26be2a235a3e31 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sun, 16 Jun 2024 18:58:27 +0100 Subject: [PATCH 010/125] feat: async env wrapper , changed the gym wrapper to rware wrapper --- mava/configs/default_ff_ippo_seb.yaml | 7 +++ mava/systems/sebulba/ppo/test.py | 50 ++++++++++++++++++ mava/utils/make_env.py | 19 ++++--- mava/wrappers/__init__.py | 2 +- mava/wrappers/gym.py | 75 +++++++++++++++++---------- 5 files changed, 117 insertions(+), 36 deletions(-) create mode 100644 mava/configs/default_ff_ippo_seb.yaml create mode 100644 mava/systems/sebulba/ppo/test.py diff --git a/mava/configs/default_ff_ippo_seb.yaml b/mava/configs/default_ff_ippo_seb.yaml new file mode 100644 index 000000000..1002d90c4 --- /dev/null +++ b/mava/configs/default_ff_ippo_seb.yaml @@ -0,0 +1,7 @@ +defaults: + - logger: ff_ippo + - arch: sebulba + - system: ppo/ff_ippo + - network: mlp + - env: gym + - _self_ diff --git a/mava/systems/sebulba/ppo/test.py b/mava/systems/sebulba/ppo/test.py new file mode 100644 index 000000000..b868f69b6 --- /dev/null +++ b/mava/systems/sebulba/ppo/test.py @@ -0,0 +1,50 @@ + +import copy +import time +from typing import Any, Dict, Tuple, List +import threading +import chex +import flax +import hydra +import jax +import jax.numpy as jnp +import numpy as np +import optax +import queue +from collections import deque +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from omegaconf import DictConfig, OmegaConf +from optax._src.base import OptState +from rich.pretty import pprint + +from mava.evaluator import make_eval_fns +from mava.networks import FeedForwardActor as Actor +from mava.networks import FeedForwardValueNet as Critic +from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition #todo: change this +from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, Observation +from mava.utils import make_env as environments +from mava.utils.checkpointing import Checkpointer +from mava.utils.jax_utils import ( + merge_leading_dims, + unreplicate_batch_dim, + unreplicate_n_dims, +) +from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.training import make_learning_rate +from mava.wrappers.episode_metrics import get_final_step_metrics + + +@hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> float: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + env = environments.make_gym_env(cfg) + a = env.reset() + print(a) + +if __name__ == "__main__": + hydra_entry_point() \ No newline at end of file diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 1e2721dc6..61b379fd7 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -45,7 +45,8 @@ CleanerWrapper, ConnectorWrapper, GigastepWrapper, - GymWrapper, + GymRwareWrapper, + AsyncGymWrapper, LbfWrapper, MabraxWrapper, MatraxWrapper, @@ -69,6 +70,8 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} +_gym_registry = {"rware" : GymRwareWrapper} + def add_extra_wrappers( train_env: Environment, eval_env: Environment, config: DictConfig @@ -219,27 +222,27 @@ def make_gym_env( Returns: A tuple of the environments. """ + base_env_name = config.env.scenario.split(":")[0] + wrapper = _gym_registry[base_env_name] def create_gym_env( config: DictConfig, add_global_state: bool = False, eval_env: bool = False ) -> Environment: # todo: add the RecordEpisodeMetrics for gym. env = gym.make(config.env.scenario) - env = gym.wrappers.compatibility.EnvCompatibility( - env - ) - wrapped_env = GymWrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) + _gym_registry + wrapped_env = wrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) if not config.env.implicit_agent_id: pass # todo : add agent id wrapper for gym . return wrapped_env - num_env = config.arch.num_envs - envs = gym.vector.AsyncVectorEnv( + num_env = config.arch.num_envs + envs = gym.vector.AsyncVectorEnv( #todo : give them more descriptive names [ lambda: create_gym_env(config, add_global_state, eval_env=eval_env) for _ in range(num_env) ] ) - + envs = AsyncGymWrapper(envs) return envs diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 14a679cac..6210ca6ed 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,7 +15,7 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper -from mava.wrappers.gym import GymWrapper +from mava.wrappers.gym import GymRwareWrapper, AsyncGymWrapper from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index fff21a899..bc71e3e81 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -21,8 +21,8 @@ from mava.types import Observation -class GymWrapper(gym.Wrapper): - """Wrapper for gym environments""" +class GymRwareWrapper(gym.Wrapper): + """Wrapper for rware gym environments""" def __init__( self, @@ -42,7 +42,7 @@ def __init__( Defaults to False. """ super().__init__(env) - self._env = env + self._env = gym.wrappers.compatibility.EnvCompatibility(env) self.use_individual_rewards = use_individual_rewards self.add_global_state = add_global_state # todo : add the global observations self.eval_env = eval_env @@ -50,33 +50,29 @@ def __init__( self.num_actions = self._env.action_space[ 0 ].n # todo: all the agents must have the same num_actions, add assertion? - self.step_count = 0 # todo : make sure this implementaion is correct def reset(self) -> Tuple: - (agents_view, extra), _ = self._env.reset( + (agents_view, info), _ = self._env.reset( seed=np.random.randint(1) ) # todo: assure reproducibility, this only works for rware - reward = np.zeros(self.num_agents) - terminated, truncated = np.zeros(self.num_agents, dtype=bool), np.zeros( - self.num_agents, dtype=bool - ) - actions_mask = self._get_actions_mask(extra) - - obs = Observation( - agents_view=np.array(agents_view), action_mask=actions_mask, step_count=self.step_count - ) - return obs, reward, terminated, truncated, extra + info["action_mask"] = self._get_actions_mask(info) + + return np.array(agents_view), info def step(self, actions: NDArray) -> Tuple: - self.step_count += 1 if ( self._reset_next_step and not self.eval_env ): # only auto-reset in training envs. todo: turn this into a sepreat wrapper - return self.reset() + agents_view, info = self.reset() + reward = np.zeros(self.num_agents) + terminated, truncated = np.zeros(self.num_agents, dtype=bool), np.zeros( + self.num_agents, dtype=bool + ) + return agents_view, reward, terminated, truncated, info - agents_view, reward, terminated, truncated, extra = self.env.step(actions) + agents_view, reward, terminated, truncated, info = self.env.step(actions) terminated, truncated = np.array(terminated), np.array(truncated) @@ -87,20 +83,45 @@ def step(self, actions: NDArray) -> Tuple: ): # only auto-reset in training envs, same functionality as the AutoResetWrapper. return self.reset() - actions_mask = self._get_actions_mask(extra) + info["action_mask"] = self._get_actions_mask(info) if self.use_individual_rewards: reward = np.array(reward) else: reward = np.array([np.array(reward).mean()] * self.num_agents) - obs = Observation( - agents_view=np.array(agents_view), action_mask=actions_mask, step_count=self.step_count - ) - return obs, reward, terminated, truncated, extra + return agents_view, reward, terminated, truncated, info - def _get_actions_mask(self, extra: Dict) -> NDArray: - if "action_mask" in extra: - return np.array(extra["action_mask"]) - return np.ones((self.num_agents, self.num_actions), dtype=np.float32) \ No newline at end of file + def _get_actions_mask(self, info: Dict) -> NDArray: + if "action_mask" in info: + return np.array(info["action_mask"]) + return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + + +class AsyncGymWrapper: + """Wrapper for async gym environments""" + + def __init__(self, env: gym.vector.AsyncVectorEnv): + self._env = env + self.step_count = 0 #todo : make sure this is implemented correctly + + def reset(self) -> Tuple[Observation, Dict]: + agents_view , info = self._env.reset() + obs = self._create_obs(agents_view, info) + return obs, info + + def step(self) -> Tuple[Observation, NDArray, NDArray, NDArray, Dict]: + + self.step_count += 1 + agents_view, reward, terminated, truncated, info = self._env.step() + obs = self._create_obs(agents_view, info) + + return obs, reward, terminated, truncated, info + + + def _create_obs(self, agents_view : NDArray, info: Dict) -> Observation: + """Create the observations""" + agents_view = np.array(agents_view) + return Observation(agents_view=agents_view, action_mask=info["action_mask"], step_count=self.step_count) + \ No newline at end of file From 24d8aaefb596904e5fd9e0be813947405a3ecdaa Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sun, 16 Jun 2024 22:43:55 +0100 Subject: [PATCH 011/125] fix: fixed the async env wrapper --- mava/wrappers/gym.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index bc71e3e81..2c6597830 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -62,26 +62,19 @@ def reset(self) -> Tuple: def step(self, actions: NDArray) -> Tuple: - if ( - self._reset_next_step and not self.eval_env - ): # only auto-reset in training envs. todo: turn this into a sepreat wrapper - agents_view, info = self.reset() - reward = np.zeros(self.num_agents) - terminated, truncated = np.zeros(self.num_agents, dtype=bool), np.zeros( - self.num_agents, dtype=bool - ) - return agents_view, reward, terminated, truncated, info - agents_view, reward, terminated, truncated, info = self.env.step(actions) - terminated, truncated = np.array(terminated), np.array(truncated) - done = np.logical_or(terminated, truncated).all() if ( done and not self.eval_env ): # only auto-reset in training envs, same functionality as the AutoResetWrapper. - return self.reset() + agents_view, info = self.reset() + reward = np.zeros(self.num_agents) + terminated, truncated = np.zeros(self.num_agents, dtype=bool), np.zeros( + self.num_agents, dtype=bool + ) + return agents_view, reward, terminated, truncated, info info["action_mask"] = self._get_actions_mask(info) @@ -99,22 +92,29 @@ def _get_actions_mask(self, info: Dict) -> NDArray: return np.ones((self.num_agents, self.num_actions), dtype=np.float32) -class AsyncGymWrapper: +class AsyncGymWrapper(gym.Wrapper): """Wrapper for async gym environments""" def __init__(self, env: gym.vector.AsyncVectorEnv): + super().__init__(env) self._env = env self.step_count = 0 #todo : make sure this is implemented correctly + action_space = env.single_action_space + self.num_agents = len(action_space) + self.num_actions = action_space[0].n + self.num_envs = env.num_envs + def reset(self) -> Tuple[Observation, Dict]: agents_view , info = self._env.reset() obs = self._create_obs(agents_view, info) - return obs, info + dones = np.zeros((self.num_envs, 1)) + return obs, dones, info - def step(self) -> Tuple[Observation, NDArray, NDArray, NDArray, Dict]: + def step(self, actions : NDArray) -> Tuple[Observation, NDArray, NDArray, NDArray, Dict]: self.step_count += 1 - agents_view, reward, terminated, truncated, info = self._env.step() + agents_view, reward, terminated, truncated, info = self._env.step(actions) obs = self._create_obs(agents_view, info) return obs, reward, terminated, truncated, info From a6deae270fbbd8bbb81c8fc507e5c974f10f66df Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 18 Jun 2024 16:24:16 +0100 Subject: [PATCH 012/125] fix: info only contains the action_mask and reformated (n_agents, n_env) ->(n_env, n_agents) --- mava/utils/make_env.py | 1 - mava/wrappers/gym.py | 15 ++++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 61b379fd7..7f5a5a0fb 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -229,7 +229,6 @@ def create_gym_env( config: DictConfig, add_global_state: bool = False, eval_env: bool = False ) -> Environment: # todo: add the RecordEpisodeMetrics for gym. env = gym.make(config.env.scenario) - _gym_registry wrapped_env = wrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) if not config.env.implicit_agent_id: pass # todo : add agent id wrapper for gym . diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 2c6597830..be4fe40fc 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -56,7 +56,7 @@ def reset(self) -> Tuple: seed=np.random.randint(1) ) # todo: assure reproducibility, this only works for rware - info["action_mask"] = self._get_actions_mask(info) + info = {"action_mask" : self._get_actions_mask(info)} return np.array(agents_view), info @@ -76,7 +76,7 @@ def step(self, actions: NDArray) -> Tuple: ) return agents_view, reward, terminated, truncated, info - info["action_mask"] = self._get_actions_mask(info) + info = {"action_mask" : self._get_actions_mask(info)} if self.use_individual_rewards: reward = np.array(reward) @@ -108,20 +108,21 @@ def __init__(self, env: gym.vector.AsyncVectorEnv): def reset(self) -> Tuple[Observation, Dict]: agents_view , info = self._env.reset() obs = self._create_obs(agents_view, info) - dones = np.zeros((self.num_envs, 1)) + dones = np.zeros((self.num_envs, self.num_agents)) return obs, dones, info def step(self, actions : NDArray) -> Tuple[Observation, NDArray, NDArray, NDArray, Dict]: self.step_count += 1 + actions = actions.swapaxes(0,1) # num_env, num_ags --> num_ags, num_env as expected by the async env agents_view, reward, terminated, truncated, info = self._env.step(actions) obs = self._create_obs(agents_view, info) - - return obs, reward, terminated, truncated, info + dones = np.logical_or(terminated, truncated) + return obs, reward, dones, info def _create_obs(self, agents_view : NDArray, info: Dict) -> Observation: """Create the observations""" - agents_view = np.array(agents_view) - return Observation(agents_view=agents_view, action_mask=info["action_mask"], step_count=self.step_count) + agents_view = np.stack(agents_view, axis = 1) + return Observation(agents_view=agents_view, action_mask=np.stack(info["action_mask"], axis = 0), step_count=self.step_count) \ No newline at end of file From 1475bd0d7ae465dbdce4b86aa02d55df487ae588 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 22 Jun 2024 12:02:19 +0100 Subject: [PATCH 013/125] chore: removed async gym wrapper --- mava/systems/sebulba/ppo/types.py | 99 +++++++++++++++++++++++++++++++ mava/utils/make_env.py | 3 +- mava/wrappers/gym.py | 44 ++------------ 3 files changed, 106 insertions(+), 40 deletions(-) create mode 100644 mava/systems/sebulba/ppo/types.py diff --git a/mava/systems/sebulba/ppo/types.py b/mava/systems/sebulba/ppo/types.py new file mode 100644 index 000000000..13aeb58c1 --- /dev/null +++ b/mava/systems/sebulba/ppo/types.py @@ -0,0 +1,99 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import chex +from flax.core.frozen_dict import FrozenDict +from jumanji.types import TimeStep +from optax._src.base import OptState +from typing_extensions import NamedTuple + +from mava.types import Action, Done, HiddenState, State, Value + + +class Params(NamedTuple): + """Parameters of an actor critic network.""" + + actor_params: FrozenDict + critic_params: FrozenDict + + +class OptStates(NamedTuple): + """OptStates of actor critic learner.""" + + actor_opt_state: OptState + critic_opt_state: OptState + + +class HiddenStates(NamedTuple): + """Hidden states for an actor critic learner.""" + + policy_hidden_state: HiddenState + critic_hidden_state: HiddenState + + +class LearnerState(NamedTuple): + """State of the learner.""" + + params: Params + opt_states: OptStates + key: chex.PRNGKey + env_state: State + timestep: TimeStep + + +class RNNLearnerState(NamedTuple): + """State of the `Learner` for recurrent architectures.""" + + params: Params + opt_states: OptStates + key: chex.PRNGKey + env_state: State + timestep: TimeStep + dones: Done + hstates: HiddenStates + + +class PPOTransition(NamedTuple): + """Transition tuple for PPO.""" + + done: Done + action: Action + value: Value + reward: chex.Array + log_prob: chex.Array + obs: chex.Array + info : Dict + +class RNNPPOTransition(NamedTuple): + """Transition tuple for PPO.""" + + done: Done + action: Action + value: Value + reward: chex.Array + log_prob: chex.Array + obs: chex.Array + hstates: HiddenStates + + +class Observation(NamedTuple): + """The observation that the agent sees. + agents_view: the agent's view of the environment. + action_mask: boolean array specifying, for each agent, which action is legal. + """ + + agents_view: chex.Array # (num_agents, num_obs_features) + action_mask: chex.Array # (num_agents, num_actions) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 7f5a5a0fb..8ee391f0c 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -46,7 +46,6 @@ ConnectorWrapper, GigastepWrapper, GymRwareWrapper, - AsyncGymWrapper, LbfWrapper, MabraxWrapper, MatraxWrapper, @@ -241,7 +240,7 @@ def create_gym_env( for _ in range(num_env) ] ) - envs = AsyncGymWrapper(envs) + return envs diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index be4fe40fc..f48c34fcf 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -17,10 +17,14 @@ import gym import numpy as np from numpy.typing import NDArray +import warnings from mava.types import Observation +# Filter out the warnings +warnings.filterwarnings('ignore', module='gym.utils.passive_env_checker') + class GymRwareWrapper(gym.Wrapper): """Wrapper for rware gym environments""" @@ -56,7 +60,7 @@ def reset(self) -> Tuple: seed=np.random.randint(1) ) # todo: assure reproducibility, this only works for rware - info = {"action_mask" : self._get_actions_mask(info)} + info = {"actions_mask" : self._get_actions_mask(info)} return np.array(agents_view), info @@ -76,7 +80,7 @@ def step(self, actions: NDArray) -> Tuple: ) return agents_view, reward, terminated, truncated, info - info = {"action_mask" : self._get_actions_mask(info)} + info = {"actions_mask" : self._get_actions_mask(info)} if self.use_individual_rewards: reward = np.array(reward) @@ -90,39 +94,3 @@ def _get_actions_mask(self, info: Dict) -> NDArray: if "action_mask" in info: return np.array(info["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) - - -class AsyncGymWrapper(gym.Wrapper): - """Wrapper for async gym environments""" - - def __init__(self, env: gym.vector.AsyncVectorEnv): - super().__init__(env) - self._env = env - self.step_count = 0 #todo : make sure this is implemented correctly - - action_space = env.single_action_space - self.num_agents = len(action_space) - self.num_actions = action_space[0].n - self.num_envs = env.num_envs - - def reset(self) -> Tuple[Observation, Dict]: - agents_view , info = self._env.reset() - obs = self._create_obs(agents_view, info) - dones = np.zeros((self.num_envs, self.num_agents)) - return obs, dones, info - - def step(self, actions : NDArray) -> Tuple[Observation, NDArray, NDArray, NDArray, Dict]: - - self.step_count += 1 - actions = actions.swapaxes(0,1) # num_env, num_ags --> num_ags, num_env as expected by the async env - agents_view, reward, terminated, truncated, info = self._env.step(actions) - obs = self._create_obs(agents_view, info) - dones = np.logical_or(terminated, truncated) - return obs, reward, dones, info - - - def _create_obs(self, agents_view : NDArray, info: Dict) -> Observation: - """Create the observations""" - agents_view = np.stack(agents_view, axis = 1) - return Observation(agents_view=agents_view, action_mask=np.stack(info["action_mask"], axis = 0), step_count=self.step_count) - \ No newline at end of file From 9fce9c6a463780103bd5e72279fb8e13121d5351 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 22 Jun 2024 14:08:15 +0100 Subject: [PATCH 014/125] feat: gym metric tracker wrapper --- mava/systems/sebulba/ppo/types.py | 3 +- mava/utils/make_env.py | 11 ++--- mava/wrappers/__init__.py | 2 +- mava/wrappers/gym.py | 70 +++++++++++++++++++++++++++---- 4 files changed, 70 insertions(+), 16 deletions(-) diff --git a/mava/systems/sebulba/ppo/types.py b/mava/systems/sebulba/ppo/types.py index 13aeb58c1..6e02aa904 100644 --- a/mava/systems/sebulba/ppo/types.py +++ b/mava/systems/sebulba/ppo/types.py @@ -75,7 +75,8 @@ class PPOTransition(NamedTuple): reward: chex.Array log_prob: chex.Array obs: chex.Array - info : Dict + info: Dict + class RNNPPOTransition(NamedTuple): """Transition tuple for PPO.""" diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 8ee391f0c..69fc54623 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -17,7 +17,6 @@ import gym import gym.vector import gym.wrappers -import gym.wrappers import gym.wrappers.compatibility import jaxmarl import jumanji @@ -45,6 +44,7 @@ CleanerWrapper, ConnectorWrapper, GigastepWrapper, + GymRecordEpisodeMetrics, GymRwareWrapper, LbfWrapper, MabraxWrapper, @@ -69,7 +69,7 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} -_gym_registry = {"rware" : GymRwareWrapper} +_gym_registry = {"rware": GymRwareWrapper} def add_extra_wrappers( @@ -231,16 +231,17 @@ def create_gym_env( wrapped_env = wrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) if not config.env.implicit_agent_id: pass # todo : add agent id wrapper for gym . + env = GymRecordEpisodeMetrics(env) return wrapped_env - num_env = config.arch.num_envs - envs = gym.vector.AsyncVectorEnv( #todo : give them more descriptive names + num_env = config.arch.num_envs + envs = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names [ lambda: create_gym_env(config, add_global_state, eval_env=eval_env) for _ in range(num_env) ] ) - + return envs diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 6210ca6ed..e888d9317 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,7 +15,7 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper -from mava.wrappers.gym import GymRwareWrapper, AsyncGymWrapper +from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index f48c34fcf..69632f1bc 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Dict, Tuple import gym import numpy as np from numpy.typing import NDArray -import warnings - -from mava.types import Observation +# Filter out the warnings +warnings.filterwarnings("ignore", module="gym.utils.passive_env_checker") -# Filter out the warnings -warnings.filterwarnings('ignore', module='gym.utils.passive_env_checker') class GymRwareWrapper(gym.Wrapper): """Wrapper for rware gym environments""" @@ -60,8 +58,8 @@ def reset(self) -> Tuple: seed=np.random.randint(1) ) # todo: assure reproducibility, this only works for rware - info = {"actions_mask" : self._get_actions_mask(info)} - + info = {"actions_mask": self._get_actions_mask(info)} + return np.array(agents_view), info def step(self, actions: NDArray) -> Tuple: @@ -80,17 +78,71 @@ def step(self, actions: NDArray) -> Tuple: ) return agents_view, reward, terminated, truncated, info - info = {"actions_mask" : self._get_actions_mask(info)} + info = {"actions_mask": self._get_actions_mask(info)} if self.use_individual_rewards: reward = np.array(reward) else: reward = np.array([np.array(reward).mean()] * self.num_agents) - return agents_view, reward, terminated, truncated, info def _get_actions_mask(self, info: Dict) -> NDArray: if "action_mask" in info: return np.array(info["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + + +class GymRecordEpisodeMetrics(gym.Wrapper): + """Record the episode returns and lengths.""" + + def __init__(self, env: gym.Env): + super().__init__(env) + self.running_count_episode_return = 0.0 + self.running_count_episode_length = 0 + + def reset(self) -> Tuple: + + # Reset the env + agents_view, info = self.env.reset() + + # Reset the metrics + self.running_count_episode_return = 0.0 + self.running_count_episode_length = 0 + + # Create the metrics dict + metrics = { + "episode_return": self.running_count_episode_return, + "episode_length": self.self.running_count_episode_length, + "is_terminal_step": False, + } + if "won_episode" in info: + metrics["won_episode"] = info["won_episode"] + + return agents_view, metrics + + def step(self, actions: NDArray) -> Tuple: + + # Step the env + agents_view, reward, terminated, truncated, info = self.env.step(actions) + + # Update the metrics + done = np.logical_or(terminated, truncated).all() + + if not done: + self.running_count_episode_return += float(np.mean(reward)) + self.running_count_episode_length += 1 + + else: + self.running_count_episode_return = 0.0 + self.running_count_episode_length = 0 + + metrics = { + "episode_return": self.running_count_episode_return, + "episode_length": self.self.running_count_episode_length, + "is_terminal_step": False, + } + if "won_episode" in info: + metrics["won_episode"] = info["won_episode"] + + return agents_view, reward, terminated, truncated, metrics From 055a3266accb82a96808fa95762314dac45646d3 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 10 Jun 2024 20:12:16 +0100 Subject: [PATCH 015/125] feat: init sebulba ippo --- mava/systems/{ => anakin}/__init__.py | 0 mava/systems/{ => anakin}/ppo/__init__.py | 0 mava/systems/{ => anakin}/ppo/ff_ippo.py | 0 mava/systems/{ => anakin}/ppo/ff_mappo.py | 0 mava/systems/{ => anakin}/ppo/rec_ippo.py | 0 mava/systems/{ => anakin}/ppo/rec_mappo.py | 0 mava/systems/{ => anakin}/ppo/types.py | 0 .../{ => anakin}/q_learning/__init__.py | 0 .../{ => anakin}/q_learning/rec_iql.py | 0 mava/systems/{ => anakin}/q_learning/types.py | 0 mava/systems/{ => anakin}/sac/__init__.py | 0 mava/systems/{ => anakin}/sac/ff_isac.py | 0 mava/systems/{ => anakin}/sac/ff_masac.py | 0 mava/systems/{ => anakin}/sac/types.py | 0 mava/systems/sebulba/ppo/ff_ippo.py | 596 +++++++++++++ mava/systems/sebulba/ppo/orig.py | 796 ++++++++++++++++++ 16 files changed, 1392 insertions(+) rename mava/systems/{ => anakin}/__init__.py (100%) rename mava/systems/{ => anakin}/ppo/__init__.py (100%) rename mava/systems/{ => anakin}/ppo/ff_ippo.py (100%) rename mava/systems/{ => anakin}/ppo/ff_mappo.py (100%) rename mava/systems/{ => anakin}/ppo/rec_ippo.py (100%) rename mava/systems/{ => anakin}/ppo/rec_mappo.py (100%) rename mava/systems/{ => anakin}/ppo/types.py (100%) rename mava/systems/{ => anakin}/q_learning/__init__.py (100%) rename mava/systems/{ => anakin}/q_learning/rec_iql.py (100%) rename mava/systems/{ => anakin}/q_learning/types.py (100%) rename mava/systems/{ => anakin}/sac/__init__.py (100%) rename mava/systems/{ => anakin}/sac/ff_isac.py (100%) rename mava/systems/{ => anakin}/sac/ff_masac.py (100%) rename mava/systems/{ => anakin}/sac/types.py (100%) create mode 100644 mava/systems/sebulba/ppo/ff_ippo.py create mode 100644 mava/systems/sebulba/ppo/orig.py diff --git a/mava/systems/__init__.py b/mava/systems/anakin/__init__.py similarity index 100% rename from mava/systems/__init__.py rename to mava/systems/anakin/__init__.py diff --git a/mava/systems/ppo/__init__.py b/mava/systems/anakin/ppo/__init__.py similarity index 100% rename from mava/systems/ppo/__init__.py rename to mava/systems/anakin/ppo/__init__.py diff --git a/mava/systems/ppo/ff_ippo.py b/mava/systems/anakin/ppo/ff_ippo.py similarity index 100% rename from mava/systems/ppo/ff_ippo.py rename to mava/systems/anakin/ppo/ff_ippo.py diff --git a/mava/systems/ppo/ff_mappo.py b/mava/systems/anakin/ppo/ff_mappo.py similarity index 100% rename from mava/systems/ppo/ff_mappo.py rename to mava/systems/anakin/ppo/ff_mappo.py diff --git a/mava/systems/ppo/rec_ippo.py b/mava/systems/anakin/ppo/rec_ippo.py similarity index 100% rename from mava/systems/ppo/rec_ippo.py rename to mava/systems/anakin/ppo/rec_ippo.py diff --git a/mava/systems/ppo/rec_mappo.py b/mava/systems/anakin/ppo/rec_mappo.py similarity index 100% rename from mava/systems/ppo/rec_mappo.py rename to mava/systems/anakin/ppo/rec_mappo.py diff --git a/mava/systems/ppo/types.py b/mava/systems/anakin/ppo/types.py similarity index 100% rename from mava/systems/ppo/types.py rename to mava/systems/anakin/ppo/types.py diff --git a/mava/systems/q_learning/__init__.py b/mava/systems/anakin/q_learning/__init__.py similarity index 100% rename from mava/systems/q_learning/__init__.py rename to mava/systems/anakin/q_learning/__init__.py diff --git a/mava/systems/q_learning/rec_iql.py b/mava/systems/anakin/q_learning/rec_iql.py similarity index 100% rename from mava/systems/q_learning/rec_iql.py rename to mava/systems/anakin/q_learning/rec_iql.py diff --git a/mava/systems/q_learning/types.py b/mava/systems/anakin/q_learning/types.py similarity index 100% rename from mava/systems/q_learning/types.py rename to mava/systems/anakin/q_learning/types.py diff --git a/mava/systems/sac/__init__.py b/mava/systems/anakin/sac/__init__.py similarity index 100% rename from mava/systems/sac/__init__.py rename to mava/systems/anakin/sac/__init__.py diff --git a/mava/systems/sac/ff_isac.py b/mava/systems/anakin/sac/ff_isac.py similarity index 100% rename from mava/systems/sac/ff_isac.py rename to mava/systems/anakin/sac/ff_isac.py diff --git a/mava/systems/sac/ff_masac.py b/mava/systems/anakin/sac/ff_masac.py similarity index 100% rename from mava/systems/sac/ff_masac.py rename to mava/systems/anakin/sac/ff_masac.py diff --git a/mava/systems/sac/types.py b/mava/systems/anakin/sac/types.py similarity index 100% rename from mava/systems/sac/types.py rename to mava/systems/anakin/sac/types.py diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py new file mode 100644 index 000000000..c9a2069b2 --- /dev/null +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -0,0 +1,596 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time +from typing import Any, Dict, Tuple + +import chex +import flax +import hydra +import jax +import jax.numpy as jnp +import optax +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from jumanji.env import Environment +from omegaconf import DictConfig, OmegaConf +from optax._src.base import OptState +from rich.pretty import pprint + +from mava.evaluator import make_eval_fns +from mava.networks import FeedForwardActor as Actor +from mava.networks import FeedForwardValueNet as Critic +from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn +from mava.utils import make_env as environments +from mava.utils.checkpointing import Checkpointer +from mava.utils.jax_utils import ( + merge_leading_dims, + unreplicate_batch_dim, + unreplicate_n_dims, +) +from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.training import make_learning_rate +from mava.wrappers.episode_metrics import get_final_step_metrics + + +def get_learner_fn( + env: Environment, + apply_fns: Tuple[ActorApply, CriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> LearnerFn[LearnerState]: + """Get the learner function.""" + + # Get apply and update functions for actor and critic networks. + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: + """A single update of the network. + + This function steps the environment and records the trajectory batch for + training. It then calculates advantages and targets based on the recorded + trajectory and updates the actor and critic networks based on the calculated + losses. + + Args: + learner_state (NamedTuple): + - params (Params): The current model parameters. + - opt_states (OptStates): The current optimizer states. + - key (PRNGKey): The random number generator state. + - env_state (State): The environment state. + - last_timestep (TimeStep): The last timestep in the current trajectory. + _ (Any): The current metrics info. + """ + + def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + """Step the environment.""" + params, opt_states, key, env_state, last_timestep = learner_state + + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) + value = critic_apply_fn(params.critic_params, last_timestep.observation) + + action = actor_policy.sample(seed=policy_key) + log_prob = actor_policy.log_prob(action) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), + timestep.last(), + ) + info = timestep.extras["episode_metrics"] + + transition = PPOTransition( + done, action, value, timestep.reward, log_prob, last_timestep.observation, info + ) + learner_state = LearnerState(params, opt_states, key, env_state, timestep) + return learner_state, transition + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + learner_state, traj_batch = jax.lax.scan( + _env_step, learner_state, None, config.system.rollout_length + ) + + # CALCULATE ADVANTAGE + params, opt_states, key, env_state, last_timestep = learner_state + last_val = critic_apply_fn(params.critic_params, last_timestep.observation) + + def _calculate_gae( + traj_batch: PPOTransition, last_val: chex.Array + ) -> Tuple[chex.Array, chex.Array]: + """Calculate the GAE.""" + + def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple: + """Calculate the GAE for a single transition.""" + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + gamma = config.system.gamma + delta = reward + gamma * next_value * (1 - done) - value + gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + advantages, targets = _calculate_gae(traj_batch, last_val) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + + # UNPACK TRAIN STATE AND BATCH INFO + params, opt_states, key = train_state + traj_batch, advantages, targets = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + actor_opt_state: OptState, + traj_batch: PPOTransition, + gae: chex.Array, + key: chex.PRNGKey, + ) -> Tuple: + """Calculate the actor loss.""" + # RERUN NETWORK + actor_policy = actor_apply_fn(actor_params, traj_batch.obs) + log_prob = actor_policy.log_prob(traj_batch.action) + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config.system.clip_eps, + 1.0 + config.system.clip_eps, + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + # The seed will be used in the TanhTransformedDistribution: + entropy = actor_policy.entropy(seed=key).mean() + + total_loss_actor = loss_actor - config.system.ent_coef * entropy + return total_loss_actor, (loss_actor, entropy) + + def _critic_loss_fn( + critic_params: FrozenDict, + critic_opt_state: OptState, + traj_batch: PPOTransition, + targets: chex.Array, + ) -> Tuple: + """Calculate the critic loss.""" + # RERUN NETWORK + value = critic_apply_fn(critic_params, traj_batch.obs) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( + -config.system.clip_eps, config.system.clip_eps + ) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + + critic_total_loss = config.system.vf_coef * value_loss + return critic_total_loss, (value_loss) + + # CALCULATE ACTOR LOSS + key, entropy_key = jax.random.split(key) + actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) + actor_loss_info, actor_grads = actor_grad_fn( + params.actor_params, + opt_states.actor_opt_state, + traj_batch, + advantages, + entropy_key, + ) + + # CALCULATE CRITIC LOSS + critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) + critic_loss_info, critic_grads = critic_grad_fn( + params.critic_params, opt_states.critic_opt_state, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # This pmean could be a regular mean as the batch axis is on the same device. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="batch" + ) + # pmean over devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="device" + ) + + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="batch" + ) + # pmean over devices. + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="device" + ) + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # UPDATE CRITIC PARAMS AND OPTIMISER STATE + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + # PACK NEW PARAMS AND OPTIMISER STATE + new_params = Params(actor_new_params, critic_new_params) + new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) + + # PACK LOSS INFO + total_loss = actor_loss_info[0] + critic_loss_info[0] + value_loss = critic_loss_info[1] + actor_loss = actor_loss_info[1][0] + entropy = actor_loss_info[1][1] + loss_info = { + "total_loss": total_loss, + "value_loss": value_loss, + "actor_loss": actor_loss, + "entropy": entropy, + } + return (new_params, new_opt_state, entropy_key), loss_info + + params, opt_states, traj_batch, advantages, targets, key = update_state + key, shuffle_key, entropy_key = jax.random.split(key, 3) + + # SHUFFLE MINIBATCHES + batch_size = config.system.rollout_length * config.arch.num_envs + permutation = jax.random.permutation(shuffle_key, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])), + shuffled_batch, + ) + + # UPDATE MINIBATCHES + (params, opt_states, entropy_key), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states, entropy_key), minibatches + ) + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + return update_state, loss_info + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.ppo_epochs + ) + + params, opt_states, traj_batch, advantages, targets, key = update_state + learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + Args: + learner_state (NamedTuple): + - params (Params): The initial model parameters. + - opt_states (OptStates): The initial optimizer state. + - key (chex.PRNGKey): The random number generator state. + - env_state (LogEnvState): The environment state. + - timesteps (TimeStep): The initial timestep in the initial trajectory. + """ + + batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + + learner_state, (episode_info, loss_info) = jax.lax.scan( + batched_update_step, learner_state, None, config.system.num_updates_per_eval + ) + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_setup( + env: Environment, keys: chex.Array, config: DictConfig +) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + # Get available TPU cores. + devices = jax.devices() + learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] + n_devices = len(learner_devices) + + # Get number of agents. + config.system.num_agents = env.num_agents + + # PRNG keys. + key, actor_net_key, critic_net_key = keys + + # Define network and optimiser. + actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_action_head = hydra.utils.instantiate( + config.network.action_head, action_dim=env.action_dim + ) + critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + + actor_network = Actor(torso=actor_torso, action_head=actor_action_head) + critic_network = Critic(torso=critic_torso) + + actor_lr = make_learning_rate(config.system.actor_lr, config) + critic_lr = make_learning_rate(config.system.critic_lr, config) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(critic_lr, eps=1e-5), + ) + + # Initialise observation with obs of all agents. + obs = env.single_observation_space.sample() + init_x = jax.tree_util.tree_map(lambda x: x[jnp.newaxis, ...], obs) + + # Initialise actor params and optimiser state. + actor_params = actor_network.init(actor_net_key, init_x) + actor_opt_state = actor_optim.init(actor_params) + + # Initialise critic params and optimiser state. + critic_params = critic_network.init(critic_net_key, init_x) + critic_opt_state = critic_optim.init(critic_params) + + # Pack params. + params = Params(actor_params, critic_params) + + # Pack apply and update functions. + apply_fns = (actor_network.apply, critic_network.apply) + update_fns = (actor_optim.update, critic_optim.update) + + # Get batched iterated update and replicate it to pmap it over cores. + learn = get_learner_fn(env, apply_fns, update_fns, config) + learn = jax.pmap(learn, axis_name="device", devices = learner_devices) + + # Initialise environment states and timesteps: across devices and batches. + key, *env_keys = jax.random.split( + key, n_devices * config.system.update_batch_size * config.arch.num_envs + 1 + ) + env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( + jnp.stack(env_keys), + ) + reshape_states = lambda x: x.reshape( + (n_devices, config.system.update_batch_size, config.arch.num_envs) + x.shape[1:] + ) + # (devices, update batch size, num_envs, ...) + env_states = jax.tree_map(reshape_states, env_states) + timesteps = jax.tree_map(reshape_states, timesteps) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params(input_params=params) + # Update the params + params = restored_params + + # Define params to be replicated across devices and batches. + key, step_keys = jax.random.split(key) + opt_states = OptStates(actor_opt_state, critic_opt_state) + replicate_learner = (params, opt_states, step_keys) + + # Duplicate learner for update_batch_size. + broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size,) + x.shape) + replicate_learner = jax.tree_map(broadcast, replicate_learner) + + # Duplicate learner across devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) + + # Initialise learner state. + params, opt_states, step_keys = replicate_learner + init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + + return learn, actor_network, init_learner_state + + +def run_experiment(_config: DictConfig) -> float: + """Runs experiment.""" + config = copy.deepcopy(_config) + + n_devices = len(jax.devices()) + + # Create the enviroments for train and eval. + env, eval_env = environments.make(config) + + # PRNG keys. + key, key_e, actor_net_key, critic_net_key = jax.random.split( + jax.random.PRNGKey(config.system.seed), num=4 + ) + + # Setup learner. + learn, actor_network, learner_state = learner_setup( + env, (key, actor_net_key, critic_net_key), config + ) + + # Setup evaluator. + # One key per device for evaluation. + eval_keys = jax.random.split(key_e, n_devices) + evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor_network.apply, config) + + # Calculate total timesteps. + config = check_total_timesteps(config) + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Calculate number of updates per evaluation. + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + steps_per_rollout = ( + n_devices + * config.system.num_updates_per_eval + * config.system.rollout_length + * config.system.update_batch_size + * config.arch.num_envs + ) + + # Logger setup + logger = MavaLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Run experiment for a total number of evaluations. + max_episode_return = -jnp.inf + best_params = None + for eval_step in range(config.arch.num_evaluation): + # Train. + start_time = time.time() + + learner_output = learn(learner_state) + jax.block_until_ready(learner_output) + + # Log the results of the training. + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + if ep_completed: # only log episode metrics if an episode was completed in the rollout. + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + + # Prepare for evaluation. + start_time = time.time() + + trained_params = unreplicate_batch_dim(learner_state.params.actor_params) + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + # Evaluate. + evaluator_output = evaluator(trained_params, eval_keys) + jax.block_until_ready(evaluator_output) + + # Log the results of the evaluation. + elapsed_time = time.time() - start_time + episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) + + steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=steps_per_rollout * (eval_step + 1), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(trained_params) + max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Record the performance for the final evaluation run. + eval_performance = float(jnp.mean(evaluator_output.episode_metrics[config.env.eval_metric])) + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + evaluator_output = absolute_metric_evaluator(best_params, eval_keys) + jax.block_until_ready(evaluator_output) + + elapsed_time = time.time() - start_time + steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) + t = int(steps_per_rollout * (eval_step + 1)) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + + return eval_performance + + +@hydra.main(config_path="../../configs", config_name="default_ff_ippo.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> float: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + eval_performance = run_experiment(cfg) + print(f"{Fore.CYAN}{Style.BRIGHT}IPPO experiment completed{Style.RESET_ALL}") + return eval_performance + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/mava/systems/sebulba/ppo/orig.py b/mava/systems/sebulba/ppo/orig.py new file mode 100644 index 000000000..85b679305 --- /dev/null +++ b/mava/systems/sebulba/ppo/orig.py @@ -0,0 +1,796 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mava.utils.sebulba_utils import configure_computation_environment + +configure_computation_environment() # noqa: E402 + +import copy +import queue +import threading +import time +from collections import deque +from typing import Any, Dict, List, Tuple + +import chex +import flax +import hydra +import jax +import jax.numpy as jnp +import numpy as np +import optax +from chex import PRNGKey +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from mava.evaluator import get_sebulba_ff_evaluator as evaluator_setup +from mava.logger import Logger +from mava.networks import get_networks +from mava.types import ( + ActorApply, + CriticApply, + LearnerState, + Observation, + OptStates, + Params, +) +from mava.types import PPOTransition as Transition +from mava.types import SebulbaLearnerFn as LearnerFn +from mava.types import SingleDeviceFn +from mava.utils.checkpointing import Checkpointer +from mava.utils.jax import merge_leading_dims +from mava.utils.make_env import make + + +def rollout( # noqa: CCR001 + rng: PRNGKey, + config: DictConfig, + rollout_queue: queue.Queue, + params_queue: queue.Queue, + device_thread_id: int, + apply_fns: Tuple, + logger: Logger, + learner_devices: List, +) -> None: + """Executor rollout loop.""" + # Create envs + envs = make(config)(config.arch.num_envs) # type: ignore + + # Setup + len_executor_device_ids = len(config.arch.executor_device_ids) + t_env = 0 + start_time = time.time() + + # Get the apply functions for the actor and critic networks. + vmap_actor_apply, vmap_critic_apply = apply_fns + + # Define the util functions: select action function and prepare data to share it with learner. + @jax.jit + def get_action_and_value( + params: FrozenDict, + observation: Observation, + key: PRNGKey, + ) -> Tuple: + """Get action and value.""" + key, subkey = jax.random.split(key) + + policy = vmap_actor_apply(params.actor_params, observation) + action, logprob = policy.sample_and_log_prob(seed=subkey) + + value = vmap_critic_apply(params.critic_params, observation).squeeze() + return action, logprob, value, key + + @jax.jit + def prepare_data(storage: List[Transition]) -> Transition: + """Prepare data to share with learner.""" + return jax.tree_map( # type: ignore + lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage + ) + + # Define the episode info + env_id = np.arange(config.arch.num_envs) + # Accumulated episode returns + episode_returns = np.zeros((config.arch.num_envs,), dtype=np.float32) + # Final episode returns + returned_episode_returns = np.zeros((config.arch.num_envs,), dtype=np.float32) + # Accumulated episode lengths + episode_lengths = np.zeros((config.arch.num_envs,), dtype=np.float32) + # Final episode lengths + returned_episode_lengths = np.zeros((config.arch.num_envs,), dtype=np.float32) + + # Define the data structure + params_queue_get_time: deque = deque(maxlen=10) + rollout_time: deque = deque(maxlen=10) + rollout_queue_put_time: deque = deque(maxlen=10) + + # Reset envs + next_obs, infos = envs.reset() + next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) + + # Loop till the learner has finished training + for update in range(1, config.system.num_updates + 2): + # Setup + env_recv_time: float = 0 + inference_time: float = 0 + storage_time: float = 0 + env_send_time: float = 0 + + # Get the latest parameters from the learner + params_queue_get_time_start = time.time() + if config.arch.concurrency: + if update != 2: + params = params_queue.get() + params.network_params["params"]["Dense_0"]["kernel"].block_until_ready() + else: + params = params_queue.get() + params_queue_get_time.append(time.time() - params_queue_get_time_start) + + # Rollout + rollout_time_start = time.time() + storage: List = [] + # Loop over the rollout length + for _ in range(0, config.system.rollout_length): + # Get previous step info + cached_next_obs = next_obs + cached_next_dones = next_dones + cashed_action_mask = np.stack(infos["actions_mask"]) + + # Increment current timestep + t_env += ( + config.arch.n_threads_per_executor * len_executor_device_ids * config.arch.num_envs + ) + + # Get action and value + inference_time_start = time.time() + ( + action, + logprob, + value, + rng, + ) = get_action_and_value(params, Observation(cached_next_obs, cashed_action_mask), rng) + inference_time += time.time() - inference_time_start + + # Step the environment + env_send_time_start = time.time() + cpu_action = np.array(action) + next_obs, next_reward, terminated, truncated, infos = envs.step(cpu_action) + next_done = terminated + truncated + next_dones = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), + (next_done), + ) + + # Append data to storage + env_send_time += time.time() - env_send_time_start + storage_time_start = time.time() + storage.append( + Transition( + done=cached_next_dones, + action=action, + value=value, + reward=next_reward, + log_prob=logprob, + obs=cached_next_obs, + info=np.stack(infos["actions_mask"]), # Add action mask to info + ) + ) + storage_time += time.time() - storage_time_start + + # Update episode info + episode_returns[env_id] += np.mean(next_reward) + returned_episode_returns[env_id] = np.where( + next_done, + episode_returns[env_id], + returned_episode_returns[env_id], + ) + episode_returns[env_id] *= (1 - next_done) * (1 - truncated) + episode_lengths[env_id] += 1 + returned_episode_lengths[env_id] = np.where( + next_done, + episode_lengths[env_id], + returned_episode_lengths[env_id], + ) + episode_lengths[env_id] *= (1 - next_done) * (1 - truncated) + rollout_time.append(time.time() - rollout_time_start) + + # Prepare data to share with learner + partitioned_storage = prepare_data(storage) + sharded_storage = Transition( + *list( # noqa: C417 + map( + lambda x: jax.device_put_sharded(x, devices=learner_devices), # type: ignore + partitioned_storage, + ) + ) + ) + sharded_next_obs = jax.device_put_sharded( + np.split(next_obs, len(learner_devices)), devices=learner_devices + ) + sharded_next_done = jax.device_put_sharded( + np.split(next_dones, len(learner_devices)), devices=learner_devices + ) + sharded_next_action_mask = jax.device_put_sharded( + np.split(np.stack(infos["actions_mask"]), len(learner_devices)), devices=learner_devices + ) + payload = ( + t_env, + sharded_storage, + sharded_next_obs, + sharded_next_done, + sharded_next_action_mask, + np.mean(params_queue_get_time), + ) + + # Put data in the rollout queue to share it with the learner + rollout_queue_put_time_start = time.time() + rollout_queue.put(payload) + rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) + + if (update % config.arch.log_frequency == 0) or (config.system.num_updates + 1 == update): + # Log info + logger.log_executor_metrics( + t_env=t_env, + metrics={ + "episodes_info": { + "episode_return": returned_episode_returns, + "episode_length": returned_episode_lengths, + "steps_per_second": int(t_env / (time.time() - start_time)), + }, + "speed_info": { + "rollout_time": np.mean(rollout_time), + }, + "queue_info": { + "params_queue_get_time": np.mean(params_queue_get_time), + "env_recv_time": env_recv_time, + "inference_time": inference_time, + "storage_time": storage_time, + "env_send_time": env_send_time, + "rollout_queue_put_time": np.mean(rollout_queue_put_time), + }, + }, + device_thread_id=device_thread_id, + ) + + +def get_learner_fn( + apply_fns: Tuple[ActorApply, CriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> LearnerFn: + """Get the learner function.""" + # Get apply and update functions for actor and critic networks. + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def single_device_update( + agents_state: LearnerState, + traj_batch: Transition, + last_observation: Observation, + rng: PRNGKey, + ) -> Tuple[LearnerState, chex.PRNGKey, Tuple]: + params, opt_states, _, _, _ = agents_state + + def _calculate_gae( + traj_batch: Transition, last_val: chex.Array + ) -> Tuple[chex.Array, chex.Array]: + """Calculate the GAE.""" + + def _get_advantages(gae_and_next_value: Tuple, transition: Transition) -> Tuple: + """Calculate the GAE for a single transition.""" + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + gamma = config.system.gamma + delta = reward + gamma * next_value * (1 - done) - value + gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + # Calculate GAE + last_val = critic_apply_fn(params.critic_params, last_observation) + advantages, targets = _calculate_gae(traj_batch, last_val) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + + # UNPACK TRAIN STATE AND BATCH INFO + params, opt_states = train_state + traj_batch, advantages, targets = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + actor_opt_state: OptStates, + traj_batch: Transition, + gae: chex.Array, + ) -> Tuple: + """Calculate the actor loss.""" + # RERUN NETWORK + actor_policy = actor_apply_fn(actor_params, traj_batch.obs) + log_prob = actor_policy.log_prob(traj_batch.action) + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config.system.clip_eps, + 1.0 + config.system.clip_eps, + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + entropy = actor_policy.entropy().mean() + + total_loss_actor = loss_actor - config.system.ent_coef * entropy + return total_loss_actor, (loss_actor, entropy) + + def _critic_loss_fn( + critic_params: FrozenDict, + critic_opt_state: OptStates, + traj_batch: Transition, + targets: chex.Array, + ) -> Tuple: + """Calculate the critic loss.""" + # RERUN NETWORK + value = critic_apply_fn(critic_params, traj_batch.obs) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( + -config.system.clip_eps, config.system.clip_eps + ) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + + critic_total_loss = config.system.vf_coef * value_loss + return critic_total_loss, (value_loss) + + # CALCULATE ACTOR LOSS + actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) + actor_loss_info, actor_grads = actor_grad_fn( + params.actor_params, opt_states.actor_opt_state, traj_batch, advantages + ) + + # CALCULATE CRITIC LOSS + critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) + critic_loss_info, critic_grads = critic_grad_fn( + params.critic_params, opt_states.critic_opt_state, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the learner devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="local_devices" + ) + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="local_devices" + ) + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # UPDATE CRITIC PARAMS AND OPTIMISER STATE + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + # PACK NEW PARAMS AND OPTIMISER STATE + new_params = Params(actor_new_params, critic_new_params) + new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) + + # PACK LOSS INFO + total_loss = actor_loss_info[0] + critic_loss_info[0] + value_loss = critic_loss_info[1] + actor_loss = actor_loss_info[1][0] + entropy = actor_loss_info[1][1] + loss_info = (total_loss, value_loss, actor_loss, entropy) + + return (new_params, new_opt_state), loss_info + + params, opt_states, traj_batch, advantages, targets, rng = update_state + rng, shuffle_rng = jax.random.split(rng) + + # SHUFFLE MINIBATCHES + batch_size = config.system.rollout_length * config.arch.num_envs + permutation = jax.random.permutation(shuffle_rng, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])), + shuffled_batch, + ) + + # UPDATE MINIBATCHES + (params, opt_states), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states), minibatches + ) + + update_state = (params, opt_states, traj_batch, advantages, targets, rng) + return update_state, loss_info + + update_state = (params, opt_states, traj_batch, advantages, targets, rng) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.ppo_epochs + ) + + params, opt_states, traj_batch, advantages, targets, rng = update_state + learner_state = agents_state._replace(params=params, opt_states=opt_states) + return learner_state, rng, loss_info + + def learner_fn( + agents_state: LearnerState, + sharded_storages: List, + sharded_next_obs: List, + sharded_next_done: List, + sharded_next_action_mask: List, + key: chex.PRNGKey, + ) -> Tuple: + """Single device update.""" + # Horizontal stack all the data from different devices + traj_batch = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages) + traj_batch = traj_batch._replace(obs=Observation(traj_batch.obs, traj_batch.info)) + + # Get last observation + last_obs = jnp.concatenate(sharded_next_obs) + last_action_mask = jnp.concatenate(sharded_next_action_mask) + last_observation = Observation(last_obs, last_action_mask) + + # Update learner + agents_state, key, (total_loss, value_loss, actor_loss, entropy) = single_device_update( + agents_state, traj_batch, last_observation, key + ) + + # Pack loss info + loss_info = { + "total_loss": total_loss, + "loss_actor": actor_loss, + "value_loss": value_loss, + "entropy": entropy, + } + return agents_state, key, loss_info + + return learner_fn + + +def learner_setup( + rngs: chex.Array, config: DictConfig, learner_devices: List +) -> Tuple[SingleDeviceFn, LearnerState, Tuple[ActorApply, ActorApply]]: + """Initialise learner_fn, network, optimiser, environment and states.""" + # Get number of actions and agents. + dummy_envs = make(config)( # type: ignore + config.arch.num_envs # Create dummy_envs to get observation and action spaces + ) + config.system.num_agents = dummy_envs.single_observation_space.shape[0] + config.system.num_actions = int(dummy_envs.single_action_space.nvec[0]) + + # PRNG keys. + actor_net_key, critic_net_key = rngs + + # Define network and optimiser. + actor_network, critic_network = get_networks( + config=config, network="feedforward", centralised_critic=False + ) + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(config.system.actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(config.system.critic_lr, eps=1e-5), + ) + + # Initialise observation: Select only obs for a single agent. + init_obs = np.array([dummy_envs.single_observation_space.sample()[0]]) + init_action_mask = np.ones((1, config.system.num_actions)) + init_x = Observation(init_obs, init_action_mask) + + # Initialise actor params and optimiser state. + actor_params = actor_network.init(actor_net_key, init_x) + actor_opt_state = actor_optim.init(actor_params) + + # Initialise critic params and optimiser state. + critic_params = critic_network.init(critic_net_key, init_x) + critic_opt_state = critic_optim.init(critic_params) + + # Vmap network apply function over number of agents. + vmapped_actor_network_apply_fn = jax.vmap( + actor_network.apply, + in_axes=(None, Observation(1, 1, None)), + out_axes=(1), + ) + vmapped_critic_network_apply_fn = jax.vmap( + critic_network.apply, + in_axes=(None, Observation(1, 1, None)), + out_axes=(1), + ) + + # Pack apply and update functions. + apply_fns = (vmapped_actor_network_apply_fn, vmapped_critic_network_apply_fn) + update_fns = (actor_optim.update, critic_optim.update) + + # Define agents state + agents_state = LearnerState( + params=Params( + actor_params=actor_params, + critic_params=critic_params, + ), + opt_states=OptStates( + actor_opt_state=actor_opt_state, + critic_opt_state=critic_opt_state, + ), + ) + # Replicate agents state per learner device + agents_state = flax.jax_utils.replicate(agents_state, devices=learner_devices) + + # Get Learner function: pmap over learner devices. + single_device_update = get_learner_fn(apply_fns, update_fns, config) + multi_device_update = jax.pmap( + single_device_update, + axis_name="local_devices", + devices=learner_devices, + ) + + # Close dummy envs. + dummy_envs.close() + + return multi_device_update, agents_state, apply_fns + + +def run_experiment(_config: DictConfig) -> None: # noqa: CCR001 + """Runs experiment.""" + config = copy.deepcopy(_config) + + # Setup device distribution. + local_devices = jax.local_devices() #why are we using local devices insted of devices? ------------------------------------------------------------------------------------------------------------------------------------ define a ratio insted of the devices to use? + learner_devices = [local_devices[d_id] for d_id in config.arch.learner_device_ids] + + # PRNG keys. + rng, rng_e, actor_net_key, critic_net_key = jax.random.split( + jax.random.PRNGKey(config.system.seed), num=4 + ) + learner_keys = jax.device_put_replicated(rng, learner_devices) + + # Sanity check of config + assert ( + config.arch.num_envs % len(config.arch.learner_device_ids) == 0 + ), "local_num_envs must be divisible by len(learner_device_ids)" + #each thread is going to devide needs to give an equal number of traj to each learning device? shound't each actor Thread have a designated N learneres? If we have less actor T than learners then ech actor will devide based on the num_env and gives to N actors, ig to lessen the managment each actor gives to all of the learners? + #this deviates from the paper? + assert ( + int(config.arch.num_envs / len(config.arch.learner_device_ids)) + * config.arch.n_threads_per_executor + % config.system.num_minibatches + == 0 + ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" #this one makes sense but the assertion is a bit off? + + # Setup learner. + ( + multi_device_update, + agents_state, + apply_fns, + ) = learner_setup((actor_net_key, critic_net_key), config, learner_devices) + + # Setup evaluator. + eval_envs = make(config)(config.arch.num_eval_episodes) # type: ignore + evaluator = evaluator_setup(eval_envs=eval_envs, apply_fn=apply_fns[0], config=config) + + # Calculate total timesteps. + batch_size = int( + config.arch.num_envs + * config.system.rollout_length + * config.arch.n_threads_per_executor + * len(config.arch.executor_device_ids) + ) + config.system.total_timesteps = config.system.num_updates * batch_size + + # Setup logger. + config.arch.log_frequency = config.system.num_updates // config.arch.num_evaluation + logger = Logger(config) + cfg_dict: Dict = OmegaConf.to_container(config, resolve=True) + pprint(cfg_dict) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=cfg_dict, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + if config.logger.checkpointing.load_model: + print( + f"{Fore.RED}{Style.BRIGHT}Loading checkpoint is not supported\ + for sebulba architecture yet{Style.RESET_ALL}" + ) + + # Executor setup and launch. + unreplicated_params = flax.jax_utils.unreplicate(agents_state.params) + params_queues: List = [] + rollout_queues: List = [] + for d_idx, d_id in enumerate( # Loop through each executor device + config.arch.executor_device_ids + ): + # Replicate params per executor device + device_params = jax.device_put(unreplicated_params, local_devices[d_id]) + # Loop through each executor thread + for thread_id in range(config.arch.n_threads_per_executor): + params_queues.append(queue.Queue(maxsize=1)) + rollout_queues.append(queue.Queue(maxsize=1)) + params_queues[-1].put(device_params) + threading.Thread( + target=rollout, + args=( + jax.device_put(rng, local_devices[d_id]), + config, + rollout_queues[-1], + params_queues[-1], + d_idx * config.arch.n_threads_per_executor + thread_id, + apply_fns, + logger, + learner_devices, + ), + ).start() + + # Run experiment for the total number of updates. + rollout_queue_get_time: deque = deque(maxlen=10) + data_transfer_time: deque = deque(maxlen=10) + trainer_update_number = 0 + max_episode_return = jnp.float32(0.0) + best_params = None + while True: + trainer_update_number += 1 + rollout_queue_get_time_start = time.time() + sharded_storages = [] + sharded_next_obss = [] + sharded_next_dones = [] + sharded_next_action_masks = [] + + # Loop through each executor device + for d_idx, _ in enumerate(config.arch.executor_device_ids): + # Loop through each executor thread + for thread_id in range(config.arch.n_threads_per_executor): + # Get data from rollout queue + ( + t_env, + sharded_storage, + sharded_next_obs, + sharded_next_done, + sharded_next_action_mask, + avg_params_queue_get_time, + ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() + sharded_storages.append(sharded_storage) + sharded_next_obss.append(sharded_next_obs) + sharded_next_dones.append(sharded_next_done) + sharded_next_action_masks.append(sharded_next_action_mask) + + rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) + training_time_start = time.time() + + # Update learner + (agents_state, learner_keys, loss_info) = multi_device_update( # type: ignore + agents_state, + sharded_storages, + sharded_next_obss, + sharded_next_dones, + sharded_next_action_masks, + learner_keys, + ) + + # Send updated params to executors + unreplicated_params = flax.jax_utils.unreplicate(agents_state.params) + for d_idx, d_id in enumerate(config.arch.executor_device_ids): + device_params = jax.device_put(unreplicated_params, local_devices[d_id]) + for thread_id in range(config.arch.n_threads_per_executor): + params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( + device_params + ) + + if trainer_update_number % config.arch.log_frequency == 0: + # Logging training info + logger.log_trainer_metrics( + experiment_output={ + "loss_info": loss_info, + "queue_info": { + "rollout_queue_get_time": np.mean(rollout_queue_get_time), + "data_transfer_time": np.mean(data_transfer_time), + "rollout_params_queue_get_time_diff": np.mean(rollout_queue_get_time) + - avg_params_queue_get_time, + "rollout_queue_size": rollout_queues[0].qsize(), + "params_queue_size": params_queues[0].qsize(), + }, + "speed_info": { + "training_time": time.time() - training_time_start, + "trainer_update_number": trainer_update_number, + }, + }, + t_env=t_env, + ) + + # Evaluation + rng_e, _ = jax.random.split(rng_e) + evaluator_output = evaluator(params=unreplicated_params, rng=rng_e) + # Log the results of the evaluation. + episode_return = logger.log_evaluator_metrics( + t_env=t_env, + metrics=evaluator_output, + eval_step=trainer_update_number, + ) + + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=t_env, + unreplicated_learner_state=flax.jax_utils.unreplicate(agents_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(unreplicated_params) + max_episode_return = episode_return + + # Check if training is finished + if trainer_update_number >= config.system.num_updates: + rng_e, _ = jax.random.split(rng_e) + # Measure absolute metric + evaluator_output = evaluator(params=best_params, rng=rng_e, eval_multiplier=10) + # Log the results of the evaluation. + logger.log_evaluator_metrics( + t_env=t_env, + metrics=evaluator_output, + eval_step=trainer_update_number + 1, + absolute_metric=True, + ) + break + + +@hydra.main(config_path="../../configs", config_name="default_ff_ippo.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> None: + """Experiment entry point.""" + + # Run experiment. + run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}IPPO experiment completed{Style.RESET_ALL}") + + +if __name__ == "__main__": + hydra_entry_point() \ No newline at end of file From a435a0afa12551685255ac25d1332bb2bf21244f Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 13 Jun 2024 23:51:28 +0100 Subject: [PATCH 016/125] feat: initial learner / training loop --- mava/systems/anakin/ppo/ff_ippo.py | 2 +- mava/systems/sebulba/ppo/ff_ippo.py | 480 +++++++++++++++++----------- mava/systems/sebulba/ppo/test.py | 2 +- mava/utils/checkpointing.py | 2 +- 4 files changed, 298 insertions(+), 188 deletions(-) diff --git a/mava/systems/anakin/ppo/ff_ippo.py b/mava/systems/anakin/ppo/ff_ippo.py index 7b45fb45f..44e196535 100644 --- a/mava/systems/anakin/ppo/ff_ippo.py +++ b/mava/systems/anakin/ppo/ff_ippo.py @@ -578,7 +578,7 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_ippo.yaml", version_base="1.2") +@hydra.main(config_path="../../../configs", config_name="default_ff_ippo.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index c9a2069b2..95e722546 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -14,14 +14,17 @@ import copy import time -from typing import Any, Dict, Tuple - +from typing import Any, Dict, Tuple, List +import threading import chex import flax import hydra import jax import jax.numpy as jnp +import numpy as np import optax +import queue +from collections import deque from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict from jumanji.env import Environment @@ -32,8 +35,8 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn +from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition #todo: change this +from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import ( @@ -47,8 +50,157 @@ from mava.wrappers.episode_metrics import get_final_step_metrics +def rollout( + rng: chex.PRNGKey, + config: DictConfig, + rollout_queue: queue.Queue, + params_queue: queue.Queue, + device_thread_id: int, + apply_fns: Tuple, + logger: MavaLogger, + learner_devices: List): + + #create envs + env = environments.make(config) + + #setup + len_executor_device_ids = len(config.arch.executor_device_ids) + t_env = 0 + start_time = time.time() + + actor_apply_fn, critic_apply_fn = apply_fns + + # Define the util functions: select action function and prepare data to share it with learner. + @jax.jit + def get_action_and_value( + params: FrozenDict, + observation: Observation, + key: chex.PRNGKey, + ) -> Tuple: + """Get action and value.""" + key, subkey = jax.random.split(key) + + policy = actor_apply_fn(params.actor_params, observation) + action, log_prob = policy.sample_and_log_prob(seed=subkey) + + value = critic_apply_fn(params.critic_params, observation).squeeze() + return action, log_prob, value, key + + @jax.jit + def prepare_data(storage: List[PPOTransition]) -> PPOTransition: + """Prepare data to share with learner.""" + return jax.tree_map( # type: ignore + lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage + ) + + + # Define queues to track time + params_queue_get_time: deque = deque(maxlen=10) + rollout_time: deque = deque(maxlen=10) + rollout_queue_put_time: deque = deque(maxlen=10) + + next_obs, next_rewards, next_dones , extra = env.reset() + + # Loop till the learner has finished training + for update in range(1, config.system.num_updates + 2): + # Setup + env_recv_time: float = 0 + inference_time: float = 0 + storage_time: float = 0 + env_send_time: float = 0 + + # Get the latest parameters from the learner + params_queue_get_time_start = time.time() + params = params_queue.get() + params_queue_get_time.append(time.time() - params_queue_get_time_start) + + # Rollout + rollout_time_start = time.time() + storage: List = [] + # Loop over the rollout length + for _ in range(0, config.system.rollout_length): + # Cached for transition + cached_next_obs = next_obs + cached_next_dones = next_dones + + # Increment current timestep + t_env += ( + config.arch.n_threads_per_executor * len_executor_device_ids * config.arch.num_envs + ) + + # Get action and value + inference_time_start = time.time() + + ( + action, + log_prob, + value, + rng, + ) = get_action_and_value(params, cached_next_obs, rng) + inference_time += time.time() - inference_time_start + + # Step the environment + env_send_time_start = time.time() + cpu_action = np.array(action) + next_obs, next_reward, next_dones, extra = env.step(cpu_action) + + next_dones = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), + (next_dones), + ) + + # Append data to storage + env_send_time += time.time() - env_send_time_start + storage_time_start = time.time() + storage.append( + PPOTransition( + done=cached_next_dones, + action=action, + value=value, + reward=next_reward, + log_prob=log_prob, + obs=cached_next_obs, + info=extra, + ) + ) + storage_time += time.time() - storage_time_start + + rollout_time.append(time.time() - rollout_time_start) + + # Prepare data to share with learner + # todo: investigate the thread --> single learning + partitioned_storage = prepare_data(storage) + sharded_storage = PPOTransition( + *list( # noqa: C417 + map( + lambda x: jax.device_put_sharded(x, devices=learner_devices), # type: ignore + partitioned_storage, + ) + ) + ) + + sharded_next_obs = jax.device_put_sharded( + np.split(next_obs, len(learner_devices)), devices=learner_devices + ) + sharded_next_done = jax.device_put_sharded( + np.split(next_dones, len(learner_devices)), devices=learner_devices + ) + + payload = ( + t_env, + sharded_storage, + sharded_next_obs, + sharded_next_done, + np.mean(params_queue_get_time), + ) + + # Put data in the rollout queue to share it with the learner + rollout_queue_put_time_start = time.time() + rollout_queue.put(payload) + rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) + + def get_learner_fn( - env: Environment, apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, @@ -59,7 +211,7 @@ def get_learner_fn( actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: + def _update_step(learner_state: LearnerState, _: Any, traj_batch : PPOTransition, last_obs: chex.Array, last_done: chex.Array) -> Tuple[LearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -77,71 +229,32 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup _ (Any): The current metrics info. """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: - """Step the environment.""" - params, opt_states, key, env_state, last_timestep = learner_state - - # SELECT ACTION - key, policy_key = jax.random.split(key) - actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) - value = critic_apply_fn(params.critic_params, last_timestep.observation) - - action = actor_policy.sample(seed=policy_key) - log_prob = actor_policy.log_prob(action) - - # STEP ENVIRONMENT - env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - - # LOG EPISODE METRICS - done = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) - info = timestep.extras["episode_metrics"] - - transition = PPOTransition( - done, action, value, timestep.reward, log_prob, last_timestep.observation, info - ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) - return learner_state, transition - - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( - _env_step, learner_state, None, config.system.rollout_length - ) - - # CALCULATE ADVANTAGE - params, opt_states, key, env_state, last_timestep = learner_state - last_val = critic_apply_fn(params.critic_params, last_timestep.observation) - - def _calculate_gae( - traj_batch: PPOTransition, last_val: chex.Array + def _calculate_gae( #todo: lake sure this is appropriate + traj_batch: PPOTransition, last_val: chex.Array, last_done: chex.Array ) -> Tuple[chex.Array, chex.Array]: - """Calculate the GAE.""" - - def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple: - """Calculate the GAE for a single transition.""" - gae, next_value = gae_and_next_value - done, value, reward = ( - transition.done, - transition.value, - transition.reward, - ) + def _get_advantages( + carry: Tuple[chex.Array, chex.Array, chex.Array], transition: PPOTransition + ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: + gae, next_value, next_done = carry + done, value, reward = transition.done, transition.value, transition.reward gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae - return (gae, value), gae + delta = reward + gamma * next_value * (1 - next_done) - value + gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae + return (gae, value, done), gae _, advantages = jax.lax.scan( _get_advantages, - (jnp.zeros_like(last_val), last_val), + (jnp.zeros_like(last_val), last_val, last_done), traj_batch, reverse=True, unroll=16, ) return advantages, advantages + traj_batch.value - - advantages, targets = _calculate_gae(traj_batch, last_val) + + # CALCULATE ADVANTAGE + params, opt_states, key, _, _ = learner_state + last_val = critic_apply_fn(params.critic_params, last_obs) + advantages, targets = _calculate_gae(traj_batch, last_val, last_done) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" @@ -304,11 +417,11 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + learner_state = LearnerState(params, opt_states, key) metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: + def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_done: chex.Array) -> ExperimentOutput[LearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -325,9 +438,11 @@ def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """ batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + + partial_batched_update_step = lambda learner_state, xs : batched_update_step(learner_state, xs, traj_batch , last_obs, last_done) learner_state, (episode_info, loss_info) = jax.lax.scan( - batched_update_step, learner_state, None, config.system.num_updates_per_eval + partial_batched_update_step, learner_state, None, config.system.num_updates_per_eval ) return ExperimentOutput( learner_state=learner_state, @@ -339,16 +454,18 @@ def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: def learner_setup( - env: Environment, keys: chex.Array, config: DictConfig + keys: chex.Array, config: DictConfig, learner_devices: List ) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: """Initialise learner_fn, network, optimiser, environment and states.""" # Get available TPU cores. - devices = jax.devices() - learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] n_devices = len(learner_devices) - - # Get number of agents. - config.system.num_agents = env.num_agents + + #create temporory envoirnments. + env = environments.make(config) + # Get number of agents and actions. + action_space = env.single_action_space + config.system.num_agents = len(action_space) + config.system.num_actions = action_space[0].n # PRNG keys. key, actor_net_key, critic_net_key = keys @@ -375,9 +492,10 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Initialise observation with obs of all agents. - obs = env.single_observation_space.sample() - init_x = jax.tree_util.tree_map(lambda x: x[jnp.newaxis, ...], obs) + # Initialise observation: Select only obs for a single agent. + init_obs = np.array([env.single_observation_space.sample()[0]]) + init_action_mask = np.ones((1, config.system.num_actions)) + init_x = Observation(init_obs, init_action_mask) # Initialise actor params and optimiser state. actor_params = actor_network.init(actor_net_key, init_x) @@ -398,20 +516,6 @@ def learner_setup( learn = get_learner_fn(env, apply_fns, update_fns, config) learn = jax.pmap(learn, axis_name="device", devices = learner_devices) - # Initialise environment states and timesteps: across devices and batches. - key, *env_keys = jax.random.split( - key, n_devices * config.system.update_batch_size * config.arch.num_envs + 1 - ) - env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( - jnp.stack(env_keys), - ) - reshape_states = lambda x: x.reshape( - (n_devices, config.system.update_batch_size, config.arch.num_envs) + x.shape[1:] - ) - # (devices, update batch size, num_envs, ...) - env_states = jax.tree_map(reshape_states, env_states) - timesteps = jax.tree_map(reshape_states, timesteps) - # Load model from checkpoint if specified. if config.logger.checkpointing.load_model: loaded_checkpoint = Checkpointer( @@ -424,50 +528,63 @@ def learner_setup( params = restored_params # Define params to be replicated across devices and batches. - key, step_keys = jax.random.split(key) opt_states = OptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states, step_keys) + replicate_learner = (params, opt_states) # Duplicate learner for update_batch_size. broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size,) + x.shape) replicate_learner = jax.tree_map(broadcast, replicate_learner) - # Duplicate learner across devices. - replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) + # Duplicate learner across Learner devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=learner_devices) # Initialise learner state. - params, opt_states, step_keys = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + params, opt_states = replicate_learner + init_learner_state = LearnerState(params, opt_states) + env.close() - return learn, actor_network, init_learner_state + return learn, apply_fns, init_learner_state def run_experiment(_config: DictConfig) -> float: """Runs experiment.""" config = copy.deepcopy(_config) - n_devices = len(jax.devices()) - - # Create the enviroments for train and eval. - env, eval_env = environments.make(config) + devices = jax.devices() # todo: use local devices insted? + learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] # PRNG keys. key, key_e, actor_net_key, critic_net_key = jax.random.split( jax.random.PRNGKey(config.system.seed), num=4 ) + learner_keys = jax.device_put_replicated(key, learner_devices) + + # Sanity check of config + assert ( + config.arch.num_envs % len(config.arch.learner_device_ids) == 0 + ), "The number of environments need to be divisible by the number of learners " + + assert ( + int(config.arch.num_envs / len(config.arch.learner_device_ids)) + * config.arch.n_threads_per_executor + % config.system.num_minibatches + == 0 + ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" + + # Setup learner. - learn, actor_network, learner_state = learner_setup( - env, (key, actor_net_key, critic_net_key), config + learn, apply_fns , learner_state = learner_setup( + learner_keys, config, learner_devices ) # Setup evaluator. # One key per device for evaluation. - eval_keys = jax.random.split(key_e, n_devices) - evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor_network.apply, config) + #eval_keys = jax.random.split(key_e, n_devices) # todo: well add the evaluations :) + #evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor_network.apply, config) # Calculate total timesteps. - config = check_total_timesteps(config) + config = check_total_timesteps(config) #todo: update this for sebulba assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." @@ -475,7 +592,8 @@ def run_experiment(_config: DictConfig) -> float: # Calculate number of updates per evaluation. config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation steps_per_rollout = ( - n_devices + config.arch.executor_device_ids + * config.arch.n_threads_per_executor * config.system.num_updates_per_eval * config.system.rollout_length * config.system.update_batch_size @@ -496,91 +614,83 @@ def run_experiment(_config: DictConfig) -> float: model_name=config.logger.system_name, **config.logger.checkpointing.save_args, # Checkpoint args ) - - # Run experiment for a total number of evaluations. - max_episode_return = -jnp.inf + + # Executor setup and launch. + unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) + params_queues: List = [] + rollout_queues: List = [] + for d_idx, d_id in enumerate( # Loop through each executor device + config.arch.executor_device_ids + ): + # Replicate params per executor device + device_params = jax.device_put(unreplicated_params, devices[d_id]) + # Loop through each executor thread + for thread_id in range(config.arch.n_threads_per_executor): + params_queues.append(queue.Queue(maxsize=1)) + rollout_queues.append(queue.Queue(maxsize=1)) + params_queues[-1].put(device_params) + threading.Thread( + target=rollout, + args=( + jax.device_put(key, devices[d_id]), + config, + rollout_queues[-1], + params_queues[-1], + d_idx * config.arch.n_threads_per_executor + thread_id, + apply_fns, + logger, + learner_devices, + ), + ).start() + + # Run experiment for the total number of updates. + rollout_queue_get_time: deque = deque(maxlen=10) + data_transfer_time: deque = deque(maxlen=10) + trainer_update_number = 0 + max_episode_return = jnp.float32(0.0) best_params = None - for eval_step in range(config.arch.num_evaluation): - # Train. - start_time = time.time() - - learner_output = learn(learner_state) - jax.block_until_ready(learner_output) - - # Log the results of the training. - elapsed_time = time.time() - start_time - t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time - - # Separately log timesteps, actoring metrics and training metrics. - logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) - if ep_completed: # only log episode metrics if an episode was completed in the rollout. - logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) - - # Prepare for evaluation. - start_time = time.time() - - trained_params = unreplicate_batch_dim(learner_state.params.actor_params) - key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) - eval_keys = jnp.stack(eval_keys) - eval_keys = eval_keys.reshape(n_devices, -1) - - # Evaluate. - evaluator_output = evaluator(trained_params, eval_keys) - jax.block_until_ready(evaluator_output) - - # Log the results of the evaluation. - elapsed_time = time.time() - start_time - episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) - - steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) - evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) - - if save_checkpoint: - # Save checkpoint of learner state - checkpointer.save( - timestep=steps_per_rollout * (eval_step + 1), - unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), - episode_return=episode_return, - ) - - if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(trained_params) - max_episode_return = episode_return - - # Update runner state to continue training. - learner_state = learner_output.learner_state - - # Record the performance for the final evaluation run. - eval_performance = float(jnp.mean(evaluator_output.episode_metrics[config.env.eval_metric])) - - # Measure absolute metric. - if config.arch.absolute_metric: - start_time = time.time() - - key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) - eval_keys = jnp.stack(eval_keys) - eval_keys = eval_keys.reshape(n_devices, -1) - - evaluator_output = absolute_metric_evaluator(best_params, eval_keys) - jax.block_until_ready(evaluator_output) - - elapsed_time = time.time() - start_time - steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) - t = int(steps_per_rollout * (eval_step + 1)) - evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) - - # Stop the logger. - logger.stop() + while True: + trainer_update_number += 1 + rollout_queue_get_time_start = time.time() + sharded_storages = [] + sharded_next_obss = [] + sharded_next_dones = [] + sharded_next_action_masks = [] + + # Loop through each executor device + for d_idx, _ in enumerate(config.arch.executor_device_ids): + # Loop through each executor thread + for thread_id in range(config.arch.n_threads_per_executor): + # Get data from rollout queue + ( + t_env, + sharded_storage, + sharded_next_obs, + sharded_next_done, + avg_params_queue_get_time, + ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() + sharded_storages.append(sharded_storage) + sharded_next_obss.append(sharded_next_obs) + sharded_next_dones.append(sharded_next_done) + + rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) + training_time_start = time.time() + + learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_dones) + + # Send updated params to executors + unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) + for d_idx, d_id in enumerate(config.arch.executor_device_ids): + device_params = jax.device_put(unreplicated_params, devices[d_id]) + for thread_id in range(config.arch.n_threads_per_executor): + params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( + device_params + ) - return eval_performance + return None#eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_ippo.yaml", version_base="1.2") +@hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/mava/systems/sebulba/ppo/test.py b/mava/systems/sebulba/ppo/test.py index b868f69b6..fa3798ce5 100644 --- a/mava/systems/sebulba/ppo/test.py +++ b/mava/systems/sebulba/ppo/test.py @@ -21,7 +21,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition #todo: change this +from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition #todo: change this from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer diff --git a/mava/utils/checkpointing.py b/mava/utils/checkpointing.py index 8955f76ce..230c4938d 100644 --- a/mava/utils/checkpointing.py +++ b/mava/utils/checkpointing.py @@ -24,7 +24,7 @@ from jax.tree_util import tree_map from omegaconf import DictConfig, OmegaConf -from mava.systems.ppo.types import HiddenStates, Params +from mava.systems.anakin.ppo.types import HiddenStates, Params from mava.types import MavaState # Keep track of the version of the checkpointer From 7e80d7b5f345f5606684bfbc050fca301b700cff Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 14 Jun 2024 12:46:32 +0100 Subject: [PATCH 017/125] fix: changes the env creation --- mava/systems/sebulba/ppo/ff_ippo.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 95e722546..779891cfb 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -27,7 +27,6 @@ from collections import deque from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict -from jumanji.env import Environment from omegaconf import DictConfig, OmegaConf from optax._src.base import OptState from rich.pretty import pprint @@ -61,7 +60,7 @@ def rollout( learner_devices: List): #create envs - env = environments.make(config) + env = environments.make_gym_env(config.env.scenario.name, config) #setup len_executor_device_ids = len(config.arch.executor_device_ids) @@ -461,7 +460,7 @@ def learner_setup( n_devices = len(learner_devices) #create temporory envoirnments. - env = environments.make(config) + env = environments.make_gym_env(config.env.scenario.name, config) # Get number of agents and actions. action_space = env.single_action_space config.system.num_agents = len(action_space) From b961336e21e75aa41821047e935a6bb4aa8eb292 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 15 Jun 2024 21:36:36 +0100 Subject: [PATCH 018/125] fix: fixed function calls --- mava/configs/arch/sebulba.yaml | 2 +- mava/systems/sebulba/ppo/ff_ippo.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 98cd4d96d..ac8c4eb75 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,6 +1,6 @@ # --- Sebulba config --- arch_name: "sebulba" -num_envs: 16 # number of envs per thread +num_envs: 2 # number of envs per thread # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 779891cfb..671e6f65c 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -60,7 +60,7 @@ def rollout( learner_devices: List): #create envs - env = environments.make_gym_env(config.env.scenario.name, config) + env = environments.make_gym_env(config) #setup len_executor_device_ids = len(config.arch.executor_device_ids) @@ -460,19 +460,19 @@ def learner_setup( n_devices = len(learner_devices) #create temporory envoirnments. - env = environments.make_gym_env(config.env.scenario.name, config) + env = environments.make_gym_env(config) # Get number of agents and actions. action_space = env.single_action_space config.system.num_agents = len(action_space) config.system.num_actions = action_space[0].n # PRNG keys. - key, actor_net_key, critic_net_key = keys + actor_net_key, critic_net_key = keys # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) actor_action_head = hydra.utils.instantiate( - config.network.action_head, action_dim=env.action_dim + config.network.action_head, action_dim=config.system.num_actions ) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) @@ -494,7 +494,7 @@ def learner_setup( # Initialise observation: Select only obs for a single agent. init_obs = np.array([env.single_observation_space.sample()[0]]) init_action_mask = np.ones((1, config.system.num_actions)) - init_x = Observation(init_obs, init_action_mask) + init_x = Observation(init_obs, init_action_mask, None) # Initialise actor params and optimiser state. actor_params = actor_network.init(actor_net_key, init_x) @@ -512,7 +512,7 @@ def learner_setup( update_fns = (actor_optim.update, critic_optim.update) # Get batched iterated update and replicate it to pmap it over cores. - learn = get_learner_fn(env, apply_fns, update_fns, config) + learn = get_learner_fn(apply_fns, update_fns, config) learn = jax.pmap(learn, axis_name="device", devices = learner_devices) # Load model from checkpoint if specified. @@ -539,7 +539,7 @@ def learner_setup( # Initialise learner state. params, opt_states = replicate_learner - init_learner_state = LearnerState(params, opt_states) + init_learner_state = LearnerState(params, opt_states, None, None, None) env.close() return learn, apply_fns, init_learner_state @@ -574,7 +574,7 @@ def run_experiment(_config: DictConfig) -> float: # Setup learner. learn, apply_fns , learner_state = learner_setup( - learner_keys, config, learner_devices + (actor_net_key, critic_net_key), config, learner_devices ) # Setup evaluator. @@ -591,7 +591,7 @@ def run_experiment(_config: DictConfig) -> float: # Calculate number of updates per evaluation. config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation steps_per_rollout = ( - config.arch.executor_device_ids + len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor * config.system.num_updates_per_eval * config.system.rollout_length From 502730d4d82fb62a3d085a30d13f17c3978f6768 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 22 Jun 2024 12:03:38 +0100 Subject: [PATCH 019/125] fix: fixed the training and added training logger --- mava/configs/arch/sebulba.yaml | 4 +- mava/systems/anakin/ppo/ff_ippo.py | 4 +- mava/systems/anakin/ppo/ff_mappo.py | 4 +- mava/systems/anakin/ppo/rec_ippo.py | 4 +- mava/systems/anakin/ppo/rec_mappo.py | 4 +- mava/systems/anakin/q_learning/rec_iql.py | 4 +- mava/systems/anakin/sac/ff_isac.py | 4 +- mava/systems/anakin/sac/ff_masac.py | 4 +- mava/systems/sebulba/ppo/ff_ippo.py | 162 +++++++++++----------- mava/systems/sebulba/ppo/orig.py | 5 +- mava/systems/sebulba/ppo/test.py | 23 ++- mava/utils/total_timestep_checker.py | 32 ++++- 12 files changed, 145 insertions(+), 109 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index ac8c4eb75..cd47dca13 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,6 +1,6 @@ # --- Sebulba config --- arch_name: "sebulba" -num_envs: 2 # number of envs per thread +num_envs: 4 # number of envs per thread # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select @@ -12,7 +12,7 @@ absolute_metric: True # Whether the absolute metric should be computed. For more # on the absolute metric please see: https://arxiv.org/abs/2209.10485 # --- Sebulba devices config --- -n_threads_per_executor: 1 # num of different threads/env batches per actor +n_threads_per_executor: 2 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices diff --git a/mava/systems/anakin/ppo/ff_ippo.py b/mava/systems/anakin/ppo/ff_ippo.py index 44e196535..98920428e 100644 --- a/mava/systems/anakin/ppo/ff_ippo.py +++ b/mava/systems/anakin/ppo/ff_ippo.py @@ -42,7 +42,7 @@ unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.total_timestep_checker import anakin_check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -465,7 +465,7 @@ def run_experiment(_config: DictConfig) -> float: evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor_network.apply, config) # Calculate total timesteps. - config = check_total_timesteps(config) + config = anakin_check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." diff --git a/mava/systems/anakin/ppo/ff_mappo.py b/mava/systems/anakin/ppo/ff_mappo.py index 519fa4f39..dda1ef14b 100644 --- a/mava/systems/anakin/ppo/ff_mappo.py +++ b/mava/systems/anakin/ppo/ff_mappo.py @@ -41,7 +41,7 @@ unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.total_timestep_checker import anakin_check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -462,7 +462,7 @@ def run_experiment(_config: DictConfig) -> float: evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor_network.apply, config) # Calculate total timesteps. - config = check_total_timesteps(config) + config = anakin_check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." diff --git a/mava/systems/anakin/ppo/rec_ippo.py b/mava/systems/anakin/ppo/rec_ippo.py index e70a59f07..5aff93ee6 100644 --- a/mava/systems/anakin/ppo/rec_ippo.py +++ b/mava/systems/anakin/ppo/rec_ippo.py @@ -45,7 +45,7 @@ from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.total_timestep_checker import anakin_check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -622,7 +622,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 ) # Calculate total timesteps. - config = check_total_timesteps(config) + config = anakin_check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." diff --git a/mava/systems/anakin/ppo/rec_mappo.py b/mava/systems/anakin/ppo/rec_mappo.py index 14284cedb..7efbad9d2 100644 --- a/mava/systems/anakin/ppo/rec_mappo.py +++ b/mava/systems/anakin/ppo/rec_mappo.py @@ -45,7 +45,7 @@ from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.total_timestep_checker import anakin_check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -614,7 +614,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 ) # Calculate total timesteps. - config = check_total_timesteps(config) + config = anakin_check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." diff --git a/mava/systems/anakin/q_learning/rec_iql.py b/mava/systems/anakin/q_learning/rec_iql.py index 6be8e61a4..60fd98d5c 100644 --- a/mava/systems/anakin/q_learning/rec_iql.py +++ b/mava/systems/anakin/q_learning/rec_iql.py @@ -52,7 +52,7 @@ unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.total_timestep_checker import anakin_check_total_timesteps from mava.wrappers import episode_metrics @@ -528,7 +528,7 @@ def update_step( def run_experiment(cfg: DictConfig) -> float: # Add runtime variables to config cfg.arch.n_devices = len(jax.devices()) - cfg = check_total_timesteps(cfg) + cfg = anakin_check_total_timesteps(cfg) # Number of env steps before evaluating/logging. steps_per_rollout = int(cfg.system.total_timesteps // cfg.arch.num_evaluation) diff --git a/mava/systems/anakin/sac/ff_isac.py b/mava/systems/anakin/sac/ff_isac.py index 2c33028d1..7e4e20335 100644 --- a/mava/systems/anakin/sac/ff_isac.py +++ b/mava/systems/anakin/sac/ff_isac.py @@ -51,7 +51,7 @@ from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.total_timestep_checker import anakin_check_total_timesteps from mava.wrappers import episode_metrics @@ -483,7 +483,7 @@ def update_step(carry: LearnerState, _: Any) -> Tuple[LearnerState, Tuple[Metric def run_experiment(cfg: DictConfig) -> float: # Add runtime variables to config cfg.arch.n_devices = len(jax.devices()) - cfg = check_total_timesteps(cfg) + cfg = anakin_check_total_timesteps(cfg) # Number of env steps before evaluating/logging. steps_per_rollout = int(cfg.system.total_timesteps // cfg.arch.num_evaluation) diff --git a/mava/systems/anakin/sac/ff_masac.py b/mava/systems/anakin/sac/ff_masac.py index 4401906ee..d5fb9172d 100644 --- a/mava/systems/anakin/sac/ff_masac.py +++ b/mava/systems/anakin/sac/ff_masac.py @@ -52,7 +52,7 @@ from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.total_timestep_checker import anakin_check_total_timesteps from mava.wrappers import episode_metrics @@ -501,7 +501,7 @@ def update_step(carry: LearnerState, _: Any) -> Tuple[LearnerState, Tuple[Metric def run_experiment(cfg: DictConfig) -> float: # Add runtime variables to config cfg.arch.n_devices = len(jax.devices()) - cfg = check_total_timesteps(cfg) + cfg = anakin_check_total_timesteps(cfg) # Number of env steps before evaluating/logging. steps_per_rollout = int(cfg.system.total_timesteps // cfg.arch.num_evaluation) diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 671e6f65c..f5a97b807 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -20,6 +20,7 @@ import flax import hydra import jax +import jax.debug import jax.numpy as jnp import numpy as np import optax @@ -34,8 +35,8 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition #todo: change this -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, Observation +from mava.systems.sebulba.ppo.types import LearnerState, OptStates, Params, PPOTransition, Observation #todo: change this Observation to use the origial one +from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import ( @@ -44,26 +45,28 @@ unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.total_timestep_checker import sebulba_check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics def rollout( - rng: chex.PRNGKey, + key: chex.PRNGKey, config: DictConfig, rollout_queue: queue.Queue, params_queue: queue.Queue, device_thread_id: int, apply_fns: Tuple, logger: MavaLogger, - learner_devices: List): + learner_devices: List, + actor_device_id : int): #create envs env = environments.make_gym_env(config) #setup len_executor_device_ids = len(config.arch.executor_device_ids) + current_actor_device = jax.devices()[actor_device_id] t_env = 0 start_time = time.time() @@ -78,9 +81,10 @@ def get_action_and_value( ) -> Tuple: """Get action and value.""" key, subkey = jax.random.split(key) - - policy = actor_apply_fn(params.actor_params, observation) - action, log_prob = policy.sample_and_log_prob(seed=subkey) + + actor_policy = actor_apply_fn(params.actor_params, observation) + action = actor_policy.sample(seed=subkey) + log_prob = actor_policy.log_prob(action) value = critic_apply_fn(params.critic_params, observation).squeeze() return action, log_prob, value, key @@ -89,7 +93,7 @@ def get_action_and_value( def prepare_data(storage: List[PPOTransition]) -> PPOTransition: """Prepare data to share with learner.""" return jax.tree_map( # type: ignore - lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage + lambda *xs: jnp.stack(xs), *storage ) @@ -98,7 +102,10 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: rollout_time: deque = deque(maxlen=10) rollout_queue_put_time: deque = deque(maxlen=10) - next_obs, next_rewards, next_dones , extra = env.reset() + next_obs , info = env.reset() + next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) + + move_to_device = lambda x : jax.device_put(x, device = current_actor_device) # Loop till the learner has finished training for update in range(1, config.system.num_updates + 2): @@ -113,15 +120,16 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: params = params_queue.get() params_queue_get_time.append(time.time() - params_queue_get_time_start) - # Rollout + # Rollout rollout_time_start = time.time() storage: List = [] # Loop over the rollout length for _ in range(0, config.system.rollout_length): # Cached for transition - cached_next_obs = next_obs - cached_next_dones = next_dones - + cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) + cached_next_dones = move_to_device(next_dones) + cashed_action_mask = move_to_device(jnp.stack([*info["actions_mask"]], axis = 0) ) #unpack the numpy object, find a more pythonic way? + # Increment current timestep t_env += ( config.arch.n_threads_per_executor * len_executor_device_ids * config.arch.num_envs @@ -129,24 +137,20 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: # Get action and value inference_time_start = time.time() - + # ( action, log_prob, value, - rng, - ) = get_action_and_value(params, cached_next_obs, rng) + key, + ) = get_action_and_value(params, Observation(cached_next_obs, cashed_action_mask), key) inference_time += time.time() - inference_time_start # Step the environment env_send_time_start = time.time() - cpu_action = np.array(action) - next_obs, next_reward, next_dones, extra = env.step(cpu_action) - - next_dones = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - (next_dones), - ) + cpu_action = jax.device_get(action) + next_obs, next_reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) #num_env, num_agents --> num_agents, num_env + next_dones = np.logical_or(terminated, truncated) # Append data to storage env_send_time += time.time() - env_send_time_start @@ -158,38 +162,32 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: value=value, reward=next_reward, log_prob=log_prob, - obs=cached_next_obs, - info=extra, - ) + obs=Observation(cached_next_obs, cashed_action_mask), + info={"win_rate" : info.get("win_rate")}, + )#todo: use a threadsafe alt https://github.com/instadeepai/CityLearn/blob/27e69f8ebdf1789c55ffab5c326bfaa50733a5e7/power_systems/sax_sebulba.py#L39 ) storage_time += time.time() - storage_time_start rollout_time.append(time.time() - rollout_time_start) # Prepare data to share with learner - # todo: investigate the thread --> single learning + # todo: investigate te thread --> single learning partitioned_storage = prepare_data(storage) - sharded_storage = PPOTransition( - *list( # noqa: C417 - map( - lambda x: jax.device_put_sharded(x, devices=learner_devices), # type: ignore - partitioned_storage, - ) - ) - ) + #sorage has shape rollout_len, num_agents, num_envs, .... while the other vectors have num_agents, num_envs, ... -> their split axis is diffrent + shard_split_payload= lambda x, axis : jax.device_put_sharded(jnp.split(x, len(learner_devices), axis=axis), devices=learner_devices) + + sharded_storage = jax.tree_map(lambda x : shard_split_payload(x, 1) , partitioned_storage) - sharded_next_obs = jax.device_put_sharded( - np.split(next_obs, len(learner_devices)), devices=learner_devices - ) - sharded_next_done = jax.device_put_sharded( - np.split(next_dones, len(learner_devices)), devices=learner_devices - ) + sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis = 1), 0) + sharded_next_action_mask = shard_split_payload(jnp.stack([*info["actions_mask"]], axis = 0), 0) + sharded_next_done = shard_split_payload(next_dones, 0) payload = ( t_env, sharded_storage, sharded_next_obs, sharded_next_done, + sharded_next_action_mask, np.mean(params_queue_get_time), ) @@ -210,7 +208,7 @@ def get_learner_fn( actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, _: Any, traj_batch : PPOTransition, last_obs: chex.Array, last_done: chex.Array) -> Tuple[LearnerState, Tuple]: + def _update_step(learner_state: LearnerState, _: Any, traj_batch : PPOTransition, last_obs: chex.Array, last_action_mask : chex.Array, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -252,8 +250,8 @@ def _get_advantages( # CALCULATE ADVANTAGE params, opt_states, key, _, _ = learner_state - last_val = critic_apply_fn(params.critic_params, last_obs) - advantages, targets = _calculate_gae(traj_batch, last_val, last_done) + last_val = critic_apply_fn(params.critic_params, Observation(last_obs, last_action_mask)) + advantages, targets = _calculate_gae(traj_batch, last_val, last_dones) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" @@ -338,18 +336,11 @@ def _critic_loss_fn( # Compute the parallel mean (pmean) over the batch. # This calculation is inspired by the Anakin architecture demo notebook. # available at https://tinyurl.com/26tdzs5x - # This pmean could be a regular mean as the batch axis is on the same device. - actor_grads, actor_loss_info = jax.lax.pmean( - (actor_grads, actor_loss_info), axis_name="batch" - ) # pmean over devices. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="device" ) - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" - ) # pmean over devices. critic_grads, critic_loss_info = jax.lax.pmean( (critic_grads, critic_loss_info), axis_name="device" @@ -370,7 +361,6 @@ def _critic_loss_fn( # PACK NEW PARAMS AND OPTIMISER STATE new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO total_loss = actor_loss_info[0] + critic_loss_info[0] value_loss = critic_loss_info[1] @@ -386,9 +376,8 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES - batch_size = config.system.rollout_length * config.arch.num_envs + batch_size = config.system.rollout_length * config.arch.num_envs * len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) @@ -399,7 +388,6 @@ def _critic_loss_fn( lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])), shuffled_batch, ) - # UPDATE MINIBATCHES (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches @@ -409,18 +397,17 @@ def _critic_loss_fn( return update_state, loss_info update_state = (params, opt_states, traj_batch, advantages, targets, key) - # UPDATE EPOCHS update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key) - metric = traj_batch.info + learner_state = LearnerState(params, opt_states, key, None, None) + metric = traj_batch.info #todo: metrci calcualtions return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_done: chex.Array) -> ExperimentOutput[LearnerState]: + def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_action_mask : chex.Array, last_dones : chex.Array) -> ExperimentOutput[LearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -435,14 +422,13 @@ def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs - env_state (LogEnvState): The environment state. - timesteps (TimeStep): The initial timestep in the initial trajectory. """ - - batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + # Broadcast static parameters for scan + partial_update_step = lambda learner_state, xs : _update_step(learner_state, xs, traj_batch , last_obs, last_action_mask, last_dones) - partial_batched_update_step = lambda learner_state, xs : batched_update_step(learner_state, xs, traj_batch , last_obs, last_done) - learner_state, (episode_info, loss_info) = jax.lax.scan( - partial_batched_update_step, learner_state, None, config.system.num_updates_per_eval + partial_update_step, learner_state, None, config.system.num_updates_per_eval ) + return ExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, @@ -467,7 +453,7 @@ def learner_setup( config.system.num_actions = action_space[0].n # PRNG keys. - actor_net_key, critic_net_key = keys + key, actor_net_key, critic_net_key = keys # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) @@ -492,9 +478,9 @@ def learner_setup( ) # Initialise observation: Select only obs for a single agent. - init_obs = np.array([env.single_observation_space.sample()[0]]) - init_action_mask = np.ones((1, config.system.num_actions)) - init_x = Observation(init_obs, init_action_mask, None) + init_obs = jnp.array([env.single_observation_space.sample()]) + init_action_mask = jnp.ones((config.system.num_agents, config.system.num_actions)) + init_x = Observation(init_obs, init_action_mask) # Initialise actor params and optimiser state. actor_params = actor_network.init(actor_net_key, init_x) @@ -527,19 +513,16 @@ def learner_setup( params = restored_params # Define params to be replicated across devices and batches. + key, step_keys = jax.random.split(key) opt_states = OptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states) - - # Duplicate learner for update_batch_size. - broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size,) + x.shape) - replicate_learner = jax.tree_map(broadcast, replicate_learner) + replicate_learner = (params, opt_states, step_keys) # Duplicate learner across Learner devices. replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=learner_devices) # Initialise learner state. - params, opt_states = replicate_learner - init_learner_state = LearnerState(params, opt_states, None, None, None) + params, opt_states, step_keys = replicate_learner + init_learner_state = LearnerState(params, opt_states, step_keys, None, None) env.close() return learn, apply_fns, init_learner_state @@ -562,7 +545,7 @@ def run_experiment(_config: DictConfig) -> float: # Sanity check of config assert ( config.arch.num_envs % len(config.arch.learner_device_ids) == 0 - ), "The number of environments need to be divisible by the number of learners " + ), "The number of environments must to be divisible by the number of learners " assert ( int(config.arch.num_envs / len(config.arch.learner_device_ids)) @@ -574,7 +557,7 @@ def run_experiment(_config: DictConfig) -> float: # Setup learner. learn, apply_fns , learner_state = learner_setup( - (actor_net_key, critic_net_key), config, learner_devices + (key ,actor_net_key, critic_net_key), config, learner_devices ) # Setup evaluator. @@ -583,7 +566,7 @@ def run_experiment(_config: DictConfig) -> float: #evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor_network.apply, config) # Calculate total timesteps. - config = check_total_timesteps(config) #todo: update this for sebulba + config = sebulba_check_total_timesteps(config) #todo: update this for sebulba assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." @@ -595,7 +578,6 @@ def run_experiment(_config: DictConfig) -> float: * config.arch.n_threads_per_executor * config.system.num_updates_per_eval * config.system.rollout_length - * config.system.update_batch_size * config.arch.num_envs ) @@ -639,6 +621,7 @@ def run_experiment(_config: DictConfig) -> float: apply_fns, logger, learner_devices, + d_id, ), ).start() @@ -648,7 +631,7 @@ def run_experiment(_config: DictConfig) -> float: trainer_update_number = 0 max_episode_return = jnp.float32(0.0) best_params = None - while True: + for eval_step in range(config.arch.num_evaluation): #todo : place holder trainer_update_number += 1 rollout_queue_get_time_start = time.time() sharded_storages = [] @@ -666,25 +649,36 @@ def run_experiment(_config: DictConfig) -> float: sharded_storage, sharded_next_obs, sharded_next_done, + sharded_next_action_mask, avg_params_queue_get_time, ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() sharded_storages.append(sharded_storage) sharded_next_obss.append(sharded_next_obs) sharded_next_dones.append(sharded_next_done) - + sharded_next_action_masks.append(sharded_next_action_mask) rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) training_time_start = time.time() - learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_dones) + #Concatinate the returned trajectories on the n_env axis + sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) + sharded_next_obss = jnp.concatenate(sharded_next_obss, axis = 1) + sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) + sharded_next_action_masks = jnp.concatenate(sharded_next_action_masks, axis = 1) + + learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_action_masks, sharded_next_dones) # Send updated params to executors - unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) + unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) for d_idx, d_id in enumerate(config.arch.executor_device_ids): device_params = jax.device_put(unreplicated_params, devices[d_id]) for thread_id in range(config.arch.n_threads_per_executor): params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( device_params ) + + t = int(steps_per_rollout * (eval_step + 1)) + logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + return None#eval_performance diff --git a/mava/systems/sebulba/ppo/orig.py b/mava/systems/sebulba/ppo/orig.py index 85b679305..dde0add30 100644 --- a/mava/systems/sebulba/ppo/orig.py +++ b/mava/systems/sebulba/ppo/orig.py @@ -43,7 +43,6 @@ ActorApply, CriticApply, LearnerState, - Observation, OptStates, Params, ) @@ -189,8 +188,8 @@ def prepare_data(storage: List[Transition]) -> Transition: ) storage_time += time.time() - storage_time_start - # Update episode info - episode_returns[env_id] += np.mean(next_reward) + # Update episode info ---------------------------------------------------------------------------------------------------------- this is kinda cringe? + episode_returns[env_id] += np.mean(next_reward, axis = 1) returned_episode_returns[env_id] = np.where( next_done, episode_returns[env_id], diff --git a/mava/systems/sebulba/ppo/test.py b/mava/systems/sebulba/ppo/test.py index fa3798ce5..adc15dcc7 100644 --- a/mava/systems/sebulba/ppo/test.py +++ b/mava/systems/sebulba/ppo/test.py @@ -31,20 +31,33 @@ unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.total_timestep_checker import anakin_check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics - - +from flax import linen as nn +import gym +from mava.wrappers import GymRwareWrapper @hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. OmegaConf.set_struct(cfg, False) + + base = gym.make(cfg.env.scenario) + base = GymRwareWrapper(base, cfg.env.use_individual_rewards, False, True) + base.reset() + ree = base.step([0,0]) + print(ree) env = environments.make_gym_env(cfg) a = env.reset() print(a) + b = env.step([[0,0], [0,0], [0,0], [0,0]]) + #print(b) + #r = 1+1 + # Create a sample input + #env = gym.make(cfg.env.scenario) + #env.reset() + #a = env.step(jnp.ones((4))) -if __name__ == "__main__": - hydra_entry_point() \ No newline at end of file +hydra_entry_point() \ No newline at end of file diff --git a/mava/utils/total_timestep_checker.py b/mava/utils/total_timestep_checker.py index c2cda8320..fd90b7436 100644 --- a/mava/utils/total_timestep_checker.py +++ b/mava/utils/total_timestep_checker.py @@ -18,7 +18,7 @@ from omegaconf import DictConfig -def check_total_timesteps(config: DictConfig) -> DictConfig: +def anakin_check_total_timesteps(config: DictConfig) -> DictConfig: """Check if total_timesteps is set, if not, set it based on the other parameters""" n_devices = len(jax.devices()) @@ -47,3 +47,33 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: + f"{Style.RESET_ALL}" ) return config + + +def sebulba_check_total_timesteps(config: DictConfig) -> DictConfig: + """Check if total_timesteps is set, if not, set it based on the other parameters""" + + if config.system.total_timesteps is None: + config.system.num_updates = int(config.system.num_updates) + config.system.total_timesteps = int( + len(config.arch.executor_device_ids) + * config.arch.n_threads_per_executor + * config.system.num_updates + * config.system.rollout_length + * config.arch.num_envs + ) + else: + config.system.total_timesteps = int(config.system.total_timesteps) + config.system.num_updates = int( + config.system.total_timesteps + // config.system.rollout_length + // config.arch.num_envs + // config.arch.n_threads_per_executor + // len(config.arch.executor_device_ids) + ) + print( + f"{Fore.RED}{Style.BRIGHT} Changing the number of updates " + + f"to {config.system.num_updates}: If you want to train" + + " for a specific number of updates, please set total_timesteps to None!" + + f"{Style.RESET_ALL}" + ) + return config \ No newline at end of file From 1985729cab347716153d3f5e00713b08eeb96f1b Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 22 Jun 2024 12:37:46 +0100 Subject: [PATCH 020/125] fix: changed the anakin ppo type import --- mava/systems/anakin/ppo/ff_ippo.py | 2 +- mava/systems/anakin/ppo/ff_mappo.py | 2 +- mava/systems/anakin/ppo/rec_ippo.py | 2 +- mava/systems/anakin/ppo/rec_mappo.py | 2 +- mava/systems/sebulba/ppo/ff_ippo.py | 16 ++++++++++++++-- 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/mava/systems/anakin/ppo/ff_ippo.py b/mava/systems/anakin/ppo/ff_ippo.py index 98920428e..d8cd0e9b4 100644 --- a/mava/systems/anakin/ppo/ff_ippo.py +++ b/mava/systems/anakin/ppo/ff_ippo.py @@ -32,7 +32,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer diff --git a/mava/systems/anakin/ppo/ff_mappo.py b/mava/systems/anakin/ppo/ff_mappo.py index dda1ef14b..a4ddfdaa5 100644 --- a/mava/systems/anakin/ppo/ff_mappo.py +++ b/mava/systems/anakin/ppo/ff_mappo.py @@ -31,7 +31,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer diff --git a/mava/systems/anakin/ppo/rec_ippo.py b/mava/systems/anakin/ppo/rec_ippo.py index 5aff93ee6..512a09301 100644 --- a/mava/systems/anakin/ppo/rec_ippo.py +++ b/mava/systems/anakin/ppo/rec_ippo.py @@ -33,7 +33,7 @@ from mava.networks import RecurrentActor as Actor from mava.networks import RecurrentValueNet as Critic from mava.networks import ScannedRNN -from mava.systems.ppo.types import ( +from mava.systems.anakin.ppo.types import ( HiddenStates, OptStates, Params, diff --git a/mava/systems/anakin/ppo/rec_mappo.py b/mava/systems/anakin/ppo/rec_mappo.py index 7efbad9d2..529a0505b 100644 --- a/mava/systems/anakin/ppo/rec_mappo.py +++ b/mava/systems/anakin/ppo/rec_mappo.py @@ -33,7 +33,7 @@ from mava.networks import RecurrentActor as Actor from mava.networks import RecurrentValueNet as Critic from mava.networks import ScannedRNN -from mava.systems.ppo.types import ( +from mava.systems.anakin.ppo.types import ( HiddenStates, OptStates, Params, diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index f5a97b807..0ce93cda0 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -667,6 +667,12 @@ def run_experiment(_config: DictConfig) -> float: learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_action_masks, sharded_next_dones) + # Log the results of the training. + elapsed_time = time.time() - rollout_queue_get_time_start + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + # Send updated params to executors unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) for d_idx, d_id in enumerate(config.arch.executor_device_ids): @@ -675,8 +681,11 @@ def run_experiment(_config: DictConfig) -> float: params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( device_params ) - - t = int(steps_per_rollout * (eval_step + 1)) + + # Separately log timesteps, actoring metrics and training metrics. + logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + if ep_completed: # only log episode metrics if an episode was completed in the rollout. + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) @@ -697,3 +706,6 @@ def hydra_entry_point(cfg: DictConfig) -> float: if __name__ == "__main__": hydra_entry_point() + +#learner_output.episode_metrics.keys() +#dict_keys(['episode_length', 'episode_return']) \ No newline at end of file From 89ed2466e8a3bbaff26eb60145a6dbb85e5e929c Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 25 Jun 2024 15:43:31 +0100 Subject: [PATCH 021/125] feat: fulll sebulba functional --- .../ff_ippo_store_experience.py | 4 +- mava/configs/arch/sebulba.yaml | 2 +- mava/configs/env/gym.yaml | 2 +- mava/configs/system/ppo/ff_ippo.yaml | 4 +- mava/evaluator.py | 129 +++++++++++- mava/systems/anakin/ppo/ff_ippo.py | 4 +- mava/systems/anakin/ppo/ff_mappo.py | 4 +- mava/systems/anakin/ppo/rec_ippo.py | 4 +- mava/systems/anakin/ppo/rec_mappo.py | 4 +- mava/systems/anakin/q_learning/rec_iql.py | 4 +- mava/systems/anakin/sac/ff_isac.py | 4 +- mava/systems/anakin/sac/ff_masac.py | 4 +- mava/systems/sebulba/ppo/ff_ippo.py | 168 ++++++++------- mava/systems/sebulba/ppo/test.py | 46 +++-- mava/utils/logger.py | 2 +- mava/utils/make_env.py | 10 +- mava/wrappers/__init__.py | 2 +- mava/wrappers/episode_metrics.py | 2 +- mava/wrappers/gym.py | 193 +++++++++++++----- 19 files changed, 424 insertions(+), 168 deletions(-) diff --git a/mava/advanced_usage/ff_ippo_store_experience.py b/mava/advanced_usage/ff_ippo_store_experience.py index 4bd94040c..4236bc641 100644 --- a/mava/advanced_usage/ff_ippo_store_experience.py +++ b/mava/advanced_usage/ff_ippo_store_experience.py @@ -30,7 +30,7 @@ from optax._src.base import OptState from rich.pretty import pprint -from mava.evaluator import make_eval_fns +from mava.evaluator import make_anakin_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition @@ -469,7 +469,7 @@ def run_experiment(_config: DictConfig) -> None: # noqa: CCR001 # Setup evaluator. eval_keys = jax.random.split(key_e, n_devices) - evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor_network, config) + evaluator, absolute_metric_evaluator = make_anakin_eval_fns(eval_env, actor_network, config) config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation steps_per_rollout = ( diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index cd47dca13..02ae56bb3 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,6 +1,6 @@ # --- Sebulba config --- arch_name: "sebulba" -num_envs: 4 # number of envs per thread +num_envs: 64 # number of envs per thread # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select diff --git a/mava/configs/env/gym.yaml b/mava/configs/env/gym.yaml index ad8d16b9a..44c9c624a 100644 --- a/mava/configs/env/gym.yaml +++ b/mava/configs/env/gym.yaml @@ -10,7 +10,7 @@ eval_metric: episode_return # Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. # This should not be changed. -implicit_agent_id: False +implicit_agent_id: True # Whether or not to log the winrate of this environment. This should not be changed as not all # environments have a winrate metric. log_win_rate: False diff --git a/mava/configs/system/ppo/ff_ippo.yaml b/mava/configs/system/ppo/ff_ippo.yaml index 9efb0611a..b8d0573b4 100644 --- a/mava/configs/system/ppo/ff_ippo.yaml +++ b/mava/configs/system/ppo/ff_ippo.yaml @@ -1,6 +1,6 @@ # --- Defaults FF-IPPO --- -total_timesteps: ~ # Set the total environment steps. +total_timesteps: 20_000_000 # Set the total environment steps. # If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. num_updates: 1000 # Number of updates seed: 42 @@ -14,7 +14,7 @@ critic_lr: 2.5e-4 # Learning rate for critic network update_batch_size: 2 # Number of vectorised gradient updates per device. rollout_length: 128 # Number of environment steps per vectorised environment. ppo_epochs: 4 # Number of ppo epochs per training data batch. -num_minibatches: 2 # Number of minibatches per ppo epoch. +num_minibatches: 1 # Number of minibatches per ppo epoch. gamma: 0.99 # Discounting factor. gae_lambda: 0.95 # Lambda value for GAE computation. clip_eps: 0.2 # Clipping value for PPO updates and value function. diff --git a/mava/evaluator.py b/mava/evaluator.py index 201544338..066890ed9 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -31,8 +31,10 @@ RNNEvalState, ) +from mava.systems.sebulba.ppo.types import Observation +import numpy as np -def get_ff_evaluator_fn( +def get_anakin_ff_evaluator_fn( env: Environment, apply_fn: ActorApply, config: DictConfig, @@ -282,7 +284,7 @@ def evaluator_fn( return evaluator_fn -def make_eval_fns( +def make_anakin_eval_fns( eval_env: Environment, network_apply_fn: Union[ActorApply, RecActorApply], config: DictConfig, @@ -327,10 +329,10 @@ def make_eval_fns( 10, ) else: - evaluator = get_ff_evaluator_fn( + evaluator = get_anakin_ff_evaluator_fn( eval_env, network_apply_fn, config, log_win_rate # type: ignore ) - absolute_metric_evaluator = get_ff_evaluator_fn( + absolute_metric_evaluator = get_anakin_ff_evaluator_fn( eval_env, network_apply_fn, config, log_win_rate, 10 # type: ignore ) @@ -338,3 +340,122 @@ def make_eval_fns( absolute_metric_evaluator = jax.pmap(absolute_metric_evaluator, axis_name="device") return evaluator, absolute_metric_evaluator + + +def get_sebulba_ff_evaluator_fn( + env: Environment, + apply_fn: ActorApply, + config: DictConfig, + log_win_rate: bool = False, +) -> EvalFn: + """Get the evaluator function for feedforward networks. + + Args: + env (Environment): An evironment instance for evaluation. + apply_fn (callable): Network forward pass method. + config (dict): Experiment configuration. + """ + @jax.jit + def get_action( #todo explicetly put these on the learner? they should already be there + params: FrozenDict, + observation: Observation, + key: chex.PRNGKey, + ) -> Tuple: + """Get action.""" + + pi = apply_fn(params, observation) + + if config.arch.evaluation_greedy: + action = pi.mode() + else: + action = pi.sample(seed=key) + + return action + def eval_episodes(params: FrozenDict, key : chex.PRNGKey) -> Dict: + + dones = np.zeros(env.num_envs) # todo: jnp or np? + + obs, info = env.reset() + eval_metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) + + while not dones.all(): + + key, policy_key = jax.random.split(key) + + obs = jax.device_put(jnp.stack(obs, axis = 1)) + action_mask = jax.device_put(jnp.stack([*info["actions_mask"]], axis = 0)) + + actions = get_action(params, Observation(obs, action_mask), policy_key) + cpu_action = jax.device_get(actions) + + obs, reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) + + next_metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) + + next_dones = next_metrics["is_terminal_step"] + + update_metric = lambda old_metric, new_metric : np.where(np.logical_and(next_dones, dones == False), new_metric, old_metric) + eval_metrics = jax.tree_map(update_metric, eval_metrics, next_metrics) + + dones = np.logical_or(dones, next_dones) + eval_metrics.pop("is_terminal_step") + + return eval_metrics + + return eval_episodes + + +def make_sebulba_eval_fns( + eval_env_fn: callable, + network_apply_fn: Union[ActorApply, RecActorApply], + config: DictConfig, + use_recurrent_net: bool = False, + scanned_rnn: Optional[nn.Module] = None, +) -> Tuple[EvalFn, EvalFn]: + """Initialize evaluator functions for reinforcement learning. + + Args: + eval_env_fn (Environment): The function to Create the eval envs. + network_apply_fn (Union[ActorApply,RecActorApply]): Creates a policy to sample. + config (DictConfig): The configuration settings for the evaluation. + use_recurrent_net (bool, optional): Whether to use a rnn. Defaults to False. + scanned_rnn (Optional[nn.Module], optional): The rnn module. + Required if `use_recurrent_net` is True. Defaults to None. + + Returns: + Tuple[EvalFn, EvalFn]: A tuple of two evaluation functions: + one for use during training and one for absolute metrics. + + Raises: + AssertionError: If `use_recurrent_net` is True but `scanned_rnn` is not provided. + """ + eval_env, absolute_eval_env = eval_env_fn(config, config.arch.num_eval_episodes), eval_env_fn(config, config.arch.num_eval_episodes * 10) + + # Check if win rate is required for evaluation. + log_win_rate = config.env.log_win_rate + # Vmap it over number of agents and create evaluator_fn. + if use_recurrent_net: + assert scanned_rnn is not None + evaluator = get_rnn_evaluator_fn( + eval_env, + network_apply_fn, # type: ignore + config, + scanned_rnn, + log_win_rate, + ) + absolute_metric_evaluator = get_rnn_evaluator_fn( + absolute_eval_env, + network_apply_fn, # type: ignore + config, + scanned_rnn, + log_win_rate, + ) + else: + evaluator = get_sebulba_ff_evaluator_fn( + eval_env, network_apply_fn, config, log_win_rate # type: ignore + ) + absolute_metric_evaluator = get_sebulba_ff_evaluator_fn( + absolute_eval_env, network_apply_fn, config, log_win_rate # type: ignore + ) + + return evaluator, absolute_metric_evaluator \ No newline at end of file diff --git a/mava/systems/anakin/ppo/ff_ippo.py b/mava/systems/anakin/ppo/ff_ippo.py index d8cd0e9b4..f0803de4d 100644 --- a/mava/systems/anakin/ppo/ff_ippo.py +++ b/mava/systems/anakin/ppo/ff_ippo.py @@ -29,7 +29,7 @@ from optax._src.base import OptState from rich.pretty import pprint -from mava.evaluator import make_eval_fns +from mava.evaluator import make_anakin_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition @@ -462,7 +462,7 @@ def run_experiment(_config: DictConfig) -> float: # Setup evaluator. # One key per device for evaluation. eval_keys = jax.random.split(key_e, n_devices) - evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor_network.apply, config) + evaluator, absolute_metric_evaluator = make_anakin_eval_fns(eval_env, actor_network.apply, config) # Calculate total timesteps. config = anakin_check_total_timesteps(config) diff --git a/mava/systems/anakin/ppo/ff_mappo.py b/mava/systems/anakin/ppo/ff_mappo.py index a4ddfdaa5..90fad5767 100644 --- a/mava/systems/anakin/ppo/ff_mappo.py +++ b/mava/systems/anakin/ppo/ff_mappo.py @@ -28,7 +28,7 @@ from optax._src.base import OptState from rich.pretty import pprint -from mava.evaluator import make_eval_fns +from mava.evaluator import make_anakin_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition @@ -459,7 +459,7 @@ def run_experiment(_config: DictConfig) -> float: # Setup evaluator. # One key per device for evaluation. eval_keys = jax.random.split(key_e, n_devices) - evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor_network.apply, config) + evaluator, absolute_metric_evaluator = make_anakin_eval_fns(eval_env, actor_network.apply, config) # Calculate total timesteps. config = anakin_check_total_timesteps(config) diff --git a/mava/systems/anakin/ppo/rec_ippo.py b/mava/systems/anakin/ppo/rec_ippo.py index 512a09301..583cd7acc 100644 --- a/mava/systems/anakin/ppo/rec_ippo.py +++ b/mava/systems/anakin/ppo/rec_ippo.py @@ -29,7 +29,7 @@ from optax._src.base import OptState from rich.pretty import pprint -from mava.evaluator import make_eval_fns +from mava.evaluator import make_anakin_eval_fns from mava.networks import RecurrentActor as Actor from mava.networks import RecurrentValueNet as Critic from mava.networks import ScannedRNN @@ -613,7 +613,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 # Setup evaluator. # One key per device for evaluation. eval_keys = jax.random.split(key_e, n_devices) - evaluator, absolute_metric_evaluator = make_eval_fns( + evaluator, absolute_metric_evaluator = make_anakin_eval_fns( eval_env=eval_env, network_apply_fn=actor_network.apply, config=config, diff --git a/mava/systems/anakin/ppo/rec_mappo.py b/mava/systems/anakin/ppo/rec_mappo.py index 529a0505b..74179ab34 100644 --- a/mava/systems/anakin/ppo/rec_mappo.py +++ b/mava/systems/anakin/ppo/rec_mappo.py @@ -29,7 +29,7 @@ from optax._src.base import OptState from rich.pretty import pprint -from mava.evaluator import make_eval_fns +from mava.evaluator import make_anakin_eval_fns from mava.networks import RecurrentActor as Actor from mava.networks import RecurrentValueNet as Critic from mava.networks import ScannedRNN @@ -605,7 +605,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 # Setup evaluator. # One key per device for evaluation. eval_keys = jax.random.split(key_e, n_devices) - evaluator, absolute_metric_evaluator = make_eval_fns( + evaluator, absolute_metric_evaluator = make_anakin_eval_fns( eval_env=eval_env, network_apply_fn=actor_network.apply, config=config, diff --git a/mava/systems/anakin/q_learning/rec_iql.py b/mava/systems/anakin/q_learning/rec_iql.py index 60fd98d5c..d3566a8d5 100644 --- a/mava/systems/anakin/q_learning/rec_iql.py +++ b/mava/systems/anakin/q_learning/rec_iql.py @@ -32,7 +32,7 @@ from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint -from mava.evaluator import make_eval_fns +from mava.evaluator import make_anakin_eval_fns from mava.networks import RecQNetwork, ScannedRNN from mava.systems.q_learning.types import ( ActionSelectionState, @@ -548,7 +548,7 @@ def run_experiment(cfg: DictConfig) -> float: cfg.system.num_agents = env.num_agents key, eval_key = jax.random.split(key) - evaluator, absolute_metric_evaluator = make_eval_fns( + evaluator, absolute_metric_evaluator = make_anakin_eval_fns( eval_env=eval_env, network_apply_fn=q_net.apply, config=cfg, diff --git a/mava/systems/anakin/sac/ff_isac.py b/mava/systems/anakin/sac/ff_isac.py index 7e4e20335..a3b2e5c47 100644 --- a/mava/systems/anakin/sac/ff_isac.py +++ b/mava/systems/anakin/sac/ff_isac.py @@ -31,7 +31,7 @@ from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint -from mava.evaluator import make_eval_fns +from mava.evaluator import make_anakin_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardQNet as QNetwork from mava.systems.sac.types import ( @@ -502,7 +502,7 @@ def run_experiment(cfg: DictConfig) -> float: actor, _ = networks key, eval_key = jax.random.split(key) - evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor.apply, cfg) + evaluator, absolute_metric_evaluator = make_anakin_eval_fns(eval_env, actor.apply, cfg) if cfg.logger.checkpointing.save_model: checkpointer = Checkpointer( diff --git a/mava/systems/anakin/sac/ff_masac.py b/mava/systems/anakin/sac/ff_masac.py index d5fb9172d..a319731ab 100644 --- a/mava/systems/anakin/sac/ff_masac.py +++ b/mava/systems/anakin/sac/ff_masac.py @@ -31,7 +31,7 @@ from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint -from mava.evaluator import make_eval_fns +from mava.evaluator import make_anakin_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardQNet as QNetwork from mava.systems.sac.types import ( @@ -520,7 +520,7 @@ def run_experiment(cfg: DictConfig) -> float: actor, _ = networks key, eval_key = jax.random.split(key) - evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor.apply, cfg) + evaluator, absolute_metric_evaluator = make_anakin_eval_fns(eval_env, actor.apply, cfg) if cfg.logger.checkpointing.save_model: checkpointer = Checkpointer( diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 0ce93cda0..229e268d0 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -32,7 +32,7 @@ from optax._src.base import OptState from rich.pretty import pprint -from mava.evaluator import make_eval_fns +from mava.evaluator import make_sebulba_eval_fns as make_eval_fns #todo: make a standered eval function from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.sebulba.ppo.types import LearnerState, OptStates, Params, PPOTransition, Observation #todo: change this Observation to use the origial one @@ -62,7 +62,7 @@ def rollout( actor_device_id : int): #create envs - env = environments.make_gym_env(config) + env = environments.make_gym_env(config, config.arch.num_envs) #setup len_executor_device_ids = len(config.arch.executor_device_ids) @@ -93,7 +93,7 @@ def get_action_and_value( def prepare_data(storage: List[PPOTransition]) -> PPOTransition: """Prepare data to share with learner.""" return jax.tree_map( # type: ignore - lambda *xs: jnp.stack(xs), *storage + lambda *xs : jnp.stack(xs), *storage ) @@ -102,73 +102,75 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: rollout_time: deque = deque(maxlen=10) rollout_queue_put_time: deque = deque(maxlen=10) - next_obs , info = env.reset() + next_obs , info = env.reset() #todo : the first info is discarded , is that a problem? next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) move_to_device = lambda x : jax.device_put(x, device = current_actor_device) # Loop till the learner has finished training - for update in range(1, config.system.num_updates + 2): - # Setup - env_recv_time: float = 0 - inference_time: float = 0 - storage_time: float = 0 - env_send_time: float = 0 - - # Get the latest parameters from the learner - params_queue_get_time_start = time.time() - params = params_queue.get() - params_queue_get_time.append(time.time() - params_queue_get_time_start) - - # Rollout - rollout_time_start = time.time() - storage: List = [] - # Loop over the rollout length - for _ in range(0, config.system.rollout_length): - # Cached for transition - cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) - cached_next_dones = move_to_device(next_dones) - cashed_action_mask = move_to_device(jnp.stack([*info["actions_mask"]], axis = 0) ) #unpack the numpy object, find a more pythonic way? - - # Increment current timestep - t_env += ( - config.arch.n_threads_per_executor * len_executor_device_ids * config.arch.num_envs - ) + for eval_step in range(config.arch.num_evaluation): + for update in range(1, config.system.num_updates_per_eval + 2): + # Setup + env_recv_time: float = 0 + inference_time: float = 0 + storage_time: float = 0 + env_send_time: float = 0 - # Get action and value - inference_time_start = time.time() - # - ( - action, - log_prob, - value, - key, - ) = get_action_and_value(params, Observation(cached_next_obs, cashed_action_mask), key) - inference_time += time.time() - inference_time_start + # Get the latest parameters from the learner + params_queue_get_time_start = time.time() + params = params_queue.get() + params_queue_get_time.append(time.time() - params_queue_get_time_start) - # Step the environment - env_send_time_start = time.time() - cpu_action = jax.device_get(action) - next_obs, next_reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) #num_env, num_agents --> num_agents, num_env - next_dones = np.logical_or(terminated, truncated) - - # Append data to storage - env_send_time += time.time() - env_send_time_start - storage_time_start = time.time() - storage.append( - PPOTransition( - done=cached_next_dones, - action=action, - value=value, - reward=next_reward, - log_prob=log_prob, - obs=Observation(cached_next_obs, cashed_action_mask), - info={"win_rate" : info.get("win_rate")}, - )#todo: use a threadsafe alt https://github.com/instadeepai/CityLearn/blob/27e69f8ebdf1789c55ffab5c326bfaa50733a5e7/power_systems/sax_sebulba.py#L39 - ) - storage_time += time.time() - storage_time_start + # Rollout + rollout_time_start = time.time() + storage: List = [] + # Loop over the rollout length + for _ in range(0, config.system.rollout_length): + # Cached for transition + cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) + cached_next_dones = move_to_device(next_dones) + cashed_action_mask = move_to_device(jnp.stack([*info["actions_mask"]], axis = 0) ) #unpack the numpy object, find a more pythonic way? + + # Increment current timestep + t_env += ( + config.arch.n_threads_per_executor * len_executor_device_ids * config.arch.num_envs + ) + + # Get action and value + inference_time_start = time.time() + # + ( + action, + log_prob, + value, + key, + ) = get_action_and_value(params, Observation(cached_next_obs, cashed_action_mask), key) + inference_time += time.time() - inference_time_start + + # Step the environment + env_send_time_start = time.time() + cpu_action = jax.device_get(action) + next_obs, next_reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) #num_env, num_agents --> num_agents, num_env + next_dones = np.logical_or(terminated, truncated) + + metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics (N_envs , N_metrics) -- > (N_metrics, N_envs) + # Append data to storage + env_send_time += time.time() - env_send_time_start + storage_time_start = time.time() + storage.append( + PPOTransition( + done=cached_next_dones, + action=action, + value=value, + reward=next_reward, + log_prob=log_prob, + obs=Observation(cached_next_obs, cashed_action_mask), + info=metrics, + )#todo: use a threadsafe alt https://github.com/instadeepai/CityLearn/blob/27e69f8ebdf1789c55ffab5c326bfaa50733a5e7/power_systems/sax_sebulba.py#L39 + ) + storage_time += time.time() - storage_time_start - rollout_time.append(time.time() - rollout_time_start) + rollout_time.append(time.time() - rollout_time_start) # Prepare data to share with learner # todo: investigate te thread --> single learning @@ -446,7 +448,7 @@ def learner_setup( n_devices = len(learner_devices) #create temporory envoirnments. - env = environments.make_gym_env(config) + env = environments.make_gym_env(config, config.arch.num_envs) # Get number of agents and actions. action_space = env.single_action_space config.system.num_agents = len(action_space) @@ -562,8 +564,7 @@ def run_experiment(_config: DictConfig) -> float: # Setup evaluator. # One key per device for evaluation. - #eval_keys = jax.random.split(key_e, n_devices) # todo: well add the evaluations :) - #evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor_network.apply, config) + evaluator, absolute_metric_evaluator = make_eval_fns(environments.make_gym_env, apply_fns[0], config) #todo: make this more generic # Calculate total timesteps. config = sebulba_check_total_timesteps(config) #todo: update this for sebulba @@ -576,9 +577,9 @@ def run_experiment(_config: DictConfig) -> float: steps_per_rollout = ( len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor - * config.system.num_updates_per_eval * config.system.rollout_length * config.arch.num_envs + * config.system.num_updates_per_eval ) # Logger setup @@ -633,7 +634,7 @@ def run_experiment(_config: DictConfig) -> float: best_params = None for eval_step in range(config.arch.num_evaluation): #todo : place holder trainer_update_number += 1 - rollout_queue_get_time_start = time.time() + start_time = time.time() sharded_storages = [] sharded_next_obss = [] sharded_next_dones = [] @@ -656,23 +657,17 @@ def run_experiment(_config: DictConfig) -> float: sharded_next_obss.append(sharded_next_obs) sharded_next_dones.append(sharded_next_done) sharded_next_action_masks.append(sharded_next_action_mask) - rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) + rollout_queue_get_time.append(time.time() - start_time) training_time_start = time.time() #Concatinate the returned trajectories on the n_env axis - sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) + sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) #todo: check if this breaks the explicet array device placment sharded_next_obss = jnp.concatenate(sharded_next_obss, axis = 1) sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) sharded_next_action_masks = jnp.concatenate(sharded_next_action_masks, axis = 1) learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_action_masks, sharded_next_dones) - # Log the results of the training. - elapsed_time = time.time() - rollout_queue_get_time_start - t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time - # Send updated params to executors unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) for d_idx, d_id in enumerate(config.arch.executor_device_ids): @@ -682,13 +677,36 @@ def run_experiment(_config: DictConfig) -> float: device_params ) + # Log the results of the training. + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) # todo: these shapes are not as expected + episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + # Separately log timesteps, actoring metrics and training metrics. logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) if ep_completed: # only log episode metrics if an episode was completed in the rollout. logger.log(episode_metrics, t, eval_step, LogEvent.ACT) logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + # Evaluation on the learner + key_e, eval_key = jax.random.split(key_e, 2) + episode_metrics = evaluator(unreplicate_n_dims(learner_output.learner_state.params.actor_params, 1 ), eval_key) + + # Log the results of the evaluation. + elapsed_time = time.time() - start_time + episode_return = jnp.mean(episode_metrics["episode_return"]) + steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) + episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(episode_metrics, t, eval_step, LogEvent.EVAL) + + #todo: add saving + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(learner_output.learner_state.params) + max_episode_return = episode_return + #todo: abs metric return None#eval_performance diff --git a/mava/systems/sebulba/ppo/test.py b/mava/systems/sebulba/ppo/test.py index adc15dcc7..5e45544f1 100644 --- a/mava/systems/sebulba/ppo/test.py +++ b/mava/systems/sebulba/ppo/test.py @@ -5,6 +5,8 @@ import threading import chex import flax +import gym.vector +import gym.vector.async_vector_env import hydra import jax import jax.numpy as jnp @@ -18,7 +20,7 @@ from optax._src.base import OptState from rich.pretty import pprint -from mava.evaluator import make_eval_fns +#from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition #todo: change this @@ -36,23 +38,41 @@ from mava.wrappers.episode_metrics import get_final_step_metrics from flax import linen as nn import gym -from mava.wrappers import GymRwareWrapper +import rware +from mava.wrappers import GymRwareWrapper, GymRecordEpisodeMetrics, _multiagent_worker_shared_memory @hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. - OmegaConf.set_struct(cfg, False) - + - base = gym.make(cfg.env.scenario) - base = GymRwareWrapper(base, cfg.env.use_individual_rewards, False, True) + OmegaConf.set_struct(cfg, False) + def f(): + base = gym.make(cfg.env.scenario) + base = GymRwareWrapper(base, cfg.env.use_individual_rewards, False, True) + return GymRecordEpisodeMetrics(base) + + base = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names + [ + lambda: f() + for _ in range(3) + ], + worker=_multiagent_worker_shared_memory + ) base.reset() - ree = base.step([0,0]) - print(ree) - env = environments.make_gym_env(cfg) - a = env.reset() - print(a) - b = env.step([[0,0], [0,0], [0,0], [0,0]]) + n = 0 + done = False + while not done: + n+= 1 + agents_view, reward, terminated, truncated, info = base.step([[0,0,0], [0,0,0]]) + done = np.logical_or(terminated, truncated).all() + metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) + print(n, done, terminated, np.logical_or(terminated, truncated).shape, metrics) + done = True + base.close() + print(done) + + #print(b) #r = 1+1 # Create a sample input @@ -60,4 +80,4 @@ def hydra_entry_point(cfg: DictConfig) -> float: #env.reset() #a = env.step(jnp.ones((4))) -hydra_entry_point() \ No newline at end of file +hydra_entry_point() diff --git a/mava/utils/logger.py b/mava/utils/logger.py index 4edad361e..8273e44a2 100644 --- a/mava/utils/logger.py +++ b/mava/utils/logger.py @@ -337,7 +337,7 @@ def get_logger_path(config: DictConfig, logger_type: str) -> str: def describe(x: ArrayLike) -> Union[Dict[str, ArrayLike], ArrayLike]: """Generate summary statistics for an array of metrics (mean, std, min, max).""" - if not isinstance(x, jax.Array) or x.size <= 1: + if not (isinstance(x, jax.Array) or isinstance(x, np.ndarray)) or x.size <= 1: return x # np instead of jnp because we don't jit here diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 69fc54623..cab649880 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -46,6 +46,7 @@ GigastepWrapper, GymRecordEpisodeMetrics, GymRwareWrapper, + _multiagent_worker_shared_memory, LbfWrapper, MabraxWrapper, MatraxWrapper, @@ -208,7 +209,7 @@ def make_gigastep_env( def make_gym_env( - config: DictConfig, add_global_state: bool = False, eval_env: bool = False + config: DictConfig, num_env : int, add_global_state: bool = False, eval_env: bool = False ) -> Environment: # todo : create the appropriate annotation for the sync vector """ Create a Gym environment. @@ -230,8 +231,8 @@ def create_gym_env( env = gym.make(config.env.scenario) wrapped_env = wrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) if not config.env.implicit_agent_id: - pass # todo : add agent id wrapper for gym . - env = GymRecordEpisodeMetrics(env) + wrapped_env = AgentIDWrapper(wrapped_env) # todo : add agent id wrapper for gym . + wrapped_env = GymRecordEpisodeMetrics(wrapped_env) return wrapped_env num_env = config.arch.num_envs @@ -239,7 +240,8 @@ def create_gym_env( [ lambda: create_gym_env(config, add_global_state, eval_env=eval_env) for _ in range(num_env) - ] + ], + worker=_multiagent_worker_shared_memory ) return envs diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index e888d9317..3608b1d10 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,7 +15,7 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper -from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper +from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper, _multiagent_worker_shared_memory from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, diff --git a/mava/wrappers/episode_metrics.py b/mava/wrappers/episode_metrics.py index a2b0fdb37..a46dc1b91 100644 --- a/mava/wrappers/episode_metrics.py +++ b/mava/wrappers/episode_metrics.py @@ -75,7 +75,7 @@ def step( # Previous episode return/length until done and then the next episode return. episode_return_info = state.episode_return * not_done + new_episode_return * done episode_length_info = state.episode_length * not_done + new_episode_length * done - + timestep.extras["episode_metrics"] = { "episode_return": episode_return_info, "episode_length": episode_length_info, diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 69632f1bc..546e05614 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -13,17 +13,21 @@ # limitations under the License. import warnings -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional import gym import numpy as np from numpy.typing import NDArray +from gym.spaces import Box +from gym.vector.utils import write_to_shared_memory +import sys + # Filter out the warnings warnings.filterwarnings("ignore", module="gym.utils.passive_env_checker") -class GymRwareWrapper(gym.Wrapper): +class GymRwareWrapper(gym.Wrapper): """Wrapper for rware gym environments""" def __init__( @@ -44,7 +48,7 @@ def __init__( Defaults to False. """ super().__init__(env) - self._env = gym.wrappers.compatibility.EnvCompatibility(env) + self._env = env #not having _env leaded tp self.env getting replaced --> circular called self.use_individual_rewards = use_individual_rewards self.add_global_state = add_global_state # todo : add the global observations self.eval_env = eval_env @@ -52,42 +56,33 @@ def __init__( self.num_actions = self._env.action_space[ 0 ].n # todo: all the agents must have the same num_actions, add assertion? - - def reset(self) -> Tuple: - (agents_view, info), _ = self._env.reset( - seed=np.random.randint(1) - ) # todo: assure reproducibility, this only works for rware - - info = {"actions_mask": self._get_actions_mask(info)} + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple: + + if seed is not None: + self.env.seed(seed) + + agents_view, info = self._env.reset() + + info = {"actions_mask": self.get_actions_mask(info)} return np.array(agents_view), info - def step(self, actions: NDArray) -> Tuple: - - agents_view, reward, terminated, truncated, info = self.env.step(actions) + def step(self, actions: NDArray) -> Tuple: #Vect auto rest - done = np.logical_or(terminated, truncated).all() + agents_view, reward, terminated, truncated, info = self._env.step(actions) - if ( - done and not self.eval_env - ): # only auto-reset in training envs, same functionality as the AutoResetWrapper. - agents_view, info = self.reset() - reward = np.zeros(self.num_agents) - terminated, truncated = np.zeros(self.num_agents, dtype=bool), np.zeros( - self.num_agents, dtype=bool - ) - return agents_view, reward, terminated, truncated, info - - info = {"actions_mask": self._get_actions_mask(info)} + info = {"actions_mask": self.get_actions_mask(info)} if self.use_individual_rewards: reward = np.array(reward) else: reward = np.array([np.array(reward).mean()] * self.num_agents) - return agents_view, reward, terminated, truncated, info - def _get_actions_mask(self, info: Dict) -> NDArray: + def get_actions_mask(self, info: Dict) -> NDArray: if "action_mask" in info: return np.array(info["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) @@ -98,51 +93,151 @@ class GymRecordEpisodeMetrics(gym.Wrapper): def __init__(self, env: gym.Env): super().__init__(env) + self._env = env self.running_count_episode_return = 0.0 - self.running_count_episode_length = 0 + self.running_count_episode_length = 0.0 def reset(self) -> Tuple: # Reset the env - agents_view, info = self.env.reset() + agents_view, info = self._env.reset() - # Reset the metrics - self.running_count_episode_return = 0.0 - self.running_count_episode_length = 0 + # Handle the Done when the auto reset happens + done = self.running_count_episode_length != -1 # Avoid setting the first ever done to True # Create the metrics dict metrics = { "episode_return": self.running_count_episode_return, - "episode_length": self.self.running_count_episode_length, - "is_terminal_step": False, + "episode_length": self.running_count_episode_length, + "is_terminal_step": done, } + + # Reset the metrics + self.running_count_episode_return = 0.0 + self.running_count_episode_length = 0 + if "won_episode" in info: metrics["won_episode"] = info["won_episode"] + + info["metrics"] = metrics - return agents_view, metrics + return agents_view, info def step(self, actions: NDArray) -> Tuple: # Step the env - agents_view, reward, terminated, truncated, info = self.env.step(actions) + agents_view, reward, terminated, truncated, info = self._env.step(actions) - # Update the metrics - done = np.logical_or(terminated, truncated).all() - - if not done: - self.running_count_episode_return += float(np.mean(reward)) - self.running_count_episode_length += 1 - - else: - self.running_count_episode_return = 0.0 - self.running_count_episode_length = 0 + self.running_count_episode_return += float(np.mean(reward)) + self.running_count_episode_length += 1 metrics = { "episode_return": self.running_count_episode_return, - "episode_length": self.self.running_count_episode_length, - "is_terminal_step": False, + "episode_length": self.running_count_episode_length, + "is_terminal_step": False, # We handle the True case in the reset function since this gets overwritten } if "won_episode" in info: metrics["won_episode"] = info["won_episode"] + + info["metrics"] = metrics + + return agents_view, reward, terminated, truncated, info + +class AgentIDWrapper(gym.Wrapper): + """Add onehot agent IDs to observation.""" + + def __init__(self, env: gym.Env): + super().__init__(env) - return agents_view, reward, terminated, truncated, metrics + self.agent_ids = np.eye(self.env.num_agents) + _obs_low, _obs_high, _obs_dtype, _obs_shape = ( + self.env.observation_space.low[0][0], + self.env.observation_space.high[0][0], + self.env.observation_space.dtype, + self.env.observation_space.shape, + ) + _new_obs_shape = (self.env.num_agents, _obs_shape[1] + self.env.num_agents) + self._observation_space = Box(low=_obs_low, high=_obs_high, shape=_new_obs_shape, dtype=_obs_dtype) + + def reset(self) -> Tuple[np.ndarray, Dict]: + """Reset the environment.""" + obs, info = self.env.reset() + obs = np.concatenate([self.agent_ids, obs], axis=1) + return obs, info + + def step(self, action: list) -> Tuple[np.ndarray, float, bool, bool, Dict]: + """Step the environment.""" + obs, reward, terminated, truncated, info = self.env.step(action) + obs = np.concatenate([self.agent_ids, obs], axis=1) + return obs, reward, terminated, truncated, info + + +def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): + assert shared_memory is not None + env = env_fn() + observation_space = env.observation_space + parent_pipe.close() + try: + while True: + command, data = pipe.recv() + if command == "reset": + observation, info = env.reset(**data) + write_to_shared_memory( + observation_space, index, observation, shared_memory + ) + pipe.send(((None, info), True)) + + elif command == "step": + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + if np.logical_or(terminated, truncated).all(): + old_observation, old_info = observation, info + observation, info = env.reset() + info["final_observation"] = old_observation + info["final_info"] = old_info + write_to_shared_memory( + observation_space, index, observation, shared_memory + ) + pipe.send(((None, reward, terminated, truncated, info), True)) + elif command == "seed": + env.seed(data) + pipe.send((None, True)) + elif command == "close": + pipe.send((None, True)) + break + elif command == "_call": + name, args, kwargs = data + if name in ["reset", "step", "seed", "close"]: + raise ValueError( + f"Trying to call function `{name}` with " + f"`_call`. Use `{name}` directly instead." + ) + function = getattr(env, name) + if callable(function): + pipe.send((function(*args, **kwargs), True)) + else: + pipe.send((function, True)) + elif command == "_setattr": + name, value = data + setattr(env, name, value) + pipe.send((None, True)) + elif command == "_check_spaces": + pipe.send( + ((data[0] == observation_space, data[1] == env.action_space), True) + ) + else: + raise RuntimeError( + f"Received unknown command `{command}`. Must " + "be one of {`reset`, `step`, `seed`, `close`, `_call`, " + "`_setattr`, `_check_spaces`}." + ) + except (KeyboardInterrupt, Exception): + error_queue.put((index,) + sys.exc_info()[:2]) + pipe.send((None, False)) + finally: + env.close() \ No newline at end of file From 7f43a33b63a63fbab41f4ce5673374ff76d4667f Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 2 Jul 2024 14:47:24 +0100 Subject: [PATCH 022/125] fix: logging and added LBF --- mava/configs/arch/sebulba.yaml | 10 +- mava/configs/env/gym.yaml | 6 +- mava/configs/system/ppo/ff_ippo.yaml | 8 +- mava/systems/sebulba/ppo/ff_ippo.py | 325 ++++++++++++++++----------- mava/systems/sebulba/ppo/test.py | 15 +- mava/utils/make_env.py | 17 +- mava/wrappers/__init__.py | 2 +- mava/wrappers/gym.py | 86 +++++-- 8 files changed, 291 insertions(+), 178 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 02ae56bb3..617e54134 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,13 +1,13 @@ # --- Sebulba config --- arch_name: "sebulba" -num_envs: 64 # number of envs per thread +num_envs: 3 # number of envs per thread # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select # an action which corresponds to the greatest logit. If false, the policy will sample # from the logits. num_eval_episodes: 32 # Number of episodes to evaluate per evaluation. -num_evaluation: 200 # Number of evenly spaced evaluations to perform during training. +num_evaluation: 10 # Number of evenly spaced evaluations to perform during training. absolute_metric: True # Whether the absolute metric should be computed. For more details # on the absolute metric please see: https://arxiv.org/abs/2209.10485 @@ -16,9 +16,3 @@ n_threads_per_executor: 2 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices -# --- Sebulba rollout and env config --- -concurrency: False # whether actor and learner should run concurrently -async_envs: True # "whether to use async vector or sync vector envs" - -# --- To be defined during training --- -log_frequency: ~ diff --git a/mava/configs/env/gym.yaml b/mava/configs/env/gym.yaml index 44c9c624a..9ddd16d41 100644 --- a/mava/configs/env/gym.yaml +++ b/mava/configs/env/gym.yaml @@ -1,8 +1,8 @@ # ---Environment Configs--- -scenario: rware:rware-tiny-2ag-v1 # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] +scenario: rware:rware-tiny-4ag-v1 #Foraging-8x8-2p-1f-v2 #rware:rware-tiny-2ag-v1 # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] -env_name: RobotWarehouse # Used for logging purposes. +env_name: RobotWarehouse #LevelBasedForaging # Used for logging purposes. # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. @@ -10,7 +10,7 @@ eval_metric: episode_return # Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. # This should not be changed. -implicit_agent_id: True +implicit_agent_id: False # Whether or not to log the winrate of this environment. This should not be changed as not all # environments have a winrate metric. log_win_rate: False diff --git a/mava/configs/system/ppo/ff_ippo.yaml b/mava/configs/system/ppo/ff_ippo.yaml index b8d0573b4..0c93c2683 100644 --- a/mava/configs/system/ppo/ff_ippo.yaml +++ b/mava/configs/system/ppo/ff_ippo.yaml @@ -1,16 +1,16 @@ # --- Defaults FF-IPPO --- -total_timesteps: 20_000_000 # Set the total environment steps. +total_timesteps: ~ # Set the total environment steps. # If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. -num_updates: 1000 # Number of updates +num_updates: 12 # Number of updates seed: 42 # --- Agent observations --- add_agent_id: True # --- RL hyperparameters --- -actor_lr: 2.5e-4 # Learning rate for actor network -critic_lr: 2.5e-4 # Learning rate for critic network +actor_lr: 1.0e-3 # Learning rate for actor network +critic_lr: 1.0e-3 # Learning rate for critic network update_batch_size: 2 # Number of vectorised gradient updates per device. rollout_length: 128 # Number of environment steps per vectorised environment. ppo_epochs: 4 # Number of ppo epochs per training data batch. diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 229e268d0..5df32bf5d 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -68,7 +68,7 @@ def rollout( len_executor_device_ids = len(config.arch.executor_device_ids) current_actor_device = jax.devices()[actor_device_id] t_env = 0 - start_time = time.time() + actor_apply_fn, critic_apply_fn = apply_fns @@ -98,9 +98,9 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: # Define queues to track time - params_queue_get_time: deque = deque(maxlen=10) - rollout_time: deque = deque(maxlen=10) - rollout_queue_put_time: deque = deque(maxlen=10) + params_queue_get_time: deque = deque(maxlen=1) + rollout_time: deque = deque(maxlen=1) + rollout_queue_put_time: deque = deque(maxlen=1) next_obs , info = env.reset() #todo : the first info is discarded , is that a problem? next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) @@ -108,70 +108,77 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: move_to_device = lambda x : jax.device_put(x, device = current_actor_device) # Loop till the learner has finished training - for eval_step in range(config.arch.num_evaluation): - for update in range(1, config.system.num_updates_per_eval + 2): - # Setup - env_recv_time: float = 0 - inference_time: float = 0 - storage_time: float = 0 - env_send_time: float = 0 + for update in range(config.system.num_updates): + print(update) + # Setup todo: double check tracking times + inference_time: float = 0 + storage_time: float = 0 + env_send_time: float = 0 + setup = 0 - # Get the latest parameters from the learner - params_queue_get_time_start = time.time() - params = params_queue.get() - params_queue_get_time.append(time.time() - params_queue_get_time_start) + # Get the latest parameters from the learner + params_queue_get_time_start = time.time() + params = params_queue.get() + params_queue_get_time.append(time.time() - params_queue_get_time_start) + + # Rollout + rollout_time_start = time.time() + storage: List = [] + + # Loop over the rollout length + for _ in range(0, config.system.rollout_length): - # Rollout - rollout_time_start = time.time() - storage: List = [] - # Loop over the rollout length - for _ in range(0, config.system.rollout_length): - # Cached for transition - cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) - cached_next_dones = move_to_device(next_dones) - cashed_action_mask = move_to_device(jnp.stack([*info["actions_mask"]], axis = 0) ) #unpack the numpy object, find a more pythonic way? - - # Increment current timestep - t_env += ( - config.arch.n_threads_per_executor * len_executor_device_ids * config.arch.num_envs - ) - - # Get action and value - inference_time_start = time.time() - # - ( - action, - log_prob, - value, - key, - ) = get_action_and_value(params, Observation(cached_next_obs, cashed_action_mask), key) - inference_time += time.time() - inference_time_start - - # Step the environment - env_send_time_start = time.time() - cpu_action = jax.device_get(action) - next_obs, next_reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) #num_env, num_agents --> num_agents, num_env - next_dones = np.logical_or(terminated, truncated) - - metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics (N_envs , N_metrics) -- > (N_metrics, N_envs) - # Append data to storage - env_send_time += time.time() - env_send_time_start - storage_time_start = time.time() - storage.append( - PPOTransition( - done=cached_next_dones, - action=action, - value=value, - reward=next_reward, - log_prob=log_prob, - obs=Observation(cached_next_obs, cashed_action_mask), - info=metrics, - )#todo: use a threadsafe alt https://github.com/instadeepai/CityLearn/blob/27e69f8ebdf1789c55ffab5c326bfaa50733a5e7/power_systems/sax_sebulba.py#L39 - ) - storage_time += time.time() - storage_time_start + # Cached for transition + cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) + cached_next_dones = move_to_device(next_dones) + setup_start = time.time() + cashed_action_mask = move_to_device(np.stack(info["actions_mask"]) ) + setup += time.time() - setup_start + # Increment current timestep + t_env += ( + config.arch.n_threads_per_executor * len_executor_device_ids * config.arch.num_envs + ) + + # Get action and value + inference_time_start = time.time() + # + ( + action, + log_prob, + value, + key, + ) = get_action_and_value(params, Observation(cached_next_obs, cashed_action_mask), key) + inference_time += time.time() - inference_time_start + + # Step the environment + env_send_time_start = time.time() + cpu_action = jax.device_get(action) - rollout_time.append(time.time() - rollout_time_start) - + next_obs, next_reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) #num_env, num_agents --> num_agents, num_env + env_send_time += time.time() - env_send_time_start + + + storage_time_start = time.time() + # Prepare the data + next_dones = np.logical_or(terminated, truncated) + metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics (N_envs , N_metrics) -- > (N_metrics, N_envs) + + # Append data to storage + storage.append( + PPOTransition( + done=cached_next_dones, + action=action, + value=value, + reward=next_reward, + log_prob=log_prob, + obs=Observation(cached_next_obs, cashed_action_mask), + info=metrics, + )#todo: use a threadsafe alt https://github.com/instadeepai/CityLearn/blob/27e69f8ebdf1789c55ffab5c326bfaa50733a5e7/power_systems/sax_sebulba.py#L39 + ) + storage_time += time.time() - storage_time_start + rollout_time.append(time.time() - rollout_time_start) + + parse_timer = time.time() # Prepare data to share with learner # todo: investigate te thread --> single learning partitioned_storage = prepare_data(storage) @@ -184,15 +191,27 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: sharded_next_action_mask = shard_split_payload(jnp.stack([*info["actions_mask"]], axis = 0), 0) sharded_next_done = shard_split_payload(next_dones, 0) + + speed_info = { + "rollout_time": np.mean(rollout_time), + "params_queue_get_time": np.mean(params_queue_get_time), + "action_inference": inference_time, + "storage_time": storage_time, + "env_step_time": env_send_time, + "rollout_queue_put_time": np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0, + "parse_time" : time.time() - parse_timer, + "setup_time" : setup, + } + #print(speed_info) + payload = ( t_env, sharded_storage, sharded_next_obs, sharded_next_done, - sharded_next_action_mask, - np.mean(params_queue_get_time), + sharded_next_action_mask ) - + # Put data in the rollout queue to share it with the learner rollout_queue_put_time_start = time.time() rollout_queue.put(payload) @@ -210,7 +229,7 @@ def get_learner_fn( actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, _: Any, traj_batch : PPOTransition, last_obs: chex.Array, last_action_mask : chex.Array, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: + def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_action_mask : chex.Array, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -340,7 +359,7 @@ def _critic_loss_fn( # available at https://tinyurl.com/26tdzs5x # pmean over devices. actor_grads, actor_loss_info = jax.lax.pmean( - (actor_grads, actor_loss_info), axis_name="device" + (actor_grads, actor_loss_info), axis_name="device" #todo: pmean over learner devices not all ) # pmean over devices. @@ -406,7 +425,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state learner_state = LearnerState(params, opt_states, key, None, None) - metric = traj_batch.info #todo: metrci calcualtions + metric = traj_batch.info return learner_state, (metric, loss_info) def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_action_mask : chex.Array, last_dones : chex.Array) -> ExperimentOutput[LearnerState]: @@ -424,12 +443,9 @@ def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs - env_state (LogEnvState): The environment state. - timesteps (TimeStep): The initial timestep in the initial trajectory. """ - # Broadcast static parameters for scan - partial_update_step = lambda learner_state, xs : _update_step(learner_state, xs, traj_batch , last_obs, last_action_mask, last_dones) + - learner_state, (episode_info, loss_info) = jax.lax.scan( - partial_update_step, learner_state, None, config.system.num_updates_per_eval - ) + learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch , last_obs, last_action_mask, last_dones) return ExperimentOutput( learner_state=learner_state, @@ -534,15 +550,13 @@ def run_experiment(_config: DictConfig) -> float: """Runs experiment.""" config = copy.deepcopy(_config) - devices = jax.devices() # todo: use local devices insted? + devices = jax.devices() learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] # PRNG keys. key, key_e, actor_net_key, critic_net_key = jax.random.split( jax.random.PRNGKey(config.system.seed), num=4 ) - - learner_keys = jax.device_put_replicated(key, learner_devices) # Sanity check of config assert ( @@ -624,77 +638,94 @@ def run_experiment(_config: DictConfig) -> float: learner_devices, d_id, ), - ).start() + ).start() #todo : this is techinically only multu threaded not multi processepr? # Run experiment for the total number of updates. - rollout_queue_get_time: deque = deque(maxlen=10) - data_transfer_time: deque = deque(maxlen=10) - trainer_update_number = 0 max_episode_return = jnp.float32(0.0) best_params = None - for eval_step in range(config.arch.num_evaluation): #todo : place holder - trainer_update_number += 1 - start_time = time.time() - sharded_storages = [] - sharded_next_obss = [] - sharded_next_dones = [] - sharded_next_action_masks = [] - - # Loop through each executor device - for d_idx, _ in enumerate(config.arch.executor_device_ids): - # Loop through each executor thread - for thread_id in range(config.arch.n_threads_per_executor): - # Get data from rollout queue - ( - t_env, - sharded_storage, - sharded_next_obs, - sharded_next_done, - sharded_next_action_mask, - avg_params_queue_get_time, - ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() - sharded_storages.append(sharded_storage) - sharded_next_obss.append(sharded_next_obs) - sharded_next_dones.append(sharded_next_done) - sharded_next_action_masks.append(sharded_next_action_mask) - rollout_queue_get_time.append(time.time() - start_time) - training_time_start = time.time() + for eval_step in range(config.arch.num_evaluation): + training_start_time = time.time() + learner_speeds = [] + rollout_times = [] - #Concatinate the returned trajectories on the n_env axis - sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) #todo: check if this breaks the explicet array device placment - sharded_next_obss = jnp.concatenate(sharded_next_obss, axis = 1) - sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) - sharded_next_action_masks = jnp.concatenate(sharded_next_action_masks, axis = 1) - - learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_action_masks, sharded_next_dones) + episode_metrics = [] + train_metrics = [] - # Send updated params to executors - unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) - for d_idx, d_id in enumerate(config.arch.executor_device_ids): - device_params = jax.device_put(unreplicated_params, devices[d_id]) - for thread_id in range(config.arch.n_threads_per_executor): - params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( - device_params - ) - + for update in range(config.system.num_updates_per_eval): + sharded_storages = [] + sharded_next_obss = [] + sharded_next_dones = [] + sharded_next_action_masks = [] + + rollout_start_time = time.time() + # Loop through each executor device + for d_idx, _ in enumerate(config.arch.executor_device_ids): + # Loop through each executor thread + for thread_id in range(config.arch.n_threads_per_executor): + # Get data from rollout queue + ( + t_env, + sharded_storage, + sharded_next_obs, + sharded_next_done, + sharded_next_action_mask + ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() + sharded_storages.append(sharded_storage) + sharded_next_obss.append(sharded_next_obs) + sharded_next_dones.append(sharded_next_done) + sharded_next_action_masks.append(sharded_next_action_mask) + + rollout_times.append(time.time() - rollout_start_time) + + + # Concatinate the returned trajectories on the n_env axis + sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) #todo: check if this breaks the explicet array device placment + sharded_next_obss = jnp.concatenate(sharded_next_obss, axis = 1) + sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) + sharded_next_action_masks = jnp.concatenate(sharded_next_action_masks, axis = 1) + + + learner_start_time = time.time() + learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_action_masks, sharded_next_dones) + learner_speeds.append(time.time() - learner_start_time) + + # Stack the metrics + episode_metrics.append(learner_output.episode_metrics) + train_metrics.append(learner_output.train_metrics) + + # Send updated params to executors + unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) + for d_idx, d_id in enumerate(config.arch.executor_device_ids): + device_params = jax.device_put(unreplicated_params, devices[d_id]) + for thread_id in range(config.arch.n_threads_per_executor): + params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( + device_params + ) + + + # Log the results of the training. - elapsed_time = time.time() - start_time + elapsed_time = time.time() - training_start_time t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) # todo: these shapes are not as expected - episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + episode_metrics = jax.tree_map(lambda *x : np.asarray(x), *episode_metrics) + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time # Separately log timesteps, actoring metrics and training metrics. - logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + speed_info = {"total_time" : elapsed_time, "rollout_time" : np.sum(rollout_times), "learner_time" : np.sum(learner_speeds), "timestep" : t} + logger.log(speed_info , t, eval_step, LogEvent.MISC) if ep_completed: # only log episode metrics if an episode was completed in the rollout. - logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + train_metrics = jax.tree_map(lambda *x : np.asarray(x), *train_metrics) + logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) # Evaluation on the learner + evaluation_start_timer = time.time() key_e, eval_key = jax.random.split(key_e, 2) episode_metrics = evaluator(unreplicate_n_dims(learner_output.learner_state.params.actor_params, 1 ), eval_key) # Log the results of the evaluation. - elapsed_time = time.time() - start_time + elapsed_time = time.time() - evaluation_start_timer episode_return = jnp.mean(episode_metrics["episode_return"]) steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) @@ -706,8 +737,32 @@ def run_experiment(_config: DictConfig) -> float: if config.arch.absolute_metric and max_episode_return <= episode_return: best_params = copy.deepcopy(learner_output.learner_state.params) max_episode_return = episode_return - #todo: abs metric - return None#eval_performance + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Record the performance for the final evaluation run. + eval_performance = float(jnp.mean(episode_metrics[config.env.eval_metric])) + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + key_e, eval_key = jax.random.split(key_e, 2) + episode_metrics = absolute_metric_evaluator(unreplicate_n_dims(best_params.actor_params, 1), eval_key) + + elapsed_time = time.time() - start_time + steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) + + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + + return eval_performance + @hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") diff --git a/mava/systems/sebulba/ppo/test.py b/mava/systems/sebulba/ppo/test.py index 5e45544f1..d1f34fccf 100644 --- a/mava/systems/sebulba/ppo/test.py +++ b/mava/systems/sebulba/ppo/test.py @@ -39,7 +39,8 @@ from flax import linen as nn import gym import rware -from mava.wrappers import GymRwareWrapper, GymRecordEpisodeMetrics, _multiagent_worker_shared_memory +import lbforaging +from mava.wrappers import GymRwareWrapper, GymRecordEpisodeMetrics, _multiagent_worker_shared_memory, GymAgentIDWrapper, GymLBFWrapper @hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" @@ -49,7 +50,8 @@ def hydra_entry_point(cfg: DictConfig) -> float: OmegaConf.set_struct(cfg, False) def f(): base = gym.make(cfg.env.scenario) - base = GymRwareWrapper(base, cfg.env.use_individual_rewards, False, True) + base = GymLBFWrapper(base, cfg.env.use_individual_rewards, True) + base = GymAgentIDWrapper(base) return GymRecordEpisodeMetrics(base) base = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names @@ -62,13 +64,14 @@ def f(): base.reset() n = 0 done = False + r = [0] * 3 while not done: n+= 1 - agents_view, reward, terminated, truncated, info = base.step([[0,0,0], [0,0,0]]) + agents_view, reward, terminated, truncated, info = base.step([r, r]) + print(terminated, truncated) done = np.logical_or(terminated, truncated).all() - metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) - print(n, done, terminated, np.logical_or(terminated, truncated).shape, metrics) - done = True + print(n, done) + #metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) base.close() print(done) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index cab649880..c23e40820 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -22,6 +22,7 @@ import jumanji import matrax from gigastep import ScenarioBuilder +import lbforaging from jaxmarl.environments.smax import map_name_to_scenario from jumanji.env import Environment from jumanji.environments.routing.cleaner.generator import ( @@ -46,7 +47,9 @@ GigastepWrapper, GymRecordEpisodeMetrics, GymRwareWrapper, + GymAgentIDWrapper, _multiagent_worker_shared_memory, + GymLBFWrapper, LbfWrapper, MabraxWrapper, MatraxWrapper, @@ -70,7 +73,7 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} -_gym_registry = {"rware": GymRwareWrapper} +_gym_registry = {"RobotWarehouse": GymRwareWrapper, "LevelBasedForaging" : GymLBFWrapper} def add_extra_wrappers( @@ -209,7 +212,7 @@ def make_gigastep_env( def make_gym_env( - config: DictConfig, num_env : int, add_global_state: bool = False, eval_env: bool = False + config: DictConfig, num_env : int, add_global_state: bool = False, ) -> Environment: # todo : create the appropriate annotation for the sync vector """ Create a Gym environment. @@ -222,23 +225,23 @@ def make_gym_env( Returns: A tuple of the environments. """ - base_env_name = config.env.scenario.split(":")[0] + base_env_name = config.env.env_name wrapper = _gym_registry[base_env_name] def create_gym_env( - config: DictConfig, add_global_state: bool = False, eval_env: bool = False + config: DictConfig, add_global_state: bool = False ) -> Environment: # todo: add the RecordEpisodeMetrics for gym. env = gym.make(config.env.scenario) - wrapped_env = wrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) + wrapped_env = wrapper(env, config.env.use_individual_rewards, add_global_state) if not config.env.implicit_agent_id: - wrapped_env = AgentIDWrapper(wrapped_env) # todo : add agent id wrapper for gym . + wrapped_env = GymAgentIDWrapper(wrapped_env) # todo : add agent id wrapper for gym . wrapped_env = GymRecordEpisodeMetrics(wrapped_env) return wrapped_env num_env = config.arch.num_envs envs = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names [ - lambda: create_gym_env(config, add_global_state, eval_env=eval_env) + lambda: create_gym_env(config, add_global_state) for _ in range(num_env) ], worker=_multiagent_worker_shared_memory diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 3608b1d10..64a5affec 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,7 +15,7 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper -from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper, _multiagent_worker_shared_memory +from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper, GymLBFWrapper, GymAgentIDWrapper, _multiagent_worker_shared_memory from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 546e05614..31146e29a 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -19,7 +19,7 @@ import numpy as np from numpy.typing import NDArray -from gym.spaces import Box +from gym import spaces from gym.vector.utils import write_to_shared_memory import sys @@ -51,7 +51,6 @@ def __init__( self._env = env #not having _env leaded tp self.env getting replaced --> circular called self.use_individual_rewards = use_individual_rewards self.add_global_state = add_global_state # todo : add the global observations - self.eval_env = eval_env self.num_agents = len(self._env.action_space) self.num_actions = self._env.action_space[ 0 @@ -88,6 +87,66 @@ def get_actions_mask(self, info: Dict) -> NDArray: return np.ones((self.num_agents, self.num_actions), dtype=np.float32) +class GymLBFWrapper(gym.Wrapper): + """Wrapper for rware gym environments""" + + def __init__( + self, + env: gym.Env, + use_individual_rewards: bool = False, + add_global_state: bool = False, + ): + """Initialize the gym wrapper + + Args: + env (gym.env): gym env instance. + use_individual_rewards (bool, optional): Use individual or group rewards. + Defaults to False. + add_global_state (bool, optional) : Create global observations. Defaults to False. + """ + super().__init__(env) + self._env = env #not having _env leaded tp self.env getting replaced --> circular called + self.use_individual_rewards = use_individual_rewards + self.add_global_state = add_global_state # todo : add the global observations + self.num_agents = len(self._env.action_space) + self.num_actions = self._env.action_space[ + 0 + ].n # todo: all the agents must have the same num_actions, add assertion? + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple: + + if seed is not None: + self.env.seed(seed) + + agents_view, info = self._env.reset() + + info = {"actions_mask": self.get_actions_mask(info)} + + return np.array(agents_view), info + + def step(self, actions: NDArray) -> Tuple: #Vect auto rest + + agents_view, reward, terminated, truncated, info = self._env.step(actions) + + info = {"actions_mask": self.get_actions_mask(info)} + + if self.use_individual_rewards: + reward = np.array(reward) + else: + reward = np.array([np.array(reward).mean()] * self.num_agents) + + + truncated = [truncated] * self.num_agents + + return agents_view, reward, terminated, truncated, info + + def get_actions_mask(self, info: Dict) -> NDArray: + if "action_mask" in info: + return np.array(info["action_mask"]) + return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + class GymRecordEpisodeMetrics(gym.Wrapper): """Record the episode returns and lengths.""" @@ -102,14 +161,11 @@ def reset(self) -> Tuple: # Reset the env agents_view, info = self._env.reset() - # Handle the Done when the auto reset happens - done = self.running_count_episode_length != -1 # Avoid setting the first ever done to True - # Create the metrics dict metrics = { "episode_return": self.running_count_episode_return, "episode_length": self.running_count_episode_length, - "is_terminal_step": done, + "is_terminal_step": True, } # Reset the metrics @@ -140,24 +196,26 @@ def step(self, actions: NDArray) -> Tuple: metrics["won_episode"] = info["won_episode"] info["metrics"] = metrics - + return agents_view, reward, terminated, truncated, info -class AgentIDWrapper(gym.Wrapper): +class GymAgentIDWrapper(gym.Wrapper): """Add onehot agent IDs to observation.""" def __init__(self, env: gym.Env): super().__init__(env) self.agent_ids = np.eye(self.env.num_agents) + observation_space = self.env.observation_space[0] _obs_low, _obs_high, _obs_dtype, _obs_shape = ( - self.env.observation_space.low[0][0], - self.env.observation_space.high[0][0], - self.env.observation_space.dtype, - self.env.observation_space.shape, + observation_space.low[0], + observation_space.high[0], + observation_space.dtype, + observation_space.shape, ) - _new_obs_shape = (self.env.num_agents, _obs_shape[1] + self.env.num_agents) - self._observation_space = Box(low=_obs_low, high=_obs_high, shape=_new_obs_shape, dtype=_obs_dtype) + _new_obs_shape = (_obs_shape[0] + self.env.num_agents,) + _observation_boxs = [spaces.Box(low=_obs_low, high=_obs_high, shape=_new_obs_shape, dtype=_obs_dtype)] * self.env.num_agents + self.observation_space = spaces.Tuple(_observation_boxs) def reset(self) -> Tuple[np.ndarray, Dict]: """Reset the environment.""" From 8a872587571b88da959aaea86802645cde827bfc Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 4 Jul 2024 10:02:43 +0100 Subject: [PATCH 023/125] fix: batch size calc for multiple devices --- mava/systems/sebulba/ppo/ff_ippo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 5df32bf5d..7ff158536 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -398,7 +398,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) # SHUFFLE MINIBATCHES - batch_size = config.system.rollout_length * config.arch.num_envs * len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor + batch_size = config.system.rollout_length * (config.arch.num_envs // len(config.arch.learner_device_ids)) * len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) From 7f0acd9eb878a54f0c8a0af9c450d3543bebf911 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 5 Jul 2024 11:16:06 +0100 Subject: [PATCH 024/125] fix: num_updates and code refactoring --- mava/systems/sebulba/ppo/ff_ippo.py | 47 ++++++++++++----------------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 7ff158536..8998de5f3 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -89,14 +89,6 @@ def get_action_and_value( value = critic_apply_fn(params.critic_params, observation).squeeze() return action, log_prob, value, key - @jax.jit - def prepare_data(storage: List[PPOTransition]) -> PPOTransition: - """Prepare data to share with learner.""" - return jax.tree_map( # type: ignore - lambda *xs : jnp.stack(xs), *storage - ) - - # Define queues to track time params_queue_get_time: deque = deque(maxlen=1) rollout_time: deque = deque(maxlen=1) @@ -109,12 +101,9 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: # Loop till the learner has finished training for update in range(config.system.num_updates): - print(update) - # Setup todo: double check tracking times inference_time: float = 0 storage_time: float = 0 env_send_time: float = 0 - setup = 0 # Get the latest parameters from the learner params_queue_get_time_start = time.time() @@ -131,9 +120,8 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: # Cached for transition cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) cached_next_dones = move_to_device(next_dones) - setup_start = time.time() cashed_action_mask = move_to_device(np.stack(info["actions_mask"]) ) - setup += time.time() - setup_start + # Increment current timestep t_env += ( config.arch.n_threads_per_executor * len_executor_device_ids * config.arch.num_envs @@ -141,15 +129,14 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: # Get action and value inference_time_start = time.time() - # ( action, log_prob, value, key, ) = get_action_and_value(params, Observation(cached_next_obs, cashed_action_mask), key) - inference_time += time.time() - inference_time_start + inference_time += time.time() - inference_time_start # Step the environment env_send_time_start = time.time() cpu_action = jax.device_get(action) @@ -161,7 +148,7 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: storage_time_start = time.time() # Prepare the data next_dones = np.logical_or(terminated, truncated) - metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics (N_envs , N_metrics) -- > (N_metrics, N_envs) + metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics # Append data to storage storage.append( @@ -173,22 +160,23 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: log_prob=log_prob, obs=Observation(cached_next_obs, cashed_action_mask), info=metrics, - )#todo: use a threadsafe alt https://github.com/instadeepai/CityLearn/blob/27e69f8ebdf1789c55ffab5c326bfaa50733a5e7/power_systems/sax_sebulba.py#L39 + ) ) storage_time += time.time() - storage_time_start rollout_time.append(time.time() - rollout_time_start) parse_timer = time.time() + # Prepare data to share with learner - # todo: investigate te thread --> single learning - partitioned_storage = prepare_data(storage) + stacked_storage = jax.tree_map( lambda *xs : jnp.stack(xs), *storage) + #sorage has shape rollout_len, num_agents, num_envs, .... while the other vectors have num_agents, num_envs, ... -> their split axis is diffrent shard_split_payload= lambda x, axis : jax.device_put_sharded(jnp.split(x, len(learner_devices), axis=axis), devices=learner_devices) - sharded_storage = jax.tree_map(lambda x : shard_split_payload(x, 1) , partitioned_storage) + sharded_storage = jax.tree_map(lambda x : shard_split_payload(x, 1) , stacked_storage) sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis = 1), 0) - sharded_next_action_mask = shard_split_payload(jnp.stack([*info["actions_mask"]], axis = 0), 0) + sharded_next_action_mask = shard_split_payload(np.stack(info["actions_mask"]), 0) sharded_next_done = shard_split_payload(next_dones, 0) @@ -200,7 +188,6 @@ def prepare_data(storage: List[PPOTransition]) -> PPOTransition: "env_step_time": env_send_time, "rollout_queue_put_time": np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0, "parse_time" : time.time() - parse_timer, - "setup_time" : setup, } #print(speed_info) @@ -581,13 +568,14 @@ def run_experiment(_config: DictConfig) -> float: evaluator, absolute_metric_evaluator = make_eval_fns(environments.make_gym_env, apply_fns[0], config) #todo: make this more generic # Calculate total timesteps. - config = sebulba_check_total_timesteps(config) #todo: update this for sebulba + config = sebulba_check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." # Calculate number of updates per evaluation. - config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + config.system.num_updates_per_eval, remaining_updates = divmod(config.system.num_updates , config.arch.num_evaluation) + config.arch.num_evaluation += (remaining_updates != 0) # Add an evaluation if the num_updates is not a multiple of num_evaluation steps_per_rollout = ( len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor @@ -638,8 +626,9 @@ def run_experiment(_config: DictConfig) -> float: learner_devices, d_id, ), - ).start() #todo : this is techinically only multu threaded not multi processepr? - + ).start() #todo : Use a process insted of a thread? threads are limited by pything's GIL and they only run on a single core , processes have a bogger overhead (max num_env for optimal performance?) + + # Run experiment for the total number of updates. max_episode_return = jnp.float32(0.0) best_params = None @@ -651,7 +640,9 @@ def run_experiment(_config: DictConfig) -> float: episode_metrics = [] train_metrics = [] - for update in range(config.system.num_updates_per_eval): + # Make sure that the + num_updates_in_eval = config.system.num_updates_per_eva if eval_step != config.arch.num_evaluation - 1 else remaining_updates + for update in range(num_updates_in_eval): sharded_storages = [] sharded_next_obss = [] sharded_next_dones = [] @@ -679,7 +670,7 @@ def run_experiment(_config: DictConfig) -> float: # Concatinate the returned trajectories on the n_env axis - sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) #todo: check if this breaks the explicet array device placment + sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) sharded_next_obss = jnp.concatenate(sharded_next_obss, axis = 1) sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) sharded_next_action_masks = jnp.concatenate(sharded_next_action_masks, axis = 1) From 3e352cffc37db558ec4e324a4afe6e56dd6fa1c8 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 8 Jul 2024 11:41:15 +0100 Subject: [PATCH 025/125] chore : code cleanup + comments + added checkpoint save --- mava/systems/sebulba/ppo/ff_ippo.py | 71 ++++++++++++----------------- mava/systems/sebulba/ppo/types.py | 1 + 2 files changed, 31 insertions(+), 41 deletions(-) diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 8998de5f3..f2168cf63 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -32,7 +32,7 @@ from optax._src.base import OptState from rich.pretty import pprint -from mava.evaluator import make_sebulba_eval_fns as make_eval_fns #todo: make a standered eval function +from mava.evaluator import make_sebulba_eval_fns as make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.sebulba.ppo.types import LearnerState, OptStates, Params, PPOTransition, Observation #todo: change this Observation to use the origial one @@ -55,21 +55,13 @@ def rollout( config: DictConfig, rollout_queue: queue.Queue, params_queue: queue.Queue, - device_thread_id: int, apply_fns: Tuple, - logger: MavaLogger, learner_devices: List, actor_device_id : int): - - #create envs - env = environments.make_gym_env(config, config.arch.num_envs) - + #setup - len_executor_device_ids = len(config.arch.executor_device_ids) + env = environments.make_gym_env(config, config.arch.num_envs) current_actor_device = jax.devices()[actor_device_id] - t_env = 0 - - actor_apply_fn, critic_apply_fn = apply_fns # Define the util functions: select action function and prepare data to share it with learner. @@ -94,7 +86,7 @@ def get_action_and_value( rollout_time: deque = deque(maxlen=1) rollout_queue_put_time: deque = deque(maxlen=1) - next_obs , info = env.reset() #todo : the first info is discarded , is that a problem? + next_obs , info = env.reset() next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) move_to_device = lambda x : jax.device_put(x, device = current_actor_device) @@ -118,14 +110,9 @@ def get_action_and_value( for _ in range(0, config.system.rollout_length): # Cached for transition - cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) - cached_next_dones = move_to_device(next_dones) - cashed_action_mask = move_to_device(np.stack(info["actions_mask"]) ) - - # Increment current timestep - t_env += ( - config.arch.n_threads_per_executor * len_executor_device_ids * config.arch.num_envs - ) + cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) # (num_envs, num_agents, ...) + cached_next_dones = move_to_device(next_dones) # (num_envs, num_agents) + cashed_action_mask = move_to_device(np.stack(info["actions_mask"])) # (num_envs, num_agents, num_actions) # Get action and value inference_time_start = time.time() @@ -136,17 +123,16 @@ def get_action_and_value( key, ) = get_action_and_value(params, Observation(cached_next_obs, cashed_action_mask), key) - inference_time += time.time() - inference_time_start + # Step the environment + inference_time += time.time() - inference_time_start env_send_time_start = time.time() cpu_action = jax.device_get(action) - - next_obs, next_reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) #num_env, num_agents --> num_agents, num_env + next_obs, next_reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) # (num_env, num_agents) --> (num_agents, num_env) env_send_time += time.time() - env_send_time_start - - storage_time_start = time.time() # Prepare the data + storage_time_start = time.time() next_dones = np.logical_or(terminated, truncated) metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics @@ -168,18 +154,21 @@ def get_action_and_value( parse_timer = time.time() # Prepare data to share with learner - stacked_storage = jax.tree_map( lambda *xs : jnp.stack(xs), *storage) + #[PPOTransition() * rollout_len] --> PPOTransition[done = (rollout_len, num_envs, num_agents), action = (rollout_len, num_envs, num_agents, num_actions), ...] + stacked_storage = jax.tree_map( lambda *xs : jnp.stack(xs), *storage) - #sorage has shape rollout_len, num_agents, num_envs, .... while the other vectors have num_agents, num_envs, ... -> their split axis is diffrent + + # Split the arrays over the different learner_devices on the num_envs axis shard_split_payload= lambda x, axis : jax.device_put_sharded(jnp.split(x, len(learner_devices), axis=axis), devices=learner_devices) - sharded_storage = jax.tree_map(lambda x : shard_split_payload(x, 1) , stacked_storage) + sharded_storage = jax.tree_map(lambda x : shard_split_payload(x, 1) , stacked_storage) # (num_learner_devices, rollout_len, num_envs, num_agents, ...) - sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis = 1), 0) + # (num_learner_devices, num_envs, num_agents, ...) + sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis = 1), 0) sharded_next_action_mask = shard_split_payload(np.stack(info["actions_mask"]), 0) sharded_next_done = shard_split_payload(next_dones, 0) - + # For debugging speed_info = { "rollout_time": np.mean(rollout_time), "params_queue_get_time": np.mean(params_queue_get_time), @@ -192,7 +181,6 @@ def get_action_and_value( #print(speed_info) payload = ( - t_env, sharded_storage, sharded_next_obs, sharded_next_done, @@ -447,8 +435,6 @@ def learner_setup( keys: chex.Array, config: DictConfig, learner_devices: List ) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: """Initialise learner_fn, network, optimiser, environment and states.""" - # Get available TPU cores. - n_devices = len(learner_devices) #create temporory envoirnments. env = environments.make_gym_env(config, config.arch.num_envs) @@ -502,7 +488,7 @@ def learner_setup( apply_fns = (actor_network.apply, critic_network.apply) update_fns = (actor_optim.update, critic_optim.update) - # Get batched iterated update and replicate it to pmap it over cores. + # Get batched iterated update and replicate it to pmap it over learner cores. learn = get_learner_fn(apply_fns, update_fns, config) learn = jax.pmap(learn, axis_name="device", devices = learner_devices) @@ -575,7 +561,7 @@ def run_experiment(_config: DictConfig) -> float: # Calculate number of updates per evaluation. config.system.num_updates_per_eval, remaining_updates = divmod(config.system.num_updates , config.arch.num_evaluation) - config.arch.num_evaluation += (remaining_updates != 0) # Add an evaluation if the num_updates is not a multiple of num_evaluation + config.arch.num_evaluation += (remaining_updates != 0) # Add an evaluation step if the num_updates is not a multiple of num_evaluation steps_per_rollout = ( len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor @@ -620,13 +606,11 @@ def run_experiment(_config: DictConfig) -> float: config, rollout_queues[-1], params_queues[-1], - d_idx * config.arch.n_threads_per_executor + thread_id, apply_fns, - logger, learner_devices, d_id, ), - ).start() #todo : Use a process insted of a thread? threads are limited by pything's GIL and they only run on a single core , processes have a bogger overhead (max num_env for optimal performance?) + ).start() #todo : Use a process instead of a thread? threads are limited by pything's GIL and they only run on a single core , processes have a bogger overhead (max num_env for optimal performance?) # Run experiment for the total number of updates. @@ -641,7 +625,7 @@ def run_experiment(_config: DictConfig) -> float: train_metrics = [] # Make sure that the - num_updates_in_eval = config.system.num_updates_per_eva if eval_step != config.arch.num_evaluation - 1 else remaining_updates + num_updates_in_eval = config.system.num_updates_per_eval if eval_step != config.arch.num_evaluation - 1 else remaining_updates for update in range(num_updates_in_eval): sharded_storages = [] sharded_next_obss = [] @@ -655,7 +639,6 @@ def run_experiment(_config: DictConfig) -> float: for thread_id in range(config.arch.n_threads_per_executor): # Get data from rollout queue ( - t_env, sharded_storage, sharded_next_obs, sharded_next_done, @@ -723,7 +706,13 @@ def run_experiment(_config: DictConfig) -> float: episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time logger.log(episode_metrics, t, eval_step, LogEvent.EVAL) - #todo: add saving + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=steps_per_rollout * (eval_step + 1), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state, 1), + episode_return=episode_return, + ) if config.arch.absolute_metric and max_episode_return <= episode_return: best_params = copy.deepcopy(learner_output.learner_state.params) diff --git a/mava/systems/sebulba/ppo/types.py b/mava/systems/sebulba/ppo/types.py index 6e02aa904..c27dcace5 100644 --- a/mava/systems/sebulba/ppo/types.py +++ b/mava/systems/sebulba/ppo/types.py @@ -88,6 +88,7 @@ class RNNPPOTransition(NamedTuple): log_prob: chex.Array obs: chex.Array hstates: HiddenStates + info: Dict class Observation(NamedTuple): From bcdaa381096b8c843127b051020af8c99d139c52 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 8 Jul 2024 14:53:56 +0100 Subject: [PATCH 026/125] feat: mappo + removed sebulba specifique types and made the rware wrapper generic --- mava/evaluator.py | 8 +- mava/systems/sebulba/ppo/ff_ippo.py | 28 +- mava/systems/sebulba/ppo/ff_mappo.py | 768 +++++++++++++++++++++++++++ mava/types.py | 6 +- mava/utils/make_env.py | 6 +- mava/wrappers/__init__.py | 2 +- mava/wrappers/gym.py | 80 +-- 7 files changed, 807 insertions(+), 91 deletions(-) create mode 100644 mava/systems/sebulba/ppo/ff_mappo.py diff --git a/mava/evaluator.py b/mava/evaluator.py index 066890ed9..f44a8d55b 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -31,7 +31,8 @@ RNNEvalState, ) -from mava.systems.sebulba.ppo.types import Observation +from mava.types import Observation + import numpy as np def get_anakin_ff_evaluator_fn( @@ -383,7 +384,7 @@ def eval_episodes(params: FrozenDict, key : chex.PRNGKey) -> Dict: key, policy_key = jax.random.split(key) obs = jax.device_put(jnp.stack(obs, axis = 1)) - action_mask = jax.device_put(jnp.stack([*info["actions_mask"]], axis = 0)) + action_mask = jax.device_put(np.stack(info["actions_mask"]) ) actions = get_action(params, Observation(obs, action_mask), policy_key) cpu_action = jax.device_get(actions) @@ -409,6 +410,7 @@ def make_sebulba_eval_fns( eval_env_fn: callable, network_apply_fn: Union[ActorApply, RecActorApply], config: DictConfig, + add_global_state : bool = False, use_recurrent_net: bool = False, scanned_rnn: Optional[nn.Module] = None, ) -> Tuple[EvalFn, EvalFn]: @@ -429,7 +431,7 @@ def make_sebulba_eval_fns( Raises: AssertionError: If `use_recurrent_net` is True but `scanned_rnn` is not provided. """ - eval_env, absolute_eval_env = eval_env_fn(config, config.arch.num_eval_episodes), eval_env_fn(config, config.arch.num_eval_episodes * 10) + eval_env, absolute_eval_env = eval_env_fn(config, config.arch.num_eval_episodes, add_global_state = add_global_state), eval_env_fn(config, config.arch.num_eval_episodes * 10, add_global_state = add_global_state) # Check if win rate is required for evaluation. log_win_rate = config.env.log_win_rate diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index f2168cf63..30e5bacbf 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -35,8 +35,8 @@ from mava.evaluator import make_sebulba_eval_fns as make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.sebulba.ppo.types import LearnerState, OptStates, Params, PPOTransition, Observation #todo: change this Observation to use the origial one -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn +from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import ( @@ -167,6 +167,9 @@ def get_action_and_value( sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis = 1), 0) sharded_next_action_mask = shard_split_payload(np.stack(info["actions_mask"]), 0) sharded_next_done = shard_split_payload(next_dones, 0) + + # Pack the obs and action mask + payload_obs = Observation(sharded_next_obs, sharded_next_action_mask) # For debugging speed_info = { @@ -182,9 +185,8 @@ def get_action_and_value( payload = ( sharded_storage, - sharded_next_obs, + payload_obs, sharded_next_done, - sharded_next_action_mask ) # Put data in the rollout queue to share it with the learner @@ -204,7 +206,7 @@ def get_learner_fn( actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_action_mask : chex.Array, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: + def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -246,7 +248,7 @@ def _get_advantages( # CALCULATE ADVANTAGE params, opt_states, key, _, _ = learner_state - last_val = critic_apply_fn(params.critic_params, Observation(last_obs, last_action_mask)) + last_val = critic_apply_fn(params.critic_params, last_obs) advantages, targets = _calculate_gae(traj_batch, last_val, last_dones) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: @@ -403,7 +405,7 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_action_mask : chex.Array, last_dones : chex.Array) -> ExperimentOutput[LearnerState]: + def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_dones : chex.Array) -> ExperimentOutput[LearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -420,7 +422,7 @@ def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs """ - learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch , last_obs, last_action_mask, last_dones) + learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch , last_obs, last_dones) return ExperimentOutput( learner_state=learner_state, @@ -630,7 +632,6 @@ def run_experiment(_config: DictConfig) -> float: sharded_storages = [] sharded_next_obss = [] sharded_next_dones = [] - sharded_next_action_masks = [] rollout_start_time = time.time() # Loop through each executor device @@ -642,25 +643,22 @@ def run_experiment(_config: DictConfig) -> float: sharded_storage, sharded_next_obs, sharded_next_done, - sharded_next_action_mask ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() sharded_storages.append(sharded_storage) sharded_next_obss.append(sharded_next_obs) sharded_next_dones.append(sharded_next_done) - sharded_next_action_masks.append(sharded_next_action_mask) - + rollout_times.append(time.time() - rollout_start_time) # Concatinate the returned trajectories on the n_env axis sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) - sharded_next_obss = jnp.concatenate(sharded_next_obss, axis = 1) + sharded_next_obss = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 1), *sharded_next_obss) sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) - sharded_next_action_masks = jnp.concatenate(sharded_next_action_masks, axis = 1) learner_start_time = time.time() - learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_action_masks, sharded_next_dones) + learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_dones) learner_speeds.append(time.time() - learner_start_time) # Stack the metrics diff --git a/mava/systems/sebulba/ppo/ff_mappo.py b/mava/systems/sebulba/ppo/ff_mappo.py new file mode 100644 index 000000000..5f84fd0d0 --- /dev/null +++ b/mava/systems/sebulba/ppo/ff_mappo.py @@ -0,0 +1,768 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time +from typing import Any, Dict, Tuple, List +import threading +import chex +import flax +import hydra +import jax +import jax.debug +import jax.numpy as jnp +import numpy as np +import optax +import queue +from collections import deque +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from omegaconf import DictConfig, OmegaConf +from optax._src.base import OptState +from rich.pretty import pprint + +from mava.evaluator import make_sebulba_eval_fns as make_eval_fns +from mava.networks import FeedForwardActor as Actor +from mava.networks import FeedForwardValueNet as Critic +from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition #todo: change this Observation to use the standard obs +from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, ObservationGlobalState +from mava.utils import make_env as environments +from mava.utils.checkpointing import Checkpointer +from mava.utils.jax_utils import ( + merge_leading_dims, + unreplicate_batch_dim, + unreplicate_n_dims, +) +from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.total_timestep_checker import sebulba_check_total_timesteps +from mava.utils.training import make_learning_rate +from mava.wrappers.episode_metrics import get_final_step_metrics + + +def rollout( + key: chex.PRNGKey, + config: DictConfig, + rollout_queue: queue.Queue, + params_queue: queue.Queue, + apply_fns: Tuple, + learner_devices: List, + actor_device_id : int): + + #setup + env = environments.make_gym_env(config, config.arch.num_envs, add_global_state=True) + current_actor_device = jax.devices()[actor_device_id] + actor_apply_fn, critic_apply_fn = apply_fns + + # Define the util functions: select action function and prepare data to share it with learner. + @jax.jit + def get_action_and_value( + params: FrozenDict, + observation: ObservationGlobalState, + key: chex.PRNGKey, + ) -> Tuple: + """Get action and value.""" + key, subkey = jax.random.split(key) + + actor_policy = actor_apply_fn(params.actor_params, observation) + action = actor_policy.sample(seed=subkey) + log_prob = actor_policy.log_prob(action) + + value = critic_apply_fn(params.critic_params, observation).squeeze() + return action, log_prob, value, key + + # Define queues to track time + params_queue_get_time: deque = deque(maxlen=1) + rollout_time: deque = deque(maxlen=1) + rollout_queue_put_time: deque = deque(maxlen=1) + + next_obs , info = env.reset() + next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) + + move_to_device = lambda x : jax.device_put(x, device = current_actor_device) + + # Loop till the learner has finished training + for update in range(config.system.num_updates): + inference_time: float = 0 + storage_time: float = 0 + env_send_time: float = 0 + + # Get the latest parameters from the learner + params_queue_get_time_start = time.time() + params = params_queue.get() + params_queue_get_time.append(time.time() - params_queue_get_time_start) + + # Rollout + rollout_time_start = time.time() + storage: List = [] + + # Loop over the rollout length + for _ in range(0, config.system.rollout_length): + + # Cached for transition + cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) # (num_envs, num_agents, ...) + cached_next_dones = move_to_device(next_dones) # (num_envs, num_agents) + cashed_action_mask = move_to_device(np.stack(info["actions_mask"])) # (num_envs, num_agents, num_actions) + cached_next_global_obs = move_to_device(np.stack(info["global_obs"])) + + + # Get action and value + full_observation = ObservationGlobalState(cached_next_obs, cashed_action_mask, cached_next_global_obs) + inference_time_start = time.time() + ( + action, + log_prob, + value, + key, + ) = get_action_and_value(params, full_observation , key) + + + # Step the environment + inference_time += time.time() - inference_time_start + env_send_time_start = time.time() + cpu_action = jax.device_get(action) + next_obs, next_reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) # (num_env, num_agents) --> (num_agents, num_env) + env_send_time += time.time() - env_send_time_start + + # Prepare the data + storage_time_start = time.time() + next_dones = np.logical_or(terminated, truncated) + metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics + + # Append data to storage + storage.append( + PPOTransition( + done=cached_next_dones, + action=action, + value=value, + reward=next_reward, + log_prob=log_prob, + obs=full_observation, + info=metrics, + ) + ) + storage_time += time.time() - storage_time_start + rollout_time.append(time.time() - rollout_time_start) + + parse_timer = time.time() + + # Prepare data to share with learner + #[PPOTransition() * rollout_len] --> PPOTransition[done = (rollout_len, num_envs, num_agents), action = (rollout_len, num_envs, num_agents, num_actions), ...] + stacked_storage = jax.tree_map( lambda *xs : jnp.stack(xs), *storage) + + + # Split the arrays over the different learner_devices on the num_envs axis + shard_split_payload= lambda x, axis : jax.device_put_sharded(jnp.split(x, len(learner_devices), axis=axis), devices=learner_devices) + + sharded_storage = jax.tree_map(lambda x : shard_split_payload(x, 1) , stacked_storage) # (num_learner_devices, rollout_len, num_envs, num_agents, ...) + + # (num_learner_devices, num_envs, num_agents, ...) + sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis = 1), 0) + sharded_next_action_mask = shard_split_payload(np.stack(info["actions_mask"]), 0) + sharded_next_global_obs = shard_split_payload(np.stack(info["global_obs"]), 0) + sharded_next_done = shard_split_payload(next_dones, 0) + + # Pack the obs and action mask + payload_obs = ObservationGlobalState(sharded_next_obs, sharded_next_action_mask, sharded_next_global_obs) + + # For debugging + speed_info = { + "rollout_time": np.mean(rollout_time), + "params_queue_get_time": np.mean(params_queue_get_time), + "action_inference": inference_time, + "storage_time": storage_time, + "env_step_time": env_send_time, + "rollout_queue_put_time": np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0, + "parse_time" : time.time() - parse_timer, + } + #print(speed_info) + + payload = ( + sharded_storage, + payload_obs, + sharded_next_done, + ) + + # Put data in the rollout queue to share it with the learner + rollout_queue_put_time_start = time.time() + rollout_queue.put(payload) + rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) + + +def get_learner_fn( + apply_fns: Tuple[ActorApply, CriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> LearnerFn[LearnerState]: + """Get the learner function.""" + + # Get apply and update functions for actor and critic networks. + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: + """A single update of the network. + + This function steps the environment and records the trajectory batch for + training. It then calculates advantages and targets based on the recorded + trajectory and updates the actor and critic networks based on the calculated + losses. + + Args: + learner_state (NamedTuple): + - params (Params): The current model parameters. + - opt_states (OptStates): The current optimizer states. + - key (PRNGKey): The random number generator state. + - env_state (State): The environment state. + - last_timestep (TimeStep): The last timestep in the current trajectory. + _ (Any): The current metrics info. + """ + + def _calculate_gae( #todo: lake sure this is appropriate + traj_batch: PPOTransition, last_val: chex.Array, last_done: chex.Array + ) -> Tuple[chex.Array, chex.Array]: + def _get_advantages( + carry: Tuple[chex.Array, chex.Array, chex.Array], transition: PPOTransition + ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: + gae, next_value, next_done = carry + done, value, reward = transition.done, transition.value, transition.reward + gamma = config.system.gamma + delta = reward + gamma * next_value * (1 - next_done) - value + gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae + return (gae, value, done), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val, last_done), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + # CALCULATE ADVANTAGE + params, opt_states, key, _, _ = learner_state + last_val = critic_apply_fn(params.critic_params, last_obs) + advantages, targets = _calculate_gae(traj_batch, last_val, last_dones) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + + # UNPACK TRAIN STATE AND BATCH INFO + params, opt_states, key = train_state + traj_batch, advantages, targets = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + actor_opt_state: OptState, + traj_batch: PPOTransition, + gae: chex.Array, + key: chex.PRNGKey, + ) -> Tuple: + """Calculate the actor loss.""" + # RERUN NETWORK + actor_policy = actor_apply_fn(actor_params, traj_batch.obs) + log_prob = actor_policy.log_prob(traj_batch.action) + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config.system.clip_eps, + 1.0 + config.system.clip_eps, + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + # The seed will be used in the TanhTransformedDistribution: + entropy = actor_policy.entropy(seed=key).mean() + + total_loss_actor = loss_actor - config.system.ent_coef * entropy + return total_loss_actor, (loss_actor, entropy) + + def _critic_loss_fn( + critic_params: FrozenDict, + critic_opt_state: OptState, + traj_batch: PPOTransition, + targets: chex.Array, + ) -> Tuple: + """Calculate the critic loss.""" + # RERUN NETWORK + value = critic_apply_fn(critic_params, traj_batch.obs) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( + -config.system.clip_eps, config.system.clip_eps + ) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + + critic_total_loss = config.system.vf_coef * value_loss + return critic_total_loss, (value_loss) + + # CALCULATE ACTOR LOSS + key, entropy_key = jax.random.split(key) + actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) + actor_loss_info, actor_grads = actor_grad_fn( + params.actor_params, + opt_states.actor_opt_state, + traj_batch, + advantages, + entropy_key, + ) + + # CALCULATE CRITIC LOSS + critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) + critic_loss_info, critic_grads = critic_grad_fn( + params.critic_params, opt_states.critic_opt_state, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # pmean over devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="device" #todo: pmean over learner devices not all + ) + + # pmean over devices. + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="device" + ) + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # UPDATE CRITIC PARAMS AND OPTIMISER STATE + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + # PACK NEW PARAMS AND OPTIMISER STATE + new_params = Params(actor_new_params, critic_new_params) + new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) + # PACK LOSS INFO + total_loss = actor_loss_info[0] + critic_loss_info[0] + value_loss = critic_loss_info[1] + actor_loss = actor_loss_info[1][0] + entropy = actor_loss_info[1][1] + loss_info = { + "total_loss": total_loss, + "value_loss": value_loss, + "actor_loss": actor_loss, + "entropy": entropy, + } + return (new_params, new_opt_state, entropy_key), loss_info + + params, opt_states, traj_batch, advantages, targets, key = update_state + key, shuffle_key, entropy_key = jax.random.split(key, 3) + # SHUFFLE MINIBATCHES + batch_size = config.system.rollout_length * (config.arch.num_envs // len(config.arch.learner_device_ids)) * len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor + permutation = jax.random.permutation(shuffle_key, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])), + shuffled_batch, + ) + # UPDATE MINIBATCHES + (params, opt_states, entropy_key), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states, entropy_key), minibatches + ) + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + return update_state, loss_info + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.ppo_epochs + ) + + params, opt_states, traj_batch, advantages, targets, key = update_state + learner_state = LearnerState(params, opt_states, key, None, None) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_dones : chex.Array) -> ExperimentOutput[LearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + Args: + learner_state (NamedTuple): + - params (Params): The initial model parameters. + - opt_states (OptStates): The initial optimizer state. + - key (chex.PRNGKey): The random number generator state. + - env_state (LogEnvState): The environment state. + - timesteps (TimeStep): The initial timestep in the initial trajectory. + """ + + + learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch , last_obs, last_dones) + + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_setup( + keys: chex.Array, config: DictConfig, learner_devices: List +) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + + #create temporory envoirnments. + env = environments.make_gym_env(config, 1, add_global_state=True) + # Get number of agents and actions. + action_space = env.single_action_space + config.system.num_agents = len(action_space) + config.system.num_actions = action_space[0].n + + # PRNG keys. + key, actor_net_key, critic_net_key = keys + + # Define network and optimiser. + actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_action_head = hydra.utils.instantiate( + config.network.action_head, action_dim=config.system.num_actions + ) + critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + + actor_network = Actor(torso=actor_torso, action_head=actor_action_head) + critic_network = Critic(torso=critic_torso, centralised_critic= True) + + actor_lr = make_learning_rate(config.system.actor_lr, config) + critic_lr = make_learning_rate(config.system.critic_lr, config) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(critic_lr, eps=1e-5), + ) + + # Initialise observation: Select only obs for a single agent. + obs, info = env.reset() + init_obs = jnp.stack(obs, axis = 1) # (num_envs, num_agents, ...) + init_mask = np.stack(info["actions_mask"]) # (num_envs, num_agents, num_actions) + init_global_obs = np.stack(info["global_obs"]) + init_x = ObservationGlobalState(init_obs, init_mask, init_global_obs) + + # Initialise actor params and optimiser state. + actor_params = actor_network.init(actor_net_key, init_x) + actor_opt_state = actor_optim.init(actor_params) + + # Initialise critic params and optimiser state. + critic_params = critic_network.init(critic_net_key, init_x) + critic_opt_state = critic_optim.init(critic_params) + + # Pack params. + params = Params(actor_params, critic_params) + + # Pack apply and update functions. + apply_fns = (actor_network.apply, critic_network.apply) + update_fns = (actor_optim.update, critic_optim.update) + + # Get batched iterated update and replicate it to pmap it over learner cores. + learn = get_learner_fn(apply_fns, update_fns, config) + learn = jax.pmap(learn, axis_name="device", devices = learner_devices) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params(input_params=params) + # Update the params + params = restored_params + + # Define params to be replicated across devices and batches. + key, step_keys = jax.random.split(key) + opt_states = OptStates(actor_opt_state, critic_opt_state) + replicate_learner = (params, opt_states, step_keys) + + # Duplicate learner across Learner devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=learner_devices) + + # Initialise learner state. + params, opt_states, step_keys = replicate_learner + init_learner_state = LearnerState(params, opt_states, step_keys, None, None) + env.close() + + return learn, apply_fns, init_learner_state + + +def run_experiment(_config: DictConfig) -> float: + """Runs experiment.""" + config = copy.deepcopy(_config) + + devices = jax.devices() + learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] + + # PRNG keys. + key, key_e, actor_net_key, critic_net_key = jax.random.split( + jax.random.PRNGKey(config.system.seed), num=4 + ) + + # Sanity check of config + assert ( + config.arch.num_envs % len(config.arch.learner_device_ids) == 0 + ), "The number of environments must to be divisible by the number of learners " + + assert ( + int(config.arch.num_envs / len(config.arch.learner_device_ids)) + * config.arch.n_threads_per_executor + % config.system.num_minibatches + == 0 + ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" + + + # Setup learner. + learn, apply_fns , learner_state = learner_setup( + (key ,actor_net_key, critic_net_key), config, learner_devices + ) + + # Setup evaluator. + # One key per device for evaluation. + evaluator, absolute_metric_evaluator = make_eval_fns(environments.make_gym_env, apply_fns[0], config, add_global_state=True) #todo: make this more generic + + # Calculate total timesteps. + config = sebulba_check_total_timesteps(config) + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Calculate number of updates per evaluation. + config.system.num_updates_per_eval, remaining_updates = divmod(config.system.num_updates , config.arch.num_evaluation) + config.arch.num_evaluation += (remaining_updates != 0) # Add an evaluation step if the num_updates is not a multiple of num_evaluation + steps_per_rollout = ( + len(config.arch.executor_device_ids) + * config.arch.n_threads_per_executor + * config.system.rollout_length + * config.arch.num_envs + * config.system.num_updates_per_eval + ) + + # Logger setup + logger = MavaLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Executor setup and launch. + unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) + params_queues: List = [] + rollout_queues: List = [] + for d_idx, d_id in enumerate( # Loop through each executor device + config.arch.executor_device_ids + ): + # Replicate params per executor device + device_params = jax.device_put(unreplicated_params, devices[d_id]) + # Loop through each executor thread + for thread_id in range(config.arch.n_threads_per_executor): + params_queues.append(queue.Queue(maxsize=1)) + rollout_queues.append(queue.Queue(maxsize=1)) + params_queues[-1].put(device_params) + threading.Thread( + target=rollout, + args=( + jax.device_put(key, devices[d_id]), + config, + rollout_queues[-1], + params_queues[-1], + apply_fns, + learner_devices, + d_id, + ), + ).start() #todo : Use a process instead of a thread? threads are limited by pything's GIL and they only run on a single core , processes have a bogger overhead (max num_env for optimal performance?) + + + # Run experiment for the total number of updates. + max_episode_return = jnp.float32(0.0) + best_params = None + for eval_step in range(config.arch.num_evaluation): + training_start_time = time.time() + learner_speeds = [] + rollout_times = [] + + episode_metrics = [] + train_metrics = [] + + # Make sure that the + num_updates_in_eval = config.system.num_updates_per_eval if eval_step != config.arch.num_evaluation - 1 else remaining_updates + for update in range(num_updates_in_eval): + sharded_storages = [] + sharded_next_obss = [] + sharded_next_dones = [] + + rollout_start_time = time.time() + # Loop through each executor device + for d_idx, _ in enumerate(config.arch.executor_device_ids): + # Loop through each executor thread + for thread_id in range(config.arch.n_threads_per_executor): + # Get data from rollout queue + ( + sharded_storage, + sharded_next_obs, + sharded_next_done, + ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() + sharded_storages.append(sharded_storage) + sharded_next_obss.append(sharded_next_obs) + sharded_next_dones.append(sharded_next_done) + + rollout_times.append(time.time() - rollout_start_time) + + + # Concatinate the returned trajectories on the n_env axis + sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) + sharded_next_obss = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 1), *sharded_next_obss) + sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) + + + learner_start_time = time.time() + learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_dones) + learner_speeds.append(time.time() - learner_start_time) + + # Stack the metrics + episode_metrics.append(learner_output.episode_metrics) + train_metrics.append(learner_output.train_metrics) + + # Send updated params to executors + unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) + for d_idx, d_id in enumerate(config.arch.executor_device_ids): + device_params = jax.device_put(unreplicated_params, devices[d_id]) + for thread_id in range(config.arch.n_threads_per_executor): + params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( + device_params + ) + + + + # Log the results of the training. + elapsed_time = time.time() - training_start_time + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics = jax.tree_map(lambda *x : np.asarray(x), *episode_metrics) + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + speed_info = {"total_time" : elapsed_time, "rollout_time" : np.sum(rollout_times), "learner_time" : np.sum(learner_speeds), "timestep" : t} + logger.log(speed_info , t, eval_step, LogEvent.MISC) + if ep_completed: # only log episode metrics if an episode was completed in the rollout. + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + train_metrics = jax.tree_map(lambda *x : np.asarray(x), *train_metrics) + logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) + + # Evaluation on the learner + evaluation_start_timer = time.time() + key_e, eval_key = jax.random.split(key_e, 2) + episode_metrics = evaluator(unreplicate_n_dims(learner_output.learner_state.params.actor_params, 1 ), eval_key) + + # Log the results of the evaluation. + elapsed_time = time.time() - evaluation_start_timer + episode_return = jnp.mean(episode_metrics["episode_return"]) + + steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) + episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(episode_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=steps_per_rollout * (eval_step + 1), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state, 1), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(learner_output.learner_state.params) + max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Record the performance for the final evaluation run. + eval_performance = float(jnp.mean(episode_metrics[config.env.eval_metric])) + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + key_e, eval_key = jax.random.split(key_e, 2) + episode_metrics = absolute_metric_evaluator(unreplicate_n_dims(best_params.actor_params, 1), eval_key) + + elapsed_time = time.time() - start_time + steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) + + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + + return eval_performance + + + +@hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> float: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + eval_performance = run_experiment(cfg) + print(f"{Fore.CYAN}{Style.BRIGHT}IPPO experiment completed{Style.RESET_ALL}") + return eval_performance + + +if __name__ == "__main__": + hydra_entry_point() + +#learner_output.episode_metrics.keys() +#dict_keys(['episode_length', 'episode_return']) \ No newline at end of file diff --git a/mava/types.py b/mava/types.py index aa79bf5b4..c6a2cf6aa 100644 --- a/mava/types.py +++ b/mava/types.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Generic, Tuple, TypeVar +from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Optional import chex from flax.core.frozen_dict import FrozenDict @@ -37,7 +37,7 @@ class Observation(NamedTuple): agents_view: chex.Array # (num_agents, num_obs_features) action_mask: chex.Array # (num_agents, num_actions) - step_count: chex.Array # (num_agents, ) + step_count: Optional[chex.Array] = None # (num_agents, ) class ObservationGlobalState(NamedTuple): @@ -49,7 +49,7 @@ class ObservationGlobalState(NamedTuple): agents_view: chex.Array # (num_agents, num_obs_features) action_mask: chex.Array # (num_agents, num_actions) global_state: chex.Array # (num_agents, num_agents * num_obs_features) - step_count: chex.Array # (num_agents, ) + step_count: Optional[chex.Array] = None # (num_agents, ) RNNObservation: TypeAlias = Tuple[Observation, Done] diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index c23e40820..a9313bf64 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -46,10 +46,9 @@ ConnectorWrapper, GigastepWrapper, GymRecordEpisodeMetrics, - GymRwareWrapper, + GymGenericWrapper, GymAgentIDWrapper, _multiagent_worker_shared_memory, - GymLBFWrapper, LbfWrapper, MabraxWrapper, MatraxWrapper, @@ -73,7 +72,7 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} -_gym_registry = {"RobotWarehouse": GymRwareWrapper, "LevelBasedForaging" : GymLBFWrapper} +_gym_registry = {"RobotWarehouse": GymGenericWrapper, "LevelBasedForaging" : GymGenericWrapper} def add_extra_wrappers( @@ -238,7 +237,6 @@ def create_gym_env( wrapped_env = GymRecordEpisodeMetrics(wrapped_env) return wrapped_env - num_env = config.arch.num_envs envs = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names [ lambda: create_gym_env(config, add_global_state) diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 64a5affec..703d85279 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,7 +15,7 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper -from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper, GymLBFWrapper, GymAgentIDWrapper, _multiagent_worker_shared_memory +from mava.wrappers.gym import GymRecordEpisodeMetrics, GymGenericWrapper, GymAgentIDWrapper, _multiagent_worker_shared_memory from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 31146e29a..b329241d9 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -27,7 +27,7 @@ warnings.filterwarnings("ignore", module="gym.utils.passive_env_checker") -class GymRwareWrapper(gym.Wrapper): +class GymGenericWrapper(gym.Wrapper): """Wrapper for rware gym environments""" def __init__( @@ -35,7 +35,6 @@ def __init__( env: gym.Env, use_individual_rewards: bool = False, add_global_state: bool = False, - eval_env: bool = False, ): """Initialize the gym wrapper @@ -44,17 +43,15 @@ def __init__( use_individual_rewards (bool, optional): Use individual or group rewards. Defaults to False. add_global_state (bool, optional) : Create global observations. Defaults to False. - eval_env (bool, optional): Weather the instance is used for training or evaluation. - Defaults to False. """ super().__init__(env) - self._env = env #not having _env leaded tp self.env getting replaced --> circular called + self._env = env self.use_individual_rewards = use_individual_rewards - self.add_global_state = add_global_state # todo : add the global observations + self.add_global_state = add_global_state self.num_agents = len(self._env.action_space) self.num_actions = self._env.action_space[ 0 - ].n # todo: all the agents must have the same num_actions, add assertion? + ].n def reset( self, seed: Optional[int] = None, options: Optional[dict] = None @@ -66,19 +63,24 @@ def reset( agents_view, info = self._env.reset() info = {"actions_mask": self.get_actions_mask(info)} - + if self.add_global_state: + info["global_obs"] = self.get_global_obs(agents_view) + return np.array(agents_view), info - def step(self, actions: NDArray) -> Tuple: #Vect auto rest + def step(self, actions: NDArray) -> Tuple: agents_view, reward, terminated, truncated, info = self._env.step(actions) info = {"actions_mask": self.get_actions_mask(info)} + if self.add_global_state: + info["global_obs"] = self.get_global_obs(agents_view) if self.use_individual_rewards: reward = np.array(reward) else: reward = np.array([np.array(reward).mean()] * self.num_agents) + return agents_view, reward, terminated, truncated, info def get_actions_mask(self, info: Dict) -> NDArray: @@ -86,66 +88,14 @@ def get_actions_mask(self, info: Dict) -> NDArray: return np.array(info["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + def get_global_obs(self, obs: NDArray): + global_obs = np.concatenate(obs, axis=0) + return np.tile(global_obs, (self.num_agents, 1)) + -class GymLBFWrapper(gym.Wrapper): - """Wrapper for rware gym environments""" - - def __init__( - self, - env: gym.Env, - use_individual_rewards: bool = False, - add_global_state: bool = False, - ): - """Initialize the gym wrapper - - Args: - env (gym.env): gym env instance. - use_individual_rewards (bool, optional): Use individual or group rewards. - Defaults to False. - add_global_state (bool, optional) : Create global observations. Defaults to False. - """ - super().__init__(env) - self._env = env #not having _env leaded tp self.env getting replaced --> circular called - self.use_individual_rewards = use_individual_rewards - self.add_global_state = add_global_state # todo : add the global observations - self.num_agents = len(self._env.action_space) - self.num_actions = self._env.action_space[ - 0 - ].n # todo: all the agents must have the same num_actions, add assertion? - - def reset( - self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple: - - if seed is not None: - self.env.seed(seed) - - agents_view, info = self._env.reset() - - info = {"actions_mask": self.get_actions_mask(info)} - - return np.array(agents_view), info - - def step(self, actions: NDArray) -> Tuple: #Vect auto rest - - agents_view, reward, terminated, truncated, info = self._env.step(actions) - - info = {"actions_mask": self.get_actions_mask(info)} - - if self.use_individual_rewards: - reward = np.array(reward) - else: - reward = np.array([np.array(reward).mean()] * self.num_agents) - - truncated = [truncated] * self.num_agents - return agents_view, reward, terminated, truncated, info - def get_actions_mask(self, info: Dict) -> NDArray: - if "action_mask" in info: - return np.array(info["action_mask"]) - return np.ones((self.num_agents, self.num_actions), dtype=np.float32) class GymRecordEpisodeMetrics(gym.Wrapper): """Record the episode returns and lengths.""" From 7044fbef5b423b8a65c554ef746669a8d921c144 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 8 Jul 2024 14:54:51 +0100 Subject: [PATCH 027/125] fix: removed the sebulba spesifique types --- mava/systems/sebulba/ppo/types.py | 101 ------------------------------ 1 file changed, 101 deletions(-) delete mode 100644 mava/systems/sebulba/ppo/types.py diff --git a/mava/systems/sebulba/ppo/types.py b/mava/systems/sebulba/ppo/types.py deleted file mode 100644 index c27dcace5..000000000 --- a/mava/systems/sebulba/ppo/types.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict - -import chex -from flax.core.frozen_dict import FrozenDict -from jumanji.types import TimeStep -from optax._src.base import OptState -from typing_extensions import NamedTuple - -from mava.types import Action, Done, HiddenState, State, Value - - -class Params(NamedTuple): - """Parameters of an actor critic network.""" - - actor_params: FrozenDict - critic_params: FrozenDict - - -class OptStates(NamedTuple): - """OptStates of actor critic learner.""" - - actor_opt_state: OptState - critic_opt_state: OptState - - -class HiddenStates(NamedTuple): - """Hidden states for an actor critic learner.""" - - policy_hidden_state: HiddenState - critic_hidden_state: HiddenState - - -class LearnerState(NamedTuple): - """State of the learner.""" - - params: Params - opt_states: OptStates - key: chex.PRNGKey - env_state: State - timestep: TimeStep - - -class RNNLearnerState(NamedTuple): - """State of the `Learner` for recurrent architectures.""" - - params: Params - opt_states: OptStates - key: chex.PRNGKey - env_state: State - timestep: TimeStep - dones: Done - hstates: HiddenStates - - -class PPOTransition(NamedTuple): - """Transition tuple for PPO.""" - - done: Done - action: Action - value: Value - reward: chex.Array - log_prob: chex.Array - obs: chex.Array - info: Dict - - -class RNNPPOTransition(NamedTuple): - """Transition tuple for PPO.""" - - done: Done - action: Action - value: Value - reward: chex.Array - log_prob: chex.Array - obs: chex.Array - hstates: HiddenStates - info: Dict - - -class Observation(NamedTuple): - """The observation that the agent sees. - agents_view: the agent's view of the environment. - action_mask: boolean array specifying, for each agent, which action is legal. - """ - - agents_view: chex.Array # (num_agents, num_obs_features) - action_mask: chex.Array # (num_agents, num_actions) From 9433f2eb0180d97ab0f87fef7ac87327bf5f40cf Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 10 Jul 2024 09:09:05 +0100 Subject: [PATCH 028/125] feat: ff_mappo and rec_ippo in sebulba --- mava/configs/arch/sebulba.yaml | 6 +- mava/configs/default_ff_mappo_seb.yaml | 7 + mava/configs/default_rec_ippo_seb.yaml | 7 + mava/configs/system/ppo/ff_ippo.yaml | 6 +- mava/evaluator.py | 88 ++- mava/systems/sebulba/ppo/ff_ippo.py | 11 +- mava/systems/sebulba/ppo/ff_mappo.py | 4 +- mava/systems/sebulba/ppo/rec_ippo.py | 850 +++++++++++++++++++++++++ 8 files changed, 960 insertions(+), 19 deletions(-) create mode 100644 mava/configs/default_ff_mappo_seb.yaml create mode 100644 mava/configs/default_rec_ippo_seb.yaml create mode 100644 mava/systems/sebulba/ppo/rec_ippo.py diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 617e54134..fd555f71e 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,18 +1,18 @@ # --- Sebulba config --- arch_name: "sebulba" -num_envs: 3 # number of envs per thread +num_envs: 32 # number of envs per thread # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select # an action which corresponds to the greatest logit. If false, the policy will sample # from the logits. num_eval_episodes: 32 # Number of episodes to evaluate per evaluation. -num_evaluation: 10 # Number of evenly spaced evaluations to perform during training. +num_evaluation: 200 # Number of evenly spaced evaluations to perform during training. absolute_metric: True # Whether the absolute metric should be computed. For more details # on the absolute metric please see: https://arxiv.org/abs/2209.10485 # --- Sebulba devices config --- -n_threads_per_executor: 2 # num of different threads/env batches per actor +n_threads_per_executor: 1 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices diff --git a/mava/configs/default_ff_mappo_seb.yaml b/mava/configs/default_ff_mappo_seb.yaml new file mode 100644 index 000000000..8d96d3e97 --- /dev/null +++ b/mava/configs/default_ff_mappo_seb.yaml @@ -0,0 +1,7 @@ +defaults: + - logger: ff_mappo + - arch: sebulba + - system: ppo/ff_mappo + - network: mlp + - env: gym + - _self_ diff --git a/mava/configs/default_rec_ippo_seb.yaml b/mava/configs/default_rec_ippo_seb.yaml new file mode 100644 index 000000000..61eaa95f1 --- /dev/null +++ b/mava/configs/default_rec_ippo_seb.yaml @@ -0,0 +1,7 @@ +defaults: + - logger: rec_ippo + - arch: sebulba + - system: ppo/rec_ippo + - network: rnn + - env: gym + - _self_ diff --git a/mava/configs/system/ppo/ff_ippo.yaml b/mava/configs/system/ppo/ff_ippo.yaml index 0c93c2683..c80b43ec8 100644 --- a/mava/configs/system/ppo/ff_ippo.yaml +++ b/mava/configs/system/ppo/ff_ippo.yaml @@ -2,15 +2,15 @@ total_timesteps: ~ # Set the total environment steps. # If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. -num_updates: 12 # Number of updates +num_updates: 1000 # Number of updates seed: 42 # --- Agent observations --- add_agent_id: True # --- RL hyperparameters --- -actor_lr: 1.0e-3 # Learning rate for actor network -critic_lr: 1.0e-3 # Learning rate for critic network +actor_lr: 0.0005 # Learning rate for actor network +critic_lr: 0.0005 # Learning rate for critic network update_batch_size: 2 # Number of vectorised gradient updates per device. rollout_length: 128 # Number of environment steps per vectorised environment. ppo_epochs: 4 # Number of ppo epochs per training data batch. diff --git a/mava/evaluator.py b/mava/evaluator.py index f44a8d55b..ca0c8c9a7 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -145,7 +145,7 @@ def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOut return evaluator_fn -def get_rnn_evaluator_fn( +def get_anakin_rnn_evaluator_fn( env: Environment, apply_fn: RecActorApply, config: DictConfig, @@ -314,14 +314,14 @@ def make_anakin_eval_fns( # Vmap it over number of agents and create evaluator_fn. if use_recurrent_net: assert scanned_rnn is not None - evaluator = get_rnn_evaluator_fn( + evaluator = get_anakin_rnn_evaluator_fn( eval_env, network_apply_fn, # type: ignore config, scanned_rnn, log_win_rate, ) - absolute_metric_evaluator = get_rnn_evaluator_fn( + absolute_metric_evaluator = get_anakin_rnn_evaluator_fn( eval_env, network_apply_fn, # type: ignore config, @@ -374,9 +374,10 @@ def get_action( #todo explicetly put these on the learner? they should already b return action def eval_episodes(params: FrozenDict, key : chex.PRNGKey) -> Dict: - dones = np.zeros(env.num_envs) # todo: jnp or np? + obs, info = env.reset() + dones = np.zeros(env.num_envs) # todo: jnp or np? eval_metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) while not dones.all(): @@ -405,6 +406,81 @@ def eval_episodes(params: FrozenDict, key : chex.PRNGKey) -> Dict: return eval_episodes +def get_sebulba_rnn_evaluator_fn( + env: Environment, + apply_fn: RecActorApply, + config: DictConfig, + scanned_rnn: nn.Module, + log_win_rate: bool = False, +) -> EvalFn: + """Get the evaluator function for feedforward networks. + + Args: + env (Environment): An evironment instance for evaluation. + apply_fn (callable): Network forward pass method. + config (dict): Experiment configuration. + """ + @jax.jit + def get_action( #todo explicetly put these on the learner? they should already be there + params: FrozenDict, + observation: Observation, + hstate : chex.Array, + key: chex.PRNGKey, + ) -> Tuple: + """Get action.""" + + hstate, pi = apply_fn(params, hstate, observation) + + if config.arch.evaluation_greedy: + action = pi.mode() + else: + action = pi.sample(seed=key) + + return action, hstate + def eval_episodes(params: FrozenDict, key : chex.PRNGKey) -> Dict: + + + + obs, info = env.reset() + eval_metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) + + hstate = scanned_rnn.initialize_carry( + (env.num_envs, config.system.num_agents), config.network.hidden_state_dim + ) + + dones = jnp.zeros((env.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) + + while not dones.all(): + + key, policy_key = jax.random.split(key) + + obs = jax.device_put(jnp.stack(obs, axis = 1)) + action_mask = jax.device_put(np.stack(info["actions_mask"]) ) + + obs, action_mask, dones = jax.tree_map(lambda x : x[jnp.newaxis, :], (obs, action_mask, dones)) + + + actions, hstate = get_action(params, (Observation(obs, action_mask), dones), hstate, policy_key) + cpu_action = jax.device_get(actions) + + obs, reward, terminated, truncated, info = env.step(cpu_action[0].swapaxes(0,1)) + + next_metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) + + next_dones = np.logical_or(terminated, truncated) + + per_env_done = np.all(np.logical_and(next_dones, dones[0] == False),axis = 1) + + update_metric = lambda old_metric, new_metric : np.where(per_env_done, new_metric, old_metric) + eval_metrics = jax.tree_map(update_metric, eval_metrics, next_metrics) + + dones = np.logical_or(dones, next_dones) + eval_metrics.pop("is_terminal_step") + + return eval_metrics + + return eval_episodes + def make_sebulba_eval_fns( eval_env_fn: callable, @@ -438,14 +514,14 @@ def make_sebulba_eval_fns( # Vmap it over number of agents and create evaluator_fn. if use_recurrent_net: assert scanned_rnn is not None - evaluator = get_rnn_evaluator_fn( + evaluator = get_sebulba_rnn_evaluator_fn( eval_env, network_apply_fn, # type: ignore config, scanned_rnn, log_win_rate, ) - absolute_metric_evaluator = get_rnn_evaluator_fn( + absolute_metric_evaluator = get_sebulba_rnn_evaluator_fn( absolute_eval_env, network_apply_fn, # type: ignore config, diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 30e5bacbf..153f9e4a9 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -74,7 +74,7 @@ def get_action_and_value( """Get action and value.""" key, subkey = jax.random.split(key) - actor_policy = actor_apply_fn(params.actor_params, observation) + actor_policy = actor_apply_fn(params.actor_params, observation) # TODO: check vmapiing action = actor_policy.sample(seed=subkey) log_prob = actor_policy.log_prob(action) @@ -114,6 +114,7 @@ def get_action_and_value( cached_next_dones = move_to_device(next_dones) # (num_envs, num_agents) cashed_action_mask = move_to_device(np.stack(info["actions_mask"])) # (num_envs, num_agents, num_actions) + full_observation = Observation(cached_next_obs, cashed_action_mask) # Get action and value inference_time_start = time.time() ( @@ -121,7 +122,7 @@ def get_action_and_value( log_prob, value, key, - ) = get_action_and_value(params, Observation(cached_next_obs, cashed_action_mask), key) + ) = get_action_and_value(params, full_observation, key) # Step the environment @@ -144,7 +145,7 @@ def get_action_and_value( value=value, reward=next_reward, log_prob=log_prob, - obs=Observation(cached_next_obs, cashed_action_mask), + obs=full_observation, info=metrics, ) ) @@ -206,7 +207,7 @@ def get_learner_fn( actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: + def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: Observation, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -421,7 +422,7 @@ def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs - timesteps (TimeStep): The initial timestep in the initial trajectory. """ - + # todo: add update_batch_size learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch , last_obs, last_dones) return ExperimentOutput( diff --git a/mava/systems/sebulba/ppo/ff_mappo.py b/mava/systems/sebulba/ppo/ff_mappo.py index 5f84fd0d0..66d4174bf 100644 --- a/mava/systems/sebulba/ppo/ff_mappo.py +++ b/mava/systems/sebulba/ppo/ff_mappo.py @@ -210,7 +210,7 @@ def get_learner_fn( actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: + def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: ObservationGlobalState, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -749,7 +749,7 @@ def run_experiment(_config: DictConfig) -> float: -@hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") +@hydra.main(config_path="../../../configs", config_name="default_ff_mappo_seb.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/mava/systems/sebulba/ppo/rec_ippo.py b/mava/systems/sebulba/ppo/rec_ippo.py new file mode 100644 index 000000000..6e204fb21 --- /dev/null +++ b/mava/systems/sebulba/ppo/rec_ippo.py @@ -0,0 +1,850 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time +from typing import Any, Dict, Tuple, List +import threading +import chex +import flax +import hydra +import jax +import jax.debug +import jax.numpy as jnp +import numpy as np +import optax +import queue +from collections import deque +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from omegaconf import DictConfig, OmegaConf +from optax._src.base import OptState +from rich.pretty import pprint + +from mava.evaluator import make_sebulba_eval_fns as make_eval_fns +from mava.networks import RecurrentActor as Actor +from mava.networks import RecurrentValueNet as Critic +from mava.networks import ScannedRNN +from mava.systems.anakin.ppo.types import ( + HiddenStates, + OptStates, + Params, + RNNLearnerState, + RNNPPOTransition, +) +from mava.types import ExperimentOutput, LearnerFn, RecActorApply, RecCriticApply, RNNObservation, Observation +from mava.utils import make_env as environments +from mava.utils.checkpointing import Checkpointer +from mava.utils.jax_utils import ( + merge_leading_dims, + unreplicate_batch_dim, + unreplicate_n_dims, +) +from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.total_timestep_checker import sebulba_check_total_timesteps +from mava.utils.training import make_learning_rate +from mava.wrappers.episode_metrics import get_final_step_metrics + + +def rollout( + key: chex.PRNGKey, + config: DictConfig, + rollout_queue: queue.Queue, + params_queue: queue.Queue, + apply_fns: Tuple, + learner_devices: List, + actor_device_id : int, + init_hstates : HiddenStates): + + #setup + + env = environments.make_gym_env(config, config.arch.num_envs) + current_actor_device = jax.devices()[actor_device_id] + actor_apply_fn, critic_apply_fn = apply_fns + + # Define the util functions: select action function and prepare data to share it with learner. + @jax.jit + def get_action_and_value( + params: FrozenDict, + observation: RNNObservation, + last_hstates : HiddenStates, + key: chex.PRNGKey, + ) -> Tuple: + """Get action and value.""" + key, subkey = jax.random.split(key) + + policy_hidden_state, actor_policy = actor_apply_fn(params.actor_params, last_hstates.policy_hidden_state, observation) + action = actor_policy.sample(seed=subkey) + log_prob = actor_policy.log_prob(action) + + critic_hidden_state, value = critic_apply_fn(params.critic_params, last_hstates.critic_hidden_state, observation) + hastates = HiddenStates(policy_hidden_state, critic_hidden_state) + return action, log_prob, value, key, hastates + + # Define queues to track time + params_queue_get_time: deque = deque(maxlen=1) + rollout_time: deque = deque(maxlen=1) + rollout_queue_put_time: deque = deque(maxlen=1) + + next_obs , info = env.reset() + next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) + next_hstates = init_hstates + move_to_device = lambda x : jax.device_put(x, device = current_actor_device) + + # Loop till the learner has finished training + for update in range(config.system.num_updates): + inference_time: float = 0 + storage_time: float = 0 + env_send_time: float = 0 + + # Get the latest parameters from the learner + params_queue_get_time_start = time.time() + params = params_queue.get() + params_queue_get_time.append(time.time() - params_queue_get_time_start) + + # Rollout + rollout_time_start = time.time() + storage: List = [] + + # Loop over the rollout length + for _ in range(0, config.system.rollout_length): + + # Cached for transition + cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) # (num_envs, num_agents, ...) + cached_next_dones = move_to_device(next_dones) # (num_envs, num_agents) + cashed_action_mask = move_to_device(np.stack(info["actions_mask"])) # (num_envs, num_agents, num_actions) + + # Add the sequence_len dim + cached_next_obs, cached_next_dones, cashed_action_mask = jax.tree_map(lambda x: x[jnp.newaxis, : ], (cached_next_obs, cached_next_dones, cashed_action_mask)) + + full_observation = Observation(cached_next_obs, cashed_action_mask) + full_observation_dones = (full_observation, cached_next_dones) + cashed_next_hstate = move_to_device(next_hstates) + # Get action and value + inference_time_start = time.time() + ( + action, + log_prob, + value, + key, + next_hstates + ) = get_action_and_value(params, full_observation_dones, cashed_next_hstate, key) + + + # Step the environment + inference_time += time.time() - inference_time_start + env_send_time_start = time.time() + cpu_action = jax.device_get(action) + next_obs, next_reward, terminated, truncated, info = env.step(cpu_action[0].swapaxes(0,1)) # (num_env, num_agents) --> (num_agents, num_env) + env_send_time += time.time() - env_send_time_start + + # Prepare the data + storage_time_start = time.time() + next_dones = np.logical_or(terminated, truncated) + metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics + + # Append data to storage + storage.append( + RNNPPOTransition( + done=cached_next_dones[0], + action=action[0], + value=value[0], + reward=next_reward, + log_prob=log_prob[0], + obs=Observation(cached_next_obs[0], cashed_action_mask[0]), + hstates=cashed_next_hstate, + info=metrics, + ) + ) + storage_time += time.time() - storage_time_start + rollout_time.append(time.time() - rollout_time_start) + + parse_timer = time.time() + + # Prepare data to share with learner + #[PPOTransition() * rollout_len] --> PPOTransition[done = (rollout_len, num_envs, num_agents), action = (rollout_len, num_envs, num_agents, num_actions), ...] + stacked_storage = jax.tree_map( lambda *xs : jnp.stack(xs), *storage) + + # Split the arrays over the different learner_devices on the num_envs axis + shard_split_payload= lambda x, axis : jax.device_put_sharded(jnp.split(x, len(learner_devices), axis=axis), devices=learner_devices) + + sharded_storage = jax.tree_map(lambda x : shard_split_payload(x, 1) , stacked_storage) # (num_learner_devices, rollout_len, num_envs, num_agents, ...) + + # (num_learner_devices, num_envs, num_agents, ...) + sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis = 1), 0) + sharded_next_action_mask = shard_split_payload(np.stack(info["actions_mask"]), 0) + sharded_next_done = shard_split_payload(next_dones, 0) + sharded_next_hstate = jax.tree_map( lambda x: shard_split_payload(x,0), next_hstates) + + # Pack the obs and action mask + payload_obs_dones = (Observation(sharded_next_obs, sharded_next_action_mask), cached_next_dones) + + # For debugging + speed_info = { + "rollout_time": np.mean(rollout_time), + "params_queue_get_time": np.mean(params_queue_get_time), + "action_inference": inference_time, + "storage_time": storage_time, + "env_step_time": env_send_time, + "rollout_queue_put_time": np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0, + "parse_time" : time.time() - parse_timer, + } + #print(speed_info) + + payload = ( + sharded_storage, + payload_obs_dones, + sharded_next_done, + sharded_next_hstate + ) + + # Put data in the rollout queue to share it with the learner + rollout_queue_put_time_start = time.time() + rollout_queue.put(payload) + rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) + + +def get_learner_fn( + apply_fns: Tuple[ RecActorApply, RecCriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> LearnerFn[RNNLearnerState]: + """Get the learner function.""" + + # Get apply and update functions for actor and critic networks. + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def _update_step(learner_state: RNNLearnerState, traj_batch : RNNPPOTransition, last_obs: RNNObservation, last_dones : chex.Array, last_hstate : HiddenStates) -> Tuple[RNNLearnerState, Tuple]: + """A single update of the network. + + This function steps the environment and records the trajectory batch for + training. It then calculates advantages and targets based on the recorded + trajectory and updates the actor and critic networks based on the calculated + losses. + + Args: + learner_state (NamedTuple): + - params (Params): The current model parameters. + - opt_states (OptStates): The current optimizer states. + - key (PRNGKey): The random number generator state. + - env_state (State): The environment state. + - last_timestep (TimeStep): The last timestep in the current trajectory. + _ (Any): The current metrics info. + """ + + def _calculate_gae( #todo: lake sure this is appropriate + traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array + ) -> Tuple[chex.Array, chex.Array]: + def _get_advantages( + carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition + ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: + gae, next_value, next_done = carry + done, value, reward = transition.done, transition.value, transition.reward + gamma = config.system.gamma + delta = reward + gamma * next_value * (1 - next_done) - value + gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae + return (gae, value, done), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val, last_done), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + # CALCULATE ADVANTAGE + params, opt_states, key, _, _, _, _ = learner_state + last_obs = jax.tree_map(lambda x: x[jnp.newaxis, : ], last_obs) + last_dones = last_dones[jnp.newaxis, :] + + + _, last_val = critic_apply_fn(params.critic_params, last_hstate.critic_hidden_state, last_obs) + + advantages, targets = _calculate_gae(traj_batch, last_val[0], last_dones[0]) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + + # UNPACK TRAIN STATE AND BATCH INFO + params, opt_states, key = train_state + traj_batch, advantages, targets = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + actor_opt_state: OptState, + traj_batch: RNNPPOTransition, + gae: chex.Array, + key: chex.PRNGKey, + ) -> Tuple: + """Calculate the actor loss.""" + # RERUN NETWORK + + obs_and_done = (traj_batch.obs, traj_batch.done) + _, actor_policy = actor_apply_fn( + actor_params, traj_batch.hstates.policy_hidden_state[0], obs_and_done + ) + log_prob = actor_policy.log_prob(traj_batch.action) + + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config.system.clip_eps, + 1.0 + config.system.clip_eps, + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + # The seed will be used in the TanhTransformedDistribution: + entropy = actor_policy.entropy(seed=key).mean() + + total_loss = loss_actor - config.system.ent_coef * entropy + return total_loss, (loss_actor, entropy) + + def _critic_loss_fn( + critic_params: FrozenDict, + critic_opt_state: OptState, + traj_batch: RNNPPOTransition, + targets: chex.Array, + ) -> Tuple: + """Calculate the critic loss.""" + # RERUN NETWORK + obs_and_done = (traj_batch.obs, traj_batch.done) + _, value = critic_apply_fn( + critic_params, traj_batch.hstates.critic_hidden_state[0], obs_and_done + ) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( + -config.system.clip_eps, config.system.clip_eps + ) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + + total_loss = config.system.vf_coef * value_loss + return total_loss, (value_loss) + + # CALCULATE ACTOR LOSS + key, entropy_key = jax.random.split(key) + actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) + actor_loss_info, actor_grads = actor_grad_fn( + params.actor_params, + opt_states.actor_opt_state, + traj_batch, + advantages, + entropy_key, + ) + + # CALCULATE CRITIC LOSS + critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) + critic_loss_info, critic_grads = critic_grad_fn( + params.critic_params, opt_states.critic_opt_state, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # pmean over devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="device" + ) + # pmean over devices. + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="device" + ) + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # UPDATE CRITIC PARAMS AND OPTIMISER STATE + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + new_params = Params(actor_new_params, critic_new_params) + new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) + + # PACK LOSS INFO + total_loss = actor_loss_info[0] + critic_loss_info[0] + value_loss = critic_loss_info[1] + actor_loss = actor_loss_info[1][0] + entropy = actor_loss_info[1][1] + loss_info = { + "total_loss": total_loss, + "value_loss": value_loss, + "actor_loss": actor_loss, + "entropy": entropy, + } + + return (new_params, new_opt_state, entropy_key), loss_info + + params, opt_states, traj_batch, advantages, targets, key = update_state + key, shuffle_key, entropy_key = jax.random.split(key, 3) + + # SHUFFLE MINIBATCHES + batch = (traj_batch, advantages, targets) + num_recurrent_chunks = ( + config.system.rollout_length // config.system.recurrent_chunk_size + ) + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + config.system.recurrent_chunk_size, + config.arch.num_envs * num_recurrent_chunks, + *x.shape[2:], + ), + batch, + ) + permutation = jax.random.permutation( + shuffle_key, config.arch.num_envs * num_recurrent_chunks + ) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=1), batch + ) + reshaped_batch = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, (x.shape[0], config.system.num_minibatches, -1, *x.shape[2:]) + ), + shuffled_batch, + ) + minibatches = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch) + + # UPDATE MINIBATCHES + (params, opt_states, entropy_key), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states, entropy_key), minibatches + ) + + update_state = ( + params, + opt_states, + traj_batch, + advantages, + targets, + key, + ) + return update_state, loss_info + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.ppo_epochs + ) + + params, opt_states, traj_batch, advantages, targets, key = update_state + learner_state = RNNLearnerState(params, opt_states, key, None, None, None, None) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn(learner_state: RNNLearnerState, traj_batch : RNNPPOTransition, last_obs: chex.Array, last_dones : chex.Array, last_hstate : chex.Array) -> ExperimentOutput[RNNLearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + Args: + learner_state (NamedTuple): + - params (Params): The initial model parameters. + - opt_states (OptStates): The initial optimizer state. + - key (chex.PRNGKey): The random number generator state. + - env_state (LogEnvState): The environment state. + - timesteps (TimeStep): The initial timestep in the initial trajectory. + """ + + + learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch , last_obs, last_dones, last_hstate) + + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_setup( + keys: chex.Array, config: DictConfig, learner_devices: List +) -> Tuple[LearnerFn[RNNLearnerState], Actor, RNNLearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + + #create temporory envoirnments. + env = environments.make_gym_env(config, 1) + # Get number of agents and actions. + action_space = env.single_action_space + config.system.num_agents = len(action_space) + config.system.num_actions = action_space[0].n + + # PRNG keys. + key, actor_net_key, critic_net_key = keys + + # Define network and optimisers. + actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) + actor_action_head = hydra.utils.instantiate( + config.network.action_head, action_dim=config.system.num_actions + ) + critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) + + actor_network = Actor( + pre_torso=actor_pre_torso, + post_torso=actor_post_torso, + action_head=actor_action_head, + hidden_state_dim=config.network.hidden_state_dim, + ) + critic_network = Critic( + pre_torso=critic_pre_torso, + post_torso=critic_post_torso, + hidden_state_dim=config.network.hidden_state_dim, + ) + + actor_lr = make_learning_rate(config.system.actor_lr, config) + critic_lr = make_learning_rate(config.system.critic_lr, config) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(critic_lr, eps=1e-5), + ) + + # Initialise observation: Select only obs for a single agent. + init_obs = jnp.array([[env.single_observation_space.sample()]]) + init_action_mask = jnp.ones((config.system.num_agents, config.system.num_actions)) + init_dones = jnp.zeros((1, 1, config.system.num_agents), dtype=jax.numpy.bool_) + init_x = (Observation(init_obs, init_action_mask), init_dones) + + # Initialise hidden states. + init_policy_hstate = ScannedRNN.initialize_carry( + (config.arch.num_envs, config.system.num_agents), config.network.hidden_state_dim + ) + init_critic_hstate = ScannedRNN.initialize_carry( + (config.arch.num_envs, config.system.num_agents), config.network.hidden_state_dim + ) + + # initialise params and optimiser state. + actor_params = actor_network.init(actor_net_key, init_policy_hstate, init_x) + actor_opt_state = actor_optim.init(actor_params) + critic_params = critic_network.init(critic_net_key, init_critic_hstate, init_x) + critic_opt_state = critic_optim.init(critic_params) + + # Get network apply functions and optimiser updates. + apply_fns = (actor_network.apply, critic_network.apply) + update_fns = (actor_optim.update, critic_optim.update) + + # Get batched iterated update and replicate it to pmap it over learner cores. + learn = get_learner_fn(apply_fns, update_fns, config) + learn = jax.pmap(learn, axis_name="device", devices = learner_devices) + + # Pack params and initial states. + params = Params(actor_params, critic_params) + hstates = HiddenStates(init_policy_hstate, init_critic_hstate) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, restored_hstates = loaded_checkpoint.restore_params( + input_params=params, restore_hstates=True, THiddenState=HiddenStates + ) + # Update the params and hstates + params = restored_params + hstates = restored_hstates if restored_hstates else hstates + + # Define params to be replicated across devices and batches. + key, step_keys = jax.random.split(key) + opt_states = OptStates(actor_opt_state, critic_opt_state) + replicate_learner = (params, opt_states, hstates, step_keys) + + # Duplicate learner across Learner devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=learner_devices) + + # Initialise learner state. + params, opt_states, hstates, step_keys = replicate_learner + init_learner_state = RNNLearnerState(params, opt_states, step_keys, None, None, init_dones, hstates) + env.close() + + return learn, apply_fns, init_learner_state + + +def run_experiment(_config: DictConfig) -> float: + """Runs experiment.""" + config = copy.deepcopy(_config) + + devices = jax.devices() + learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] + + # PRNG keys. + key, key_e, actor_net_key, critic_net_key = jax.random.split( + jax.random.PRNGKey(config.system.seed), num=4 + ) + + # Sanity check of config + if config.system.recurrent_chunk_size is None: + config.system.recurrent_chunk_size = config.system.rollout_length + else: + assert ( + config.system.rollout_length % config.system.recurrent_chunk_size == 0 + ), "Rollout length must be divisible by recurrent chunk size." + assert ( + config.arch.num_envs % len(config.arch.learner_device_ids) == 0 + ), "The number of environments must to be divisible by the number of learners " + + assert ( + int(config.arch.num_envs / len(config.arch.learner_device_ids)) + * config.arch.n_threads_per_executor + % config.system.num_minibatches + == 0 + ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" + + + # Setup learner. + learn, apply_fns , learner_state = learner_setup( + (key ,actor_net_key, critic_net_key), config, learner_devices + ) + + # Setup evaluator. + # One key per device for evaluation. + evaluator, absolute_metric_evaluator = make_eval_fns(environments.make_gym_env, apply_fns[0], config,use_recurrent_net = True, scanned_rnn = ScannedRNN) #todo: make this more generic + + # Calculate total timesteps. + config = sebulba_check_total_timesteps(config) + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Calculate number of updates per evaluation. + config.system.num_updates_per_eval, remaining_updates = divmod(config.system.num_updates , config.arch.num_evaluation) + config.arch.num_evaluation += (remaining_updates != 0) # Add an evaluation step if the num_updates is not a multiple of num_evaluation + steps_per_rollout = ( + len(config.arch.executor_device_ids) + * config.arch.n_threads_per_executor + * config.system.rollout_length + * config.arch.num_envs + * config.system.num_updates_per_eval + ) + + # Logger setup + logger = MavaLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Executor setup and launch. + unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) + unreplicated_hstates = flax.jax_utils.unreplicate(learner_state.hstates) + params_queues: List = [] + rollout_queues: List = [] + for d_idx, d_id in enumerate( # Loop through each executor device + config.arch.executor_device_ids + ): + # Replicate params per executor device + device_params = jax.device_put(unreplicated_params, devices[d_id]) + device_hstates = jax.device_put(unreplicated_hstates, devices[d_id]) + # Loop through each executor thread + for thread_id in range(config.arch.n_threads_per_executor): + params_queues.append(queue.Queue(maxsize=1)) + rollout_queues.append(queue.Queue(maxsize=1)) + params_queues[-1].put(device_params) + threading.Thread( + target=rollout, + args=( + jax.device_put(key, devices[d_id]), + config, + rollout_queues[-1], + params_queues[-1], + apply_fns, + learner_devices, + d_id, + device_hstates, + ), + ).start() #todo : Use a process instead of a thread? threads are limited by pything's GIL and they only run on a single core , processes have a bogger overhead (max num_env for optimal performance?) + + # Run experiment for the total number of updates. + max_episode_return = jnp.float32(0.0) + best_params = None + for eval_step in range(config.arch.num_evaluation): + training_start_time = time.time() + learner_speeds = [] + rollout_times = [] + + episode_metrics = [] + train_metrics = [] + + # Make sure that the + num_updates_in_eval = config.system.num_updates_per_eval if eval_step != config.arch.num_evaluation - 1 else remaining_updates + for update in range(num_updates_in_eval): + sharded_storages = [] + sharded_next_obss = [] + sharded_next_dones = [] + sharded_next_hstates = [] + + rollout_start_time = time.time() + # Loop through each executor device + for d_idx, _ in enumerate(config.arch.executor_device_ids): + # Loop through each executor thread + for thread_id in range(config.arch.n_threads_per_executor): + # Get data from rollout queue + ( + sharded_storage, + sharded_next_obs, + sharded_next_done, + sharded_next_hstate, + ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() + sharded_storages.append(sharded_storage) + sharded_next_obss.append(sharded_next_obs) + sharded_next_dones.append(sharded_next_done) + sharded_next_hstates.append(sharded_next_hstate) + + rollout_times.append(time.time() - rollout_start_time) + + + # Concatinate the returned trajectories on the n_env axis + sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) + sharded_next_obss = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 1), *sharded_next_obss) + sharded_next_hstates = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 1), *sharded_next_hstates) + + sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) + + learner_start_time = time.time() + learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_dones, sharded_next_hstates) + learner_speeds.append(time.time() - learner_start_time) + + # Stack the metrics + episode_metrics.append(learner_output.episode_metrics) + train_metrics.append(learner_output.train_metrics) + + # Send updated params to executors + unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) + for d_idx, d_id in enumerate(config.arch.executor_device_ids): + device_params = jax.device_put(unreplicated_params, devices[d_id]) + for thread_id in range(config.arch.n_threads_per_executor): + params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( + device_params + ) + + + + # Log the results of the training. + elapsed_time = time.time() - training_start_time + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics = jax.tree_map(lambda *x : np.asarray(x), *episode_metrics) + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + speed_info = {"total_time" : elapsed_time, "rollout_time" : np.sum(rollout_times), "learner_time" : np.sum(learner_speeds), "timestep" : t} + logger.log(speed_info , t, eval_step, LogEvent.MISC) + if ep_completed: # only log episode metrics if an episode was completed in the rollout. + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + train_metrics = jax.tree_map(lambda *x : np.asarray(x), *train_metrics) + logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) + + # Evaluation on the learner + evaluation_start_timer = time.time() + key_e, eval_key = jax.random.split(key_e, 2) + episode_metrics = evaluator(unreplicate_n_dims(learner_output.learner_state.params.actor_params, 1 ), eval_key) + + # Log the results of the evaluation. + elapsed_time = time.time() - evaluation_start_timer + episode_return = jnp.mean(episode_metrics["episode_return"]) + + steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) + episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(episode_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=steps_per_rollout * (eval_step + 1), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state, 1), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(learner_output.learner_state.params) + max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Record the performance for the final evaluation run. + eval_performance = float(jnp.mean(episode_metrics[config.env.eval_metric])) + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + key_e, eval_key = jax.random.split(key_e, 2) + episode_metrics = absolute_metric_evaluator(unreplicate_n_dims(best_params.actor_params, 1), eval_key) + + elapsed_time = time.time() - start_time + steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) + + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + + return eval_performance + + + +@hydra.main(config_path="../../../configs", config_name="default_rec_ippo_seb.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> float: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + eval_performance = run_experiment(cfg) + print(f"{Fore.CYAN}{Style.BRIGHT}IPPO experiment completed{Style.RESET_ALL}") + return eval_performance + + +if __name__ == "__main__": + hydra_entry_point() + +#learner_output.episode_metrics.keys() +#dict_keys(['episode_length', 'episode_return']) \ No newline at end of file From 627215d2943899fc6d8ed58cbbece640a21b1d39 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 10 Jul 2024 09:16:21 +0100 Subject: [PATCH 029/125] fix: removed the lbf import/wrapper --- mava/utils/make_env.py | 4 ++-- mava/wrappers/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index a9313bf64..df769d8c7 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -46,7 +46,7 @@ ConnectorWrapper, GigastepWrapper, GymRecordEpisodeMetrics, - GymGenericWrapper, + GymRwareWrapper, GymAgentIDWrapper, _multiagent_worker_shared_memory, LbfWrapper, @@ -72,7 +72,7 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} -_gym_registry = {"RobotWarehouse": GymGenericWrapper, "LevelBasedForaging" : GymGenericWrapper} +_gym_registry = {"RobotWarehouse": GymRwareWrapper} def add_extra_wrappers( diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 703d85279..4a4eb6ed0 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,7 +15,7 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper -from mava.wrappers.gym import GymRecordEpisodeMetrics, GymGenericWrapper, GymAgentIDWrapper, _multiagent_worker_shared_memory +from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper, GymAgentIDWrapper, _multiagent_worker_shared_memory from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, From c3b405dda78b59c6a5f948d5df1812917aac1edd Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 10 Jul 2024 09:27:46 +0100 Subject: [PATCH 030/125] chore: clean up & updated the code to match the sebulba-ff-ippo branch --- mava/configs/arch/sebulba.yaml | 11 +- mava/configs/env/gym.yaml | 1 + mava/systems/sebulba/ppo/test.py | 50 ------- mava/systems/sebulba/ppo/types.py | 100 -------------- mava/utils/make_env.py | 29 ++-- mava/wrappers/__init__.py | 4 +- mava/wrappers/gym.py | 213 ++++++++++++++++++++++-------- 7 files changed, 177 insertions(+), 231 deletions(-) delete mode 100644 mava/systems/sebulba/ppo/test.py delete mode 100644 mava/systems/sebulba/ppo/types.py diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 98cd4d96d..cbe3f4b52 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,6 +1,6 @@ # --- Sebulba config --- arch_name: "sebulba" -num_envs: 16 # number of envs per thread +num_envs: 32 # number of envs per thread # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select @@ -14,11 +14,4 @@ absolute_metric: True # Whether the absolute metric should be computed. For more # --- Sebulba devices config --- n_threads_per_executor: 1 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices -learner_device_ids: [0] # ids of learner devices - -# --- Sebulba rollout and env config --- -concurrency: False # whether actor and learner should run concurrently -async_envs: True # "whether to use async vector or sync vector envs" - -# --- To be defined during training --- -log_frequency: ~ +learner_device_ids: [0] # ids of learner devices \ No newline at end of file diff --git a/mava/configs/env/gym.yaml b/mava/configs/env/gym.yaml index ad8d16b9a..1e197a45e 100644 --- a/mava/configs/env/gym.yaml +++ b/mava/configs/env/gym.yaml @@ -15,6 +15,7 @@ implicit_agent_id: False # environments have a winrate metric. log_win_rate: False +# Weather or not to average the returned rewards over all of the agents. use_individual_rewards: True kwargs: diff --git a/mava/systems/sebulba/ppo/test.py b/mava/systems/sebulba/ppo/test.py deleted file mode 100644 index b868f69b6..000000000 --- a/mava/systems/sebulba/ppo/test.py +++ /dev/null @@ -1,50 +0,0 @@ - -import copy -import time -from typing import Any, Dict, Tuple, List -import threading -import chex -import flax -import hydra -import jax -import jax.numpy as jnp -import numpy as np -import optax -import queue -from collections import deque -from colorama import Fore, Style -from flax.core.frozen_dict import FrozenDict -from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState -from rich.pretty import pprint - -from mava.evaluator import make_eval_fns -from mava.networks import FeedForwardActor as Actor -from mava.networks import FeedForwardValueNet as Critic -from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition #todo: change this -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, Observation -from mava.utils import make_env as environments -from mava.utils.checkpointing import Checkpointer -from mava.utils.jax_utils import ( - merge_leading_dims, - unreplicate_batch_dim, - unreplicate_n_dims, -) -from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps -from mava.utils.training import make_learning_rate -from mava.wrappers.episode_metrics import get_final_step_metrics - - -@hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") -def hydra_entry_point(cfg: DictConfig) -> float: - """Experiment entry point.""" - # Allow dynamic attributes. - OmegaConf.set_struct(cfg, False) - - env = environments.make_gym_env(cfg) - a = env.reset() - print(a) - -if __name__ == "__main__": - hydra_entry_point() \ No newline at end of file diff --git a/mava/systems/sebulba/ppo/types.py b/mava/systems/sebulba/ppo/types.py deleted file mode 100644 index 6e02aa904..000000000 --- a/mava/systems/sebulba/ppo/types.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict - -import chex -from flax.core.frozen_dict import FrozenDict -from jumanji.types import TimeStep -from optax._src.base import OptState -from typing_extensions import NamedTuple - -from mava.types import Action, Done, HiddenState, State, Value - - -class Params(NamedTuple): - """Parameters of an actor critic network.""" - - actor_params: FrozenDict - critic_params: FrozenDict - - -class OptStates(NamedTuple): - """OptStates of actor critic learner.""" - - actor_opt_state: OptState - critic_opt_state: OptState - - -class HiddenStates(NamedTuple): - """Hidden states for an actor critic learner.""" - - policy_hidden_state: HiddenState - critic_hidden_state: HiddenState - - -class LearnerState(NamedTuple): - """State of the learner.""" - - params: Params - opt_states: OptStates - key: chex.PRNGKey - env_state: State - timestep: TimeStep - - -class RNNLearnerState(NamedTuple): - """State of the `Learner` for recurrent architectures.""" - - params: Params - opt_states: OptStates - key: chex.PRNGKey - env_state: State - timestep: TimeStep - dones: Done - hstates: HiddenStates - - -class PPOTransition(NamedTuple): - """Transition tuple for PPO.""" - - done: Done - action: Action - value: Value - reward: chex.Array - log_prob: chex.Array - obs: chex.Array - info: Dict - - -class RNNPPOTransition(NamedTuple): - """Transition tuple for PPO.""" - - done: Done - action: Action - value: Value - reward: chex.Array - log_prob: chex.Array - obs: chex.Array - hstates: HiddenStates - - -class Observation(NamedTuple): - """The observation that the agent sees. - agents_view: the agent's view of the environment. - action_mask: boolean array specifying, for each agent, which action is legal. - """ - - agents_view: chex.Array # (num_agents, num_obs_features) - action_mask: chex.Array # (num_agents, num_actions) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 69fc54623..a54cafff8 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -22,6 +22,7 @@ import jumanji import matrax from gigastep import ScenarioBuilder +import lbforaging from jaxmarl.environments.smax import map_name_to_scenario from jumanji.env import Environment from jumanji.environments.routing.cleaner.generator import ( @@ -46,6 +47,8 @@ GigastepWrapper, GymRecordEpisodeMetrics, GymRwareWrapper, + GymAgentIDWrapper, + _multiagent_worker_shared_memory, LbfWrapper, MabraxWrapper, MatraxWrapper, @@ -69,7 +72,7 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} -_gym_registry = {"rware": GymRwareWrapper} +_gym_registry = {"RobotWarehouse": GymRwareWrapper} def add_extra_wrappers( @@ -208,38 +211,38 @@ def make_gigastep_env( def make_gym_env( - config: DictConfig, add_global_state: bool = False, eval_env: bool = False + config: DictConfig, num_env : int, add_global_state: bool = False, ) -> Environment: # todo : create the appropriate annotation for the sync vector """ Create a Gym environment. Args: - env_name (str): The name of the environment to create. config (Dict): The configuration of the environment. + num_env (int) : The number of parallel envs to create. add_global_state (bool): Whether to add the global state to the observation. Default False. Returns: - A tuple of the environments. + Async environments. """ - base_env_name = config.env.scenario.split(":")[0] + base_env_name = config.env.env_name wrapper = _gym_registry[base_env_name] def create_gym_env( - config: DictConfig, add_global_state: bool = False, eval_env: bool = False + config: DictConfig, add_global_state: bool = False ) -> Environment: # todo: add the RecordEpisodeMetrics for gym. env = gym.make(config.env.scenario) - wrapped_env = wrapper(env, config.env.use_individual_rewards, add_global_state, eval_env) + wrapped_env = wrapper(env, config.env.use_individual_rewards, add_global_state) if not config.env.implicit_agent_id: - pass # todo : add agent id wrapper for gym . - env = GymRecordEpisodeMetrics(env) + wrapped_env = GymAgentIDWrapper(wrapped_env) # todo : add agent id wrapper for gym . + wrapped_env = GymRecordEpisodeMetrics(wrapped_env) return wrapped_env - num_env = config.arch.num_envs envs = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names [ - lambda: create_gym_env(config, add_global_state, eval_env=eval_env) + lambda: create_gym_env(config, add_global_state) for _ in range(num_env) - ] + ], + worker=_multiagent_worker_shared_memory ) return envs @@ -267,4 +270,4 @@ def make(config: DictConfig, add_global_state: bool = False) -> Tuple[Environmen elif env_name in _gigastep_registry: return make_gigastep_env(env_name, config, add_global_state) else: - raise ValueError(f"{env_name} is not a supported environment.") + raise ValueError(f"{env_name} is not a supported environment.") \ No newline at end of file diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index e888d9317..151a1c509 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,7 +15,7 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper -from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper +from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper, GymAgentIDWrapper, _multiagent_worker_shared_memory from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, @@ -24,4 +24,4 @@ RwareWrapper, ) from mava.wrappers.matrax import MatraxWrapper -from mava.wrappers.observation import AgentIDWrapper +from mava.wrappers.observation import AgentIDWrapper \ No newline at end of file diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 69632f1bc..041916680 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -13,17 +13,21 @@ # limitations under the License. import warnings -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional import gym import numpy as np from numpy.typing import NDArray +from gym import spaces +from gym.vector.utils import write_to_shared_memory +import sys + # Filter out the warnings warnings.filterwarnings("ignore", module="gym.utils.passive_env_checker") -class GymRwareWrapper(gym.Wrapper): +class GymRwareWrapper(gym.Wrapper): """Wrapper for rware gym environments""" def __init__( @@ -31,7 +35,6 @@ def __init__( env: gym.Env, use_individual_rewards: bool = False, add_global_state: bool = False, - eval_env: bool = False, ): """Initialize the gym wrapper @@ -40,109 +43,205 @@ def __init__( use_individual_rewards (bool, optional): Use individual or group rewards. Defaults to False. add_global_state (bool, optional) : Create global observations. Defaults to False. - eval_env (bool, optional): Weather the instance is used for training or evaluation. - Defaults to False. """ super().__init__(env) - self._env = gym.wrappers.compatibility.EnvCompatibility(env) + self._env = env self.use_individual_rewards = use_individual_rewards - self.add_global_state = add_global_state # todo : add the global observations - self.eval_env = eval_env + self.add_global_state = add_global_state self.num_agents = len(self._env.action_space) self.num_actions = self._env.action_space[ 0 - ].n # todo: all the agents must have the same num_actions, add assertion? - - def reset(self) -> Tuple: - (agents_view, info), _ = self._env.reset( - seed=np.random.randint(1) - ) # todo: assure reproducibility, this only works for rware - - info = {"actions_mask": self._get_actions_mask(info)} - + ].n + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple: + + if seed is not None: + self.env.seed(seed) + + agents_view, info = self._env.reset() + + info = {"actions_mask": self.get_actions_mask(info)} + if self.add_global_state: + info["global_obs"] = self.get_global_obs(agents_view) + return np.array(agents_view), info - def step(self, actions: NDArray) -> Tuple: + def step(self, actions: NDArray) -> Tuple: - agents_view, reward, terminated, truncated, info = self.env.step(actions) + agents_view, reward, terminated, truncated, info = self._env.step(actions) - done = np.logical_or(terminated, truncated).all() - - if ( - done and not self.eval_env - ): # only auto-reset in training envs, same functionality as the AutoResetWrapper. - agents_view, info = self.reset() - reward = np.zeros(self.num_agents) - terminated, truncated = np.zeros(self.num_agents, dtype=bool), np.zeros( - self.num_agents, dtype=bool - ) - return agents_view, reward, terminated, truncated, info - - info = {"actions_mask": self._get_actions_mask(info)} + info = {"actions_mask": self.get_actions_mask(info)} + if self.add_global_state: + info["global_obs"] = self.get_global_obs(agents_view) if self.use_individual_rewards: reward = np.array(reward) else: reward = np.array([np.array(reward).mean()] * self.num_agents) - + return agents_view, reward, terminated, truncated, info - def _get_actions_mask(self, info: Dict) -> NDArray: + def get_actions_mask(self, info: Dict) -> NDArray: if "action_mask" in info: return np.array(info["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + def get_global_obs(self, obs: NDArray): + global_obs = np.concatenate(obs, axis=0) + return np.tile(global_obs, (self.num_agents, 1)) + class GymRecordEpisodeMetrics(gym.Wrapper): """Record the episode returns and lengths.""" def __init__(self, env: gym.Env): super().__init__(env) + self._env = env self.running_count_episode_return = 0.0 - self.running_count_episode_length = 0 + self.running_count_episode_length = 0.0 def reset(self) -> Tuple: # Reset the env - agents_view, info = self.env.reset() - - # Reset the metrics - self.running_count_episode_return = 0.0 - self.running_count_episode_length = 0 + agents_view, info = self._env.reset() # Create the metrics dict metrics = { "episode_return": self.running_count_episode_return, - "episode_length": self.self.running_count_episode_length, - "is_terminal_step": False, + "episode_length": self.running_count_episode_length, + "is_terminal_step": True, } + + # Reset the metrics + self.running_count_episode_return = 0.0 + self.running_count_episode_length = 0 + if "won_episode" in info: metrics["won_episode"] = info["won_episode"] + + info["metrics"] = metrics - return agents_view, metrics + return agents_view, info def step(self, actions: NDArray) -> Tuple: # Step the env - agents_view, reward, terminated, truncated, info = self.env.step(actions) - - # Update the metrics - done = np.logical_or(terminated, truncated).all() + agents_view, reward, terminated, truncated, info = self._env.step(actions) - if not done: - self.running_count_episode_return += float(np.mean(reward)) - self.running_count_episode_length += 1 - - else: - self.running_count_episode_return = 0.0 - self.running_count_episode_length = 0 + self.running_count_episode_return += float(np.mean(reward)) + self.running_count_episode_length += 1 metrics = { "episode_return": self.running_count_episode_return, - "episode_length": self.self.running_count_episode_length, - "is_terminal_step": False, + "episode_length": self.running_count_episode_length, + "is_terminal_step": False, # We handle the True case in the reset function since this gets overwritten } if "won_episode" in info: metrics["won_episode"] = info["won_episode"] + + info["metrics"] = metrics + + return agents_view, reward, terminated, truncated, info + +class GymAgentIDWrapper(gym.Wrapper): + """Add onehot agent IDs to observation.""" + + def __init__(self, env: gym.Env): + super().__init__(env) - return agents_view, reward, terminated, truncated, metrics + self.agent_ids = np.eye(self.env.num_agents) + observation_space = self.env.observation_space[0] + _obs_low, _obs_high, _obs_dtype, _obs_shape = ( + observation_space.low[0], + observation_space.high[0], + observation_space.dtype, + observation_space.shape, + ) + _new_obs_shape = (_obs_shape[0] + self.env.num_agents,) + _observation_boxs = [spaces.Box(low=_obs_low, high=_obs_high, shape=_new_obs_shape, dtype=_obs_dtype)] * self.env.num_agents + self.observation_space = spaces.Tuple(_observation_boxs) + + def reset(self) -> Tuple[np.ndarray, Dict]: + """Reset the environment.""" + obs, info = self.env.reset() + obs = np.concatenate([self.agent_ids, obs], axis=1) + return obs, info + + def step(self, action: list) -> Tuple[np.ndarray, float, bool, bool, Dict]: + """Step the environment.""" + obs, reward, terminated, truncated, info = self.env.step(action) + obs = np.concatenate([self.agent_ids, obs], axis=1) + return obs, reward, terminated, truncated, info + + +def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): + assert shared_memory is not None + env = env_fn() + observation_space = env.observation_space + parent_pipe.close() + try: + while True: + command, data = pipe.recv() + if command == "reset": + observation, info = env.reset(**data) + write_to_shared_memory( + observation_space, index, observation, shared_memory + ) + pipe.send(((None, info), True)) + + elif command == "step": + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + if np.logical_or(terminated, truncated).all(): + old_observation, old_info = observation, info + observation, info = env.reset() + info["final_observation"] = old_observation + info["final_info"] = old_info + write_to_shared_memory( + observation_space, index, observation, shared_memory + ) + pipe.send(((None, reward, terminated, truncated, info), True)) + elif command == "seed": + env.seed(data) + pipe.send((None, True)) + elif command == "close": + pipe.send((None, True)) + break + elif command == "_call": + name, args, kwargs = data + if name in ["reset", "step", "seed", "close"]: + raise ValueError( + f"Trying to call function `{name}` with " + f"`_call`. Use `{name}` directly instead." + ) + function = getattr(env, name) + if callable(function): + pipe.send((function(*args, **kwargs), True)) + else: + pipe.send((function, True)) + elif command == "_setattr": + name, value = data + setattr(env, name, value) + pipe.send((None, True)) + elif command == "_check_spaces": + pipe.send( + ((data[0] == observation_space, data[1] == env.action_space), True) + ) + else: + raise RuntimeError( + f"Received unknown command `{command}`. Must " + "be one of {`reset`, `step`, `seed`, `close`, `_call`, " + "`_setattr`, `_check_spaces`}." + ) + except (KeyboardInterrupt, Exception): + error_queue.put((index,) + sys.exc_info()[:2]) + pipe.send((None, False)) + finally: + env.close() \ No newline at end of file From e40c5d4e2fd2ea60104f5b48201856478f8df374 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 10 Jul 2024 10:01:34 +0100 Subject: [PATCH 031/125] chore : pre-commits and some comments --- mava/configs/arch/sebulba.yaml | 2 +- mava/utils/make_env.py | 18 ++++--- mava/wrappers/__init__.py | 9 +++- mava/wrappers/gym.py | 88 +++++++++++++++++----------------- 4 files changed, 61 insertions(+), 56 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index cbe3f4b52..b6a0a9699 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -14,4 +14,4 @@ absolute_metric: True # Whether the absolute metric should be computed. For more # --- Sebulba devices config --- n_threads_per_executor: 1 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices -learner_device_ids: [0] # ids of learner devices \ No newline at end of file +learner_device_ids: [0] # ids of learner devices diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index a54cafff8..5ee4e697c 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -22,7 +22,6 @@ import jumanji import matrax from gigastep import ScenarioBuilder -import lbforaging from jaxmarl.environments.smax import map_name_to_scenario from jumanji.env import Environment from jumanji.environments.routing.cleaner.generator import ( @@ -45,16 +44,16 @@ CleanerWrapper, ConnectorWrapper, GigastepWrapper, + GymAgentIDWrapper, GymRecordEpisodeMetrics, GymRwareWrapper, - GymAgentIDWrapper, - _multiagent_worker_shared_memory, LbfWrapper, MabraxWrapper, MatraxWrapper, RecordEpisodeMetrics, RwareWrapper, SmaxWrapper, + _multiagent_worker_shared_memory, ) # Registry mapping environment names to their generator and wrapper classes. @@ -211,7 +210,9 @@ def make_gigastep_env( def make_gym_env( - config: DictConfig, num_env : int, add_global_state: bool = False, + config: DictConfig, + num_env: int, + add_global_state: bool = False, ) -> Environment: # todo : create the appropriate annotation for the sync vector """ Create a Gym environment. @@ -238,11 +239,8 @@ def create_gym_env( return wrapped_env envs = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names - [ - lambda: create_gym_env(config, add_global_state) - for _ in range(num_env) - ], - worker=_multiagent_worker_shared_memory + [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)], + worker=_multiagent_worker_shared_memory, ) return envs @@ -270,4 +268,4 @@ def make(config: DictConfig, add_global_state: bool = False) -> Tuple[Environmen elif env_name in _gigastep_registry: return make_gigastep_env(env_name, config, add_global_state) else: - raise ValueError(f"{env_name} is not a supported environment.") \ No newline at end of file + raise ValueError(f"{env_name} is not a supported environment.") diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 151a1c509..ee8fdf186 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,7 +15,12 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper -from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper, GymAgentIDWrapper, _multiagent_worker_shared_memory +from mava.wrappers.gym import ( + GymAgentIDWrapper, + GymRecordEpisodeMetrics, + GymRwareWrapper, + _multiagent_worker_shared_memory, +) from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, @@ -24,4 +29,4 @@ RwareWrapper, ) from mava.wrappers.matrax import MatraxWrapper -from mava.wrappers.observation import AgentIDWrapper \ No newline at end of file +from mava.wrappers.observation import AgentIDWrapper diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 041916680..978ad4033 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -12,23 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import warnings -from typing import Dict, Tuple, Optional +from typing import Any, Callable, Dict, Optional, Tuple import gym import numpy as np -from numpy.typing import NDArray - from gym import spaces from gym.vector.utils import write_to_shared_memory -import sys +from numpy.typing import NDArray # Filter out the warnings warnings.filterwarnings("ignore", module="gym.utils.passive_env_checker") -class GymRwareWrapper(gym.Wrapper): - """Wrapper for rware gym environments""" +class GymRwareWrapper(gym.Wrapper): + """Wrapper for rware gym environments.""" def __init__( self, @@ -45,30 +44,26 @@ def __init__( add_global_state (bool, optional) : Create global observations. Defaults to False. """ super().__init__(env) - self._env = env + self._env = env self.use_individual_rewards = use_individual_rewards - self.add_global_state = add_global_state + self.add_global_state = add_global_state self.num_agents = len(self._env.action_space) - self.num_actions = self._env.action_space[ - 0 - ].n - - def reset( - self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple: - + self.num_actions = self._env.action_space[0].n + + def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple: + if seed is not None: self.env.seed(seed) - - agents_view, info = self._env.reset() + + agents_view, info = self._env.reset() info = {"actions_mask": self.get_actions_mask(info)} if self.add_global_state: info["global_obs"] = self.get_global_obs(agents_view) - + return np.array(agents_view), info - def step(self, actions: NDArray) -> Tuple: + def step(self, actions: NDArray) -> Tuple: agents_view, reward, terminated, truncated, info = self._env.step(actions) @@ -80,7 +75,7 @@ def step(self, actions: NDArray) -> Tuple: reward = np.array(reward) else: reward = np.array([np.array(reward).mean()] * self.num_agents) - + return agents_view, reward, terminated, truncated, info def get_actions_mask(self, info: Dict) -> NDArray: @@ -88,7 +83,7 @@ def get_actions_mask(self, info: Dict) -> NDArray: return np.array(info["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) - def get_global_obs(self, obs: NDArray): + def get_global_obs(self, obs: NDArray) -> NDArray: global_obs = np.concatenate(obs, axis=0) return np.tile(global_obs, (self.num_agents, 1)) @@ -113,14 +108,14 @@ def reset(self) -> Tuple: "episode_length": self.running_count_episode_length, "is_terminal_step": True, } - + # Reset the metrics self.running_count_episode_return = 0.0 self.running_count_episode_length = 0 - + if "won_episode" in info: metrics["won_episode"] = info["won_episode"] - + info["metrics"] = metrics return agents_view, info @@ -136,17 +131,18 @@ def step(self, actions: NDArray) -> Tuple: metrics = { "episode_return": self.running_count_episode_return, "episode_length": self.running_count_episode_length, - "is_terminal_step": False, # We handle the True case in the reset function since this gets overwritten + "is_terminal_step": False, } if "won_episode" in info: metrics["won_episode"] = info["won_episode"] - + info["metrics"] = metrics - + return agents_view, reward, terminated, truncated, info - + + class GymAgentIDWrapper(gym.Wrapper): - """Add onehot agent IDs to observation.""" + """Add one hot agent IDs to observation.""" def __init__(self, env: gym.Env): super().__init__(env) @@ -160,7 +156,9 @@ def __init__(self, env: gym.Env): observation_space.shape, ) _new_obs_shape = (_obs_shape[0] + self.env.num_agents,) - _observation_boxs = [spaces.Box(low=_obs_low, high=_obs_high, shape=_new_obs_shape, dtype=_obs_dtype)] * self.env.num_agents + _observation_boxs = [ + spaces.Box(low=_obs_low, high=_obs_high, shape=_new_obs_shape, dtype=_obs_dtype) + ] * self.env.num_agents self.observation_space = spaces.Tuple(_observation_boxs) def reset(self) -> Tuple[np.ndarray, Dict]: @@ -174,9 +172,18 @@ def step(self, action: list) -> Tuple[np.ndarray, float, bool, bool, Dict]: obs, reward, terminated, truncated, info = self.env.step(action) obs = np.concatenate([self.agent_ids, obs], axis=1) return obs, reward, terminated, truncated, info - -def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): + +# Copied form https://github.com/openai/gym/blob/master/gym/vector/async_vector_env.py +# Modified to work with multiple agents +def _multiagent_worker_shared_memory( # noqa: CCR001 + index: int, + env_fn: Callable[[], Any], + pipe: Any, + parent_pipe: Any, + shared_memory: Any, + error_queue: Any, +) -> None: assert shared_memory is not None env = env_fn() observation_space = env.observation_space @@ -186,9 +193,7 @@ def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_me command, data = pipe.recv() if command == "reset": observation, info = env.reset(**data) - write_to_shared_memory( - observation_space, index, observation, shared_memory - ) + write_to_shared_memory(observation_space, index, observation, shared_memory) pipe.send(((None, info), True)) elif command == "step": @@ -199,14 +204,13 @@ def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_me truncated, info, ) = env.step(data) + # Handel the dones across all of envs and agents if np.logical_or(terminated, truncated).all(): old_observation, old_info = observation, info observation, info = env.reset() info["final_observation"] = old_observation info["final_info"] = old_info - write_to_shared_memory( - observation_space, index, observation, shared_memory - ) + write_to_shared_memory(observation_space, index, observation, shared_memory) pipe.send(((None, reward, terminated, truncated, info), True)) elif command == "seed": env.seed(data) @@ -231,9 +235,7 @@ def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_me setattr(env, name, value) pipe.send((None, True)) elif command == "_check_spaces": - pipe.send( - ((data[0] == observation_space, data[1] == env.action_space), True) - ) + pipe.send(((data[0] == observation_space, data[1] == env.action_space), True)) else: raise RuntimeError( f"Received unknown command `{command}`. Must " @@ -244,4 +246,4 @@ def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_me error_queue.put((index,) + sys.exc_info()[:2]) pipe.send((None, False)) finally: - env.close() \ No newline at end of file + env.close() From 4b17c1539e187ec64b373a6723fb4feb1a226187 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 10 Jul 2024 10:09:30 +0100 Subject: [PATCH 032/125] chore: removed unused config file --- mava/configs/default_ff_ippo_seb.yaml | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 mava/configs/default_ff_ippo_seb.yaml diff --git a/mava/configs/default_ff_ippo_seb.yaml b/mava/configs/default_ff_ippo_seb.yaml deleted file mode 100644 index 1002d90c4..000000000 --- a/mava/configs/default_ff_ippo_seb.yaml +++ /dev/null @@ -1,7 +0,0 @@ -defaults: - - logger: ff_ippo - - arch: sebulba - - system: ppo/ff_ippo - - network: mlp - - env: gym - - _self_ From 9ec6b16db7ced8fe4953961c73ed29322db99760 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 10 Jul 2024 10:58:55 +0100 Subject: [PATCH 033/125] feat: sebulba ff_ippo --- mava/configs/default_ff_mappo_seb.yaml | 7 - mava/configs/default_rec_ippo_seb.yaml | 7 - mava/systems/sebulba/ppo/ff_mappo.py | 768 ---------------------- mava/systems/sebulba/ppo/orig.py | 795 ----------------------- mava/systems/sebulba/ppo/rec_ippo.py | 850 ------------------------- mava/systems/sebulba/ppo/test.py | 86 --- mava/wrappers/gym.py | 91 ++- 7 files changed, 44 insertions(+), 2560 deletions(-) delete mode 100644 mava/configs/default_ff_mappo_seb.yaml delete mode 100644 mava/configs/default_rec_ippo_seb.yaml delete mode 100644 mava/systems/sebulba/ppo/ff_mappo.py delete mode 100644 mava/systems/sebulba/ppo/orig.py delete mode 100644 mava/systems/sebulba/ppo/rec_ippo.py delete mode 100644 mava/systems/sebulba/ppo/test.py diff --git a/mava/configs/default_ff_mappo_seb.yaml b/mava/configs/default_ff_mappo_seb.yaml deleted file mode 100644 index 8d96d3e97..000000000 --- a/mava/configs/default_ff_mappo_seb.yaml +++ /dev/null @@ -1,7 +0,0 @@ -defaults: - - logger: ff_mappo - - arch: sebulba - - system: ppo/ff_mappo - - network: mlp - - env: gym - - _self_ diff --git a/mava/configs/default_rec_ippo_seb.yaml b/mava/configs/default_rec_ippo_seb.yaml deleted file mode 100644 index 61eaa95f1..000000000 --- a/mava/configs/default_rec_ippo_seb.yaml +++ /dev/null @@ -1,7 +0,0 @@ -defaults: - - logger: rec_ippo - - arch: sebulba - - system: ppo/rec_ippo - - network: rnn - - env: gym - - _self_ diff --git a/mava/systems/sebulba/ppo/ff_mappo.py b/mava/systems/sebulba/ppo/ff_mappo.py deleted file mode 100644 index 66d4174bf..000000000 --- a/mava/systems/sebulba/ppo/ff_mappo.py +++ /dev/null @@ -1,768 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import time -from typing import Any, Dict, Tuple, List -import threading -import chex -import flax -import hydra -import jax -import jax.debug -import jax.numpy as jnp -import numpy as np -import optax -import queue -from collections import deque -from colorama import Fore, Style -from flax.core.frozen_dict import FrozenDict -from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState -from rich.pretty import pprint - -from mava.evaluator import make_sebulba_eval_fns as make_eval_fns -from mava.networks import FeedForwardActor as Actor -from mava.networks import FeedForwardValueNet as Critic -from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition #todo: change this Observation to use the standard obs -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, ObservationGlobalState -from mava.utils import make_env as environments -from mava.utils.checkpointing import Checkpointer -from mava.utils.jax_utils import ( - merge_leading_dims, - unreplicate_batch_dim, - unreplicate_n_dims, -) -from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import sebulba_check_total_timesteps -from mava.utils.training import make_learning_rate -from mava.wrappers.episode_metrics import get_final_step_metrics - - -def rollout( - key: chex.PRNGKey, - config: DictConfig, - rollout_queue: queue.Queue, - params_queue: queue.Queue, - apply_fns: Tuple, - learner_devices: List, - actor_device_id : int): - - #setup - env = environments.make_gym_env(config, config.arch.num_envs, add_global_state=True) - current_actor_device = jax.devices()[actor_device_id] - actor_apply_fn, critic_apply_fn = apply_fns - - # Define the util functions: select action function and prepare data to share it with learner. - @jax.jit - def get_action_and_value( - params: FrozenDict, - observation: ObservationGlobalState, - key: chex.PRNGKey, - ) -> Tuple: - """Get action and value.""" - key, subkey = jax.random.split(key) - - actor_policy = actor_apply_fn(params.actor_params, observation) - action = actor_policy.sample(seed=subkey) - log_prob = actor_policy.log_prob(action) - - value = critic_apply_fn(params.critic_params, observation).squeeze() - return action, log_prob, value, key - - # Define queues to track time - params_queue_get_time: deque = deque(maxlen=1) - rollout_time: deque = deque(maxlen=1) - rollout_queue_put_time: deque = deque(maxlen=1) - - next_obs , info = env.reset() - next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) - - move_to_device = lambda x : jax.device_put(x, device = current_actor_device) - - # Loop till the learner has finished training - for update in range(config.system.num_updates): - inference_time: float = 0 - storage_time: float = 0 - env_send_time: float = 0 - - # Get the latest parameters from the learner - params_queue_get_time_start = time.time() - params = params_queue.get() - params_queue_get_time.append(time.time() - params_queue_get_time_start) - - # Rollout - rollout_time_start = time.time() - storage: List = [] - - # Loop over the rollout length - for _ in range(0, config.system.rollout_length): - - # Cached for transition - cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) # (num_envs, num_agents, ...) - cached_next_dones = move_to_device(next_dones) # (num_envs, num_agents) - cashed_action_mask = move_to_device(np.stack(info["actions_mask"])) # (num_envs, num_agents, num_actions) - cached_next_global_obs = move_to_device(np.stack(info["global_obs"])) - - - # Get action and value - full_observation = ObservationGlobalState(cached_next_obs, cashed_action_mask, cached_next_global_obs) - inference_time_start = time.time() - ( - action, - log_prob, - value, - key, - ) = get_action_and_value(params, full_observation , key) - - - # Step the environment - inference_time += time.time() - inference_time_start - env_send_time_start = time.time() - cpu_action = jax.device_get(action) - next_obs, next_reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) # (num_env, num_agents) --> (num_agents, num_env) - env_send_time += time.time() - env_send_time_start - - # Prepare the data - storage_time_start = time.time() - next_dones = np.logical_or(terminated, truncated) - metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics - - # Append data to storage - storage.append( - PPOTransition( - done=cached_next_dones, - action=action, - value=value, - reward=next_reward, - log_prob=log_prob, - obs=full_observation, - info=metrics, - ) - ) - storage_time += time.time() - storage_time_start - rollout_time.append(time.time() - rollout_time_start) - - parse_timer = time.time() - - # Prepare data to share with learner - #[PPOTransition() * rollout_len] --> PPOTransition[done = (rollout_len, num_envs, num_agents), action = (rollout_len, num_envs, num_agents, num_actions), ...] - stacked_storage = jax.tree_map( lambda *xs : jnp.stack(xs), *storage) - - - # Split the arrays over the different learner_devices on the num_envs axis - shard_split_payload= lambda x, axis : jax.device_put_sharded(jnp.split(x, len(learner_devices), axis=axis), devices=learner_devices) - - sharded_storage = jax.tree_map(lambda x : shard_split_payload(x, 1) , stacked_storage) # (num_learner_devices, rollout_len, num_envs, num_agents, ...) - - # (num_learner_devices, num_envs, num_agents, ...) - sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis = 1), 0) - sharded_next_action_mask = shard_split_payload(np.stack(info["actions_mask"]), 0) - sharded_next_global_obs = shard_split_payload(np.stack(info["global_obs"]), 0) - sharded_next_done = shard_split_payload(next_dones, 0) - - # Pack the obs and action mask - payload_obs = ObservationGlobalState(sharded_next_obs, sharded_next_action_mask, sharded_next_global_obs) - - # For debugging - speed_info = { - "rollout_time": np.mean(rollout_time), - "params_queue_get_time": np.mean(params_queue_get_time), - "action_inference": inference_time, - "storage_time": storage_time, - "env_step_time": env_send_time, - "rollout_queue_put_time": np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0, - "parse_time" : time.time() - parse_timer, - } - #print(speed_info) - - payload = ( - sharded_storage, - payload_obs, - sharded_next_done, - ) - - # Put data in the rollout queue to share it with the learner - rollout_queue_put_time_start = time.time() - rollout_queue.put(payload) - rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) - - -def get_learner_fn( - apply_fns: Tuple[ActorApply, CriticApply], - update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], - config: DictConfig, -) -> LearnerFn[LearnerState]: - """Get the learner function.""" - - # Get apply and update functions for actor and critic networks. - actor_apply_fn, critic_apply_fn = apply_fns - actor_update_fn, critic_update_fn = update_fns - - def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: ObservationGlobalState, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: - """A single update of the network. - - This function steps the environment and records the trajectory batch for - training. It then calculates advantages and targets based on the recorded - trajectory and updates the actor and critic networks based on the calculated - losses. - - Args: - learner_state (NamedTuple): - - params (Params): The current model parameters. - - opt_states (OptStates): The current optimizer states. - - key (PRNGKey): The random number generator state. - - env_state (State): The environment state. - - last_timestep (TimeStep): The last timestep in the current trajectory. - _ (Any): The current metrics info. - """ - - def _calculate_gae( #todo: lake sure this is appropriate - traj_batch: PPOTransition, last_val: chex.Array, last_done: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - def _get_advantages( - carry: Tuple[chex.Array, chex.Array, chex.Array], transition: PPOTransition - ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: - gae, next_value, next_done = carry - done, value, reward = transition.done, transition.value, transition.reward - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - next_done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae - return (gae, value, done), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val, last_done), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - # CALCULATE ADVANTAGE - params, opt_states, key, _, _ = learner_state - last_val = critic_apply_fn(params.critic_params, last_obs) - advantages, targets = _calculate_gae(traj_batch, last_val, last_dones) - - def _update_epoch(update_state: Tuple, _: Any) -> Tuple: - """Update the network for a single epoch.""" - - def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: - """Update the network for a single minibatch.""" - - # UNPACK TRAIN STATE AND BATCH INFO - params, opt_states, key = train_state - traj_batch, advantages, targets = batch_info - - def _actor_loss_fn( - actor_params: FrozenDict, - actor_opt_state: OptState, - traj_batch: PPOTransition, - gae: chex.Array, - key: chex.PRNGKey, - ) -> Tuple: - """Calculate the actor loss.""" - # RERUN NETWORK - actor_policy = actor_apply_fn(actor_params, traj_batch.obs) - log_prob = actor_policy.log_prob(traj_batch.action) - - # CALCULATE ACTOR LOSS - ratio = jnp.exp(log_prob - traj_batch.log_prob) - gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( - jnp.clip( - ratio, - 1.0 - config.system.clip_eps, - 1.0 + config.system.clip_eps, - ) - * gae - ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() - # The seed will be used in the TanhTransformedDistribution: - entropy = actor_policy.entropy(seed=key).mean() - - total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) - - def _critic_loss_fn( - critic_params: FrozenDict, - critic_opt_state: OptState, - traj_batch: PPOTransition, - targets: chex.Array, - ) -> Tuple: - """Calculate the critic loss.""" - # RERUN NETWORK - value = critic_apply_fn(critic_params, traj_batch.obs) - - # CALCULATE VALUE LOSS - value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( - -config.system.clip_eps, config.system.clip_eps - ) - value_losses = jnp.square(value - targets) - value_losses_clipped = jnp.square(value_pred_clipped - targets) - value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - - critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) - - # CALCULATE ACTOR LOSS - key, entropy_key = jax.random.split(key) - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn( - params.actor_params, - opt_states.actor_opt_state, - traj_batch, - advantages, - entropy_key, - ) - - # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, opt_states.critic_opt_state, traj_batch, targets - ) - - # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x - # pmean over devices. - actor_grads, actor_loss_info = jax.lax.pmean( - (actor_grads, actor_loss_info), axis_name="device" #todo: pmean over learner devices not all - ) - - # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" - ) - - # UPDATE ACTOR PARAMS AND OPTIMISER STATE - actor_updates, actor_new_opt_state = actor_update_fn( - actor_grads, opt_states.actor_opt_state - ) - actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - - # UPDATE CRITIC PARAMS AND OPTIMISER STATE - critic_updates, critic_new_opt_state = critic_update_fn( - critic_grads, opt_states.critic_opt_state - ) - critic_new_params = optax.apply_updates(params.critic_params, critic_updates) - - # PACK NEW PARAMS AND OPTIMISER STATE - new_params = Params(actor_new_params, critic_new_params) - new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] - loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, - } - return (new_params, new_opt_state, entropy_key), loss_info - - params, opt_states, traj_batch, advantages, targets, key = update_state - key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES - batch_size = config.system.rollout_length * (config.arch.num_envs // len(config.arch.learner_device_ids)) * len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor - permutation = jax.random.permutation(shuffle_key, batch_size) - batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) - shuffled_batch = jax.tree_util.tree_map( - lambda x: jnp.take(x, permutation, axis=0), batch - ) - minibatches = jax.tree_util.tree_map( - lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])), - shuffled_batch, - ) - # UPDATE MINIBATCHES - (params, opt_states, entropy_key), loss_info = jax.lax.scan( - _update_minibatch, (params, opt_states, entropy_key), minibatches - ) - - update_state = (params, opt_states, traj_batch, advantages, targets, key) - return update_state, loss_info - - update_state = (params, opt_states, traj_batch, advantages, targets, key) - # UPDATE EPOCHS - update_state, loss_info = jax.lax.scan( - _update_epoch, update_state, None, config.system.ppo_epochs - ) - - params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, None, None) - metric = traj_batch.info - return learner_state, (metric, loss_info) - - def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_dones : chex.Array) -> ExperimentOutput[LearnerState]: - """Learner function. - - This function represents the learner, it updates the network parameters - by iteratively applying the `_update_step` function for a fixed number of - updates. The `_update_step` function is vectorized over a batch of inputs. - - Args: - learner_state (NamedTuple): - - params (Params): The initial model parameters. - - opt_states (OptStates): The initial optimizer state. - - key (chex.PRNGKey): The random number generator state. - - env_state (LogEnvState): The environment state. - - timesteps (TimeStep): The initial timestep in the initial trajectory. - """ - - - learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch , last_obs, last_dones) - - return ExperimentOutput( - learner_state=learner_state, - episode_metrics=episode_info, - train_metrics=loss_info, - ) - - return learner_fn - - -def learner_setup( - keys: chex.Array, config: DictConfig, learner_devices: List -) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: - """Initialise learner_fn, network, optimiser, environment and states.""" - - #create temporory envoirnments. - env = environments.make_gym_env(config, 1, add_global_state=True) - # Get number of agents and actions. - action_space = env.single_action_space - config.system.num_agents = len(action_space) - config.system.num_actions = action_space[0].n - - # PRNG keys. - key, actor_net_key, critic_net_key = keys - - # Define network and optimiser. - actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - actor_action_head = hydra.utils.instantiate( - config.network.action_head, action_dim=config.system.num_actions - ) - critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) - - actor_network = Actor(torso=actor_torso, action_head=actor_action_head) - critic_network = Critic(torso=critic_torso, centralised_critic= True) - - actor_lr = make_learning_rate(config.system.actor_lr, config) - critic_lr = make_learning_rate(config.system.critic_lr, config) - - actor_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(actor_lr, eps=1e-5), - ) - critic_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(critic_lr, eps=1e-5), - ) - - # Initialise observation: Select only obs for a single agent. - obs, info = env.reset() - init_obs = jnp.stack(obs, axis = 1) # (num_envs, num_agents, ...) - init_mask = np.stack(info["actions_mask"]) # (num_envs, num_agents, num_actions) - init_global_obs = np.stack(info["global_obs"]) - init_x = ObservationGlobalState(init_obs, init_mask, init_global_obs) - - # Initialise actor params and optimiser state. - actor_params = actor_network.init(actor_net_key, init_x) - actor_opt_state = actor_optim.init(actor_params) - - # Initialise critic params and optimiser state. - critic_params = critic_network.init(critic_net_key, init_x) - critic_opt_state = critic_optim.init(critic_params) - - # Pack params. - params = Params(actor_params, critic_params) - - # Pack apply and update functions. - apply_fns = (actor_network.apply, critic_network.apply) - update_fns = (actor_optim.update, critic_optim.update) - - # Get batched iterated update and replicate it to pmap it over learner cores. - learn = get_learner_fn(apply_fns, update_fns, config) - learn = jax.pmap(learn, axis_name="device", devices = learner_devices) - - # Load model from checkpoint if specified. - if config.logger.checkpointing.load_model: - loaded_checkpoint = Checkpointer( - model_name=config.logger.system_name, - **config.logger.checkpointing.load_args, # Other checkpoint args - ) - # Restore the learner state from the checkpoint - restored_params, _ = loaded_checkpoint.restore_params(input_params=params) - # Update the params - params = restored_params - - # Define params to be replicated across devices and batches. - key, step_keys = jax.random.split(key) - opt_states = OptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states, step_keys) - - # Duplicate learner across Learner devices. - replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=learner_devices) - - # Initialise learner state. - params, opt_states, step_keys = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, None, None) - env.close() - - return learn, apply_fns, init_learner_state - - -def run_experiment(_config: DictConfig) -> float: - """Runs experiment.""" - config = copy.deepcopy(_config) - - devices = jax.devices() - learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] - - # PRNG keys. - key, key_e, actor_net_key, critic_net_key = jax.random.split( - jax.random.PRNGKey(config.system.seed), num=4 - ) - - # Sanity check of config - assert ( - config.arch.num_envs % len(config.arch.learner_device_ids) == 0 - ), "The number of environments must to be divisible by the number of learners " - - assert ( - int(config.arch.num_envs / len(config.arch.learner_device_ids)) - * config.arch.n_threads_per_executor - % config.system.num_minibatches - == 0 - ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" - - - # Setup learner. - learn, apply_fns , learner_state = learner_setup( - (key ,actor_net_key, critic_net_key), config, learner_devices - ) - - # Setup evaluator. - # One key per device for evaluation. - evaluator, absolute_metric_evaluator = make_eval_fns(environments.make_gym_env, apply_fns[0], config, add_global_state=True) #todo: make this more generic - - # Calculate total timesteps. - config = sebulba_check_total_timesteps(config) - assert ( - config.system.num_updates > config.arch.num_evaluation - ), "Number of updates per evaluation must be less than total number of updates." - - # Calculate number of updates per evaluation. - config.system.num_updates_per_eval, remaining_updates = divmod(config.system.num_updates , config.arch.num_evaluation) - config.arch.num_evaluation += (remaining_updates != 0) # Add an evaluation step if the num_updates is not a multiple of num_evaluation - steps_per_rollout = ( - len(config.arch.executor_device_ids) - * config.arch.n_threads_per_executor - * config.system.rollout_length - * config.arch.num_envs - * config.system.num_updates_per_eval - ) - - # Logger setup - logger = MavaLogger(config) - cfg: Dict = OmegaConf.to_container(config, resolve=True) - cfg["arch"]["devices"] = jax.devices() - pprint(cfg) - - # Set up checkpointer - save_checkpoint = config.logger.checkpointing.save_model - if save_checkpoint: - checkpointer = Checkpointer( - metadata=config, # Save all config as metadata in the checkpoint - model_name=config.logger.system_name, - **config.logger.checkpointing.save_args, # Checkpoint args - ) - - # Executor setup and launch. - unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) - params_queues: List = [] - rollout_queues: List = [] - for d_idx, d_id in enumerate( # Loop through each executor device - config.arch.executor_device_ids - ): - # Replicate params per executor device - device_params = jax.device_put(unreplicated_params, devices[d_id]) - # Loop through each executor thread - for thread_id in range(config.arch.n_threads_per_executor): - params_queues.append(queue.Queue(maxsize=1)) - rollout_queues.append(queue.Queue(maxsize=1)) - params_queues[-1].put(device_params) - threading.Thread( - target=rollout, - args=( - jax.device_put(key, devices[d_id]), - config, - rollout_queues[-1], - params_queues[-1], - apply_fns, - learner_devices, - d_id, - ), - ).start() #todo : Use a process instead of a thread? threads are limited by pything's GIL and they only run on a single core , processes have a bogger overhead (max num_env for optimal performance?) - - - # Run experiment for the total number of updates. - max_episode_return = jnp.float32(0.0) - best_params = None - for eval_step in range(config.arch.num_evaluation): - training_start_time = time.time() - learner_speeds = [] - rollout_times = [] - - episode_metrics = [] - train_metrics = [] - - # Make sure that the - num_updates_in_eval = config.system.num_updates_per_eval if eval_step != config.arch.num_evaluation - 1 else remaining_updates - for update in range(num_updates_in_eval): - sharded_storages = [] - sharded_next_obss = [] - sharded_next_dones = [] - - rollout_start_time = time.time() - # Loop through each executor device - for d_idx, _ in enumerate(config.arch.executor_device_ids): - # Loop through each executor thread - for thread_id in range(config.arch.n_threads_per_executor): - # Get data from rollout queue - ( - sharded_storage, - sharded_next_obs, - sharded_next_done, - ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() - sharded_storages.append(sharded_storage) - sharded_next_obss.append(sharded_next_obs) - sharded_next_dones.append(sharded_next_done) - - rollout_times.append(time.time() - rollout_start_time) - - - # Concatinate the returned trajectories on the n_env axis - sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) - sharded_next_obss = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 1), *sharded_next_obss) - sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) - - - learner_start_time = time.time() - learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_dones) - learner_speeds.append(time.time() - learner_start_time) - - # Stack the metrics - episode_metrics.append(learner_output.episode_metrics) - train_metrics.append(learner_output.train_metrics) - - # Send updated params to executors - unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) - for d_idx, d_id in enumerate(config.arch.executor_device_ids): - device_params = jax.device_put(unreplicated_params, devices[d_id]) - for thread_id in range(config.arch.n_threads_per_executor): - params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( - device_params - ) - - - - # Log the results of the training. - elapsed_time = time.time() - training_start_time - t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics = jax.tree_map(lambda *x : np.asarray(x), *episode_metrics) - episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time - - # Separately log timesteps, actoring metrics and training metrics. - speed_info = {"total_time" : elapsed_time, "rollout_time" : np.sum(rollout_times), "learner_time" : np.sum(learner_speeds), "timestep" : t} - logger.log(speed_info , t, eval_step, LogEvent.MISC) - if ep_completed: # only log episode metrics if an episode was completed in the rollout. - logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - train_metrics = jax.tree_map(lambda *x : np.asarray(x), *train_metrics) - logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) - - # Evaluation on the learner - evaluation_start_timer = time.time() - key_e, eval_key = jax.random.split(key_e, 2) - episode_metrics = evaluator(unreplicate_n_dims(learner_output.learner_state.params.actor_params, 1 ), eval_key) - - # Log the results of the evaluation. - elapsed_time = time.time() - evaluation_start_timer - episode_return = jnp.mean(episode_metrics["episode_return"]) - - steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) - episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(episode_metrics, t, eval_step, LogEvent.EVAL) - - if save_checkpoint: - # Save checkpoint of learner state - checkpointer.save( - timestep=steps_per_rollout * (eval_step + 1), - unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state, 1), - episode_return=episode_return, - ) - - if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(learner_output.learner_state.params) - max_episode_return = episode_return - - # Update runner state to continue training. - learner_state = learner_output.learner_state - - # Record the performance for the final evaluation run. - eval_performance = float(jnp.mean(episode_metrics[config.env.eval_metric])) - - # Measure absolute metric. - if config.arch.absolute_metric: - start_time = time.time() - - key_e, eval_key = jax.random.split(key_e, 2) - episode_metrics = absolute_metric_evaluator(unreplicate_n_dims(best_params.actor_params, 1), eval_key) - - elapsed_time = time.time() - start_time - steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) - - t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(episode_metrics, t, eval_step, LogEvent.ABSOLUTE) - - # Stop the logger. - logger.stop() - - return eval_performance - - - -@hydra.main(config_path="../../../configs", config_name="default_ff_mappo_seb.yaml", version_base="1.2") -def hydra_entry_point(cfg: DictConfig) -> float: - """Experiment entry point.""" - # Allow dynamic attributes. - OmegaConf.set_struct(cfg, False) - - # Run experiment. - eval_performance = run_experiment(cfg) - print(f"{Fore.CYAN}{Style.BRIGHT}IPPO experiment completed{Style.RESET_ALL}") - return eval_performance - - -if __name__ == "__main__": - hydra_entry_point() - -#learner_output.episode_metrics.keys() -#dict_keys(['episode_length', 'episode_return']) \ No newline at end of file diff --git a/mava/systems/sebulba/ppo/orig.py b/mava/systems/sebulba/ppo/orig.py deleted file mode 100644 index dde0add30..000000000 --- a/mava/systems/sebulba/ppo/orig.py +++ /dev/null @@ -1,795 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mava.utils.sebulba_utils import configure_computation_environment - -configure_computation_environment() # noqa: E402 - -import copy -import queue -import threading -import time -from collections import deque -from typing import Any, Dict, List, Tuple - -import chex -import flax -import hydra -import jax -import jax.numpy as jnp -import numpy as np -import optax -from chex import PRNGKey -from colorama import Fore, Style -from flax.core.frozen_dict import FrozenDict -from omegaconf import DictConfig, OmegaConf -from rich.pretty import pprint - -from mava.evaluator import get_sebulba_ff_evaluator as evaluator_setup -from mava.logger import Logger -from mava.networks import get_networks -from mava.types import ( - ActorApply, - CriticApply, - LearnerState, - OptStates, - Params, -) -from mava.types import PPOTransition as Transition -from mava.types import SebulbaLearnerFn as LearnerFn -from mava.types import SingleDeviceFn -from mava.utils.checkpointing import Checkpointer -from mava.utils.jax import merge_leading_dims -from mava.utils.make_env import make - - -def rollout( # noqa: CCR001 - rng: PRNGKey, - config: DictConfig, - rollout_queue: queue.Queue, - params_queue: queue.Queue, - device_thread_id: int, - apply_fns: Tuple, - logger: Logger, - learner_devices: List, -) -> None: - """Executor rollout loop.""" - # Create envs - envs = make(config)(config.arch.num_envs) # type: ignore - - # Setup - len_executor_device_ids = len(config.arch.executor_device_ids) - t_env = 0 - start_time = time.time() - - # Get the apply functions for the actor and critic networks. - vmap_actor_apply, vmap_critic_apply = apply_fns - - # Define the util functions: select action function and prepare data to share it with learner. - @jax.jit - def get_action_and_value( - params: FrozenDict, - observation: Observation, - key: PRNGKey, - ) -> Tuple: - """Get action and value.""" - key, subkey = jax.random.split(key) - - policy = vmap_actor_apply(params.actor_params, observation) - action, logprob = policy.sample_and_log_prob(seed=subkey) - - value = vmap_critic_apply(params.critic_params, observation).squeeze() - return action, logprob, value, key - - @jax.jit - def prepare_data(storage: List[Transition]) -> Transition: - """Prepare data to share with learner.""" - return jax.tree_map( # type: ignore - lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage - ) - - # Define the episode info - env_id = np.arange(config.arch.num_envs) - # Accumulated episode returns - episode_returns = np.zeros((config.arch.num_envs,), dtype=np.float32) - # Final episode returns - returned_episode_returns = np.zeros((config.arch.num_envs,), dtype=np.float32) - # Accumulated episode lengths - episode_lengths = np.zeros((config.arch.num_envs,), dtype=np.float32) - # Final episode lengths - returned_episode_lengths = np.zeros((config.arch.num_envs,), dtype=np.float32) - - # Define the data structure - params_queue_get_time: deque = deque(maxlen=10) - rollout_time: deque = deque(maxlen=10) - rollout_queue_put_time: deque = deque(maxlen=10) - - # Reset envs - next_obs, infos = envs.reset() - next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) - - # Loop till the learner has finished training - for update in range(1, config.system.num_updates + 2): - # Setup - env_recv_time: float = 0 - inference_time: float = 0 - storage_time: float = 0 - env_send_time: float = 0 - - # Get the latest parameters from the learner - params_queue_get_time_start = time.time() - if config.arch.concurrency: - if update != 2: - params = params_queue.get() - params.network_params["params"]["Dense_0"]["kernel"].block_until_ready() - else: - params = params_queue.get() - params_queue_get_time.append(time.time() - params_queue_get_time_start) - - # Rollout - rollout_time_start = time.time() - storage: List = [] - # Loop over the rollout length - for _ in range(0, config.system.rollout_length): - # Get previous step info - cached_next_obs = next_obs - cached_next_dones = next_dones - cashed_action_mask = np.stack(infos["actions_mask"]) - - # Increment current timestep - t_env += ( - config.arch.n_threads_per_executor * len_executor_device_ids * config.arch.num_envs - ) - - # Get action and value - inference_time_start = time.time() - ( - action, - logprob, - value, - rng, - ) = get_action_and_value(params, Observation(cached_next_obs, cashed_action_mask), rng) - inference_time += time.time() - inference_time_start - - # Step the environment - env_send_time_start = time.time() - cpu_action = np.array(action) - next_obs, next_reward, terminated, truncated, infos = envs.step(cpu_action) - next_done = terminated + truncated - next_dones = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - (next_done), - ) - - # Append data to storage - env_send_time += time.time() - env_send_time_start - storage_time_start = time.time() - storage.append( - Transition( - done=cached_next_dones, - action=action, - value=value, - reward=next_reward, - log_prob=logprob, - obs=cached_next_obs, - info=np.stack(infos["actions_mask"]), # Add action mask to info - ) - ) - storage_time += time.time() - storage_time_start - - # Update episode info ---------------------------------------------------------------------------------------------------------- this is kinda cringe? - episode_returns[env_id] += np.mean(next_reward, axis = 1) - returned_episode_returns[env_id] = np.where( - next_done, - episode_returns[env_id], - returned_episode_returns[env_id], - ) - episode_returns[env_id] *= (1 - next_done) * (1 - truncated) - episode_lengths[env_id] += 1 - returned_episode_lengths[env_id] = np.where( - next_done, - episode_lengths[env_id], - returned_episode_lengths[env_id], - ) - episode_lengths[env_id] *= (1 - next_done) * (1 - truncated) - rollout_time.append(time.time() - rollout_time_start) - - # Prepare data to share with learner - partitioned_storage = prepare_data(storage) - sharded_storage = Transition( - *list( # noqa: C417 - map( - lambda x: jax.device_put_sharded(x, devices=learner_devices), # type: ignore - partitioned_storage, - ) - ) - ) - sharded_next_obs = jax.device_put_sharded( - np.split(next_obs, len(learner_devices)), devices=learner_devices - ) - sharded_next_done = jax.device_put_sharded( - np.split(next_dones, len(learner_devices)), devices=learner_devices - ) - sharded_next_action_mask = jax.device_put_sharded( - np.split(np.stack(infos["actions_mask"]), len(learner_devices)), devices=learner_devices - ) - payload = ( - t_env, - sharded_storage, - sharded_next_obs, - sharded_next_done, - sharded_next_action_mask, - np.mean(params_queue_get_time), - ) - - # Put data in the rollout queue to share it with the learner - rollout_queue_put_time_start = time.time() - rollout_queue.put(payload) - rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) - - if (update % config.arch.log_frequency == 0) or (config.system.num_updates + 1 == update): - # Log info - logger.log_executor_metrics( - t_env=t_env, - metrics={ - "episodes_info": { - "episode_return": returned_episode_returns, - "episode_length": returned_episode_lengths, - "steps_per_second": int(t_env / (time.time() - start_time)), - }, - "speed_info": { - "rollout_time": np.mean(rollout_time), - }, - "queue_info": { - "params_queue_get_time": np.mean(params_queue_get_time), - "env_recv_time": env_recv_time, - "inference_time": inference_time, - "storage_time": storage_time, - "env_send_time": env_send_time, - "rollout_queue_put_time": np.mean(rollout_queue_put_time), - }, - }, - device_thread_id=device_thread_id, - ) - - -def get_learner_fn( - apply_fns: Tuple[ActorApply, CriticApply], - update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], - config: DictConfig, -) -> LearnerFn: - """Get the learner function.""" - # Get apply and update functions for actor and critic networks. - actor_apply_fn, critic_apply_fn = apply_fns - actor_update_fn, critic_update_fn = update_fns - - def single_device_update( - agents_state: LearnerState, - traj_batch: Transition, - last_observation: Observation, - rng: PRNGKey, - ) -> Tuple[LearnerState, chex.PRNGKey, Tuple]: - params, opt_states, _, _, _ = agents_state - - def _calculate_gae( - traj_batch: Transition, last_val: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - """Calculate the GAE.""" - - def _get_advantages(gae_and_next_value: Tuple, transition: Transition) -> Tuple: - """Calculate the GAE for a single transition.""" - gae, next_value = gae_and_next_value - done, value, reward = ( - transition.done, - transition.value, - transition.reward, - ) - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae - return (gae, value), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - # Calculate GAE - last_val = critic_apply_fn(params.critic_params, last_observation) - advantages, targets = _calculate_gae(traj_batch, last_val) - - def _update_epoch(update_state: Tuple, _: Any) -> Tuple: - """Update the network for a single epoch.""" - - def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: - """Update the network for a single minibatch.""" - - # UNPACK TRAIN STATE AND BATCH INFO - params, opt_states = train_state - traj_batch, advantages, targets = batch_info - - def _actor_loss_fn( - actor_params: FrozenDict, - actor_opt_state: OptStates, - traj_batch: Transition, - gae: chex.Array, - ) -> Tuple: - """Calculate the actor loss.""" - # RERUN NETWORK - actor_policy = actor_apply_fn(actor_params, traj_batch.obs) - log_prob = actor_policy.log_prob(traj_batch.action) - - # CALCULATE ACTOR LOSS - ratio = jnp.exp(log_prob - traj_batch.log_prob) - gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( - jnp.clip( - ratio, - 1.0 - config.system.clip_eps, - 1.0 + config.system.clip_eps, - ) - * gae - ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() - entropy = actor_policy.entropy().mean() - - total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) - - def _critic_loss_fn( - critic_params: FrozenDict, - critic_opt_state: OptStates, - traj_batch: Transition, - targets: chex.Array, - ) -> Tuple: - """Calculate the critic loss.""" - # RERUN NETWORK - value = critic_apply_fn(critic_params, traj_batch.obs) - - # CALCULATE VALUE LOSS - value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( - -config.system.clip_eps, config.system.clip_eps - ) - value_losses = jnp.square(value - targets) - value_losses_clipped = jnp.square(value_pred_clipped - targets) - value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - - critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) - - # CALCULATE ACTOR LOSS - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn( - params.actor_params, opt_states.actor_opt_state, traj_batch, advantages - ) - - # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, opt_states.critic_opt_state, traj_batch, targets - ) - - # Compute the parallel mean (pmean) over the learner devices. - actor_grads, actor_loss_info = jax.lax.pmean( - (actor_grads, actor_loss_info), axis_name="local_devices" - ) - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="local_devices" - ) - - # UPDATE ACTOR PARAMS AND OPTIMISER STATE - actor_updates, actor_new_opt_state = actor_update_fn( - actor_grads, opt_states.actor_opt_state - ) - actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - - # UPDATE CRITIC PARAMS AND OPTIMISER STATE - critic_updates, critic_new_opt_state = critic_update_fn( - critic_grads, opt_states.critic_opt_state - ) - critic_new_params = optax.apply_updates(params.critic_params, critic_updates) - - # PACK NEW PARAMS AND OPTIMISER STATE - new_params = Params(actor_new_params, critic_new_params) - new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] - loss_info = (total_loss, value_loss, actor_loss, entropy) - - return (new_params, new_opt_state), loss_info - - params, opt_states, traj_batch, advantages, targets, rng = update_state - rng, shuffle_rng = jax.random.split(rng) - - # SHUFFLE MINIBATCHES - batch_size = config.system.rollout_length * config.arch.num_envs - permutation = jax.random.permutation(shuffle_rng, batch_size) - batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) - shuffled_batch = jax.tree_util.tree_map( - lambda x: jnp.take(x, permutation, axis=0), batch - ) - minibatches = jax.tree_util.tree_map( - lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])), - shuffled_batch, - ) - - # UPDATE MINIBATCHES - (params, opt_states), loss_info = jax.lax.scan( - _update_minibatch, (params, opt_states), minibatches - ) - - update_state = (params, opt_states, traj_batch, advantages, targets, rng) - return update_state, loss_info - - update_state = (params, opt_states, traj_batch, advantages, targets, rng) - - # UPDATE EPOCHS - update_state, loss_info = jax.lax.scan( - _update_epoch, update_state, None, config.system.ppo_epochs - ) - - params, opt_states, traj_batch, advantages, targets, rng = update_state - learner_state = agents_state._replace(params=params, opt_states=opt_states) - return learner_state, rng, loss_info - - def learner_fn( - agents_state: LearnerState, - sharded_storages: List, - sharded_next_obs: List, - sharded_next_done: List, - sharded_next_action_mask: List, - key: chex.PRNGKey, - ) -> Tuple: - """Single device update.""" - # Horizontal stack all the data from different devices - traj_batch = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages) - traj_batch = traj_batch._replace(obs=Observation(traj_batch.obs, traj_batch.info)) - - # Get last observation - last_obs = jnp.concatenate(sharded_next_obs) - last_action_mask = jnp.concatenate(sharded_next_action_mask) - last_observation = Observation(last_obs, last_action_mask) - - # Update learner - agents_state, key, (total_loss, value_loss, actor_loss, entropy) = single_device_update( - agents_state, traj_batch, last_observation, key - ) - - # Pack loss info - loss_info = { - "total_loss": total_loss, - "loss_actor": actor_loss, - "value_loss": value_loss, - "entropy": entropy, - } - return agents_state, key, loss_info - - return learner_fn - - -def learner_setup( - rngs: chex.Array, config: DictConfig, learner_devices: List -) -> Tuple[SingleDeviceFn, LearnerState, Tuple[ActorApply, ActorApply]]: - """Initialise learner_fn, network, optimiser, environment and states.""" - # Get number of actions and agents. - dummy_envs = make(config)( # type: ignore - config.arch.num_envs # Create dummy_envs to get observation and action spaces - ) - config.system.num_agents = dummy_envs.single_observation_space.shape[0] - config.system.num_actions = int(dummy_envs.single_action_space.nvec[0]) - - # PRNG keys. - actor_net_key, critic_net_key = rngs - - # Define network and optimiser. - actor_network, critic_network = get_networks( - config=config, network="feedforward", centralised_critic=False - ) - actor_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(config.system.actor_lr, eps=1e-5), - ) - critic_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(config.system.critic_lr, eps=1e-5), - ) - - # Initialise observation: Select only obs for a single agent. - init_obs = np.array([dummy_envs.single_observation_space.sample()[0]]) - init_action_mask = np.ones((1, config.system.num_actions)) - init_x = Observation(init_obs, init_action_mask) - - # Initialise actor params and optimiser state. - actor_params = actor_network.init(actor_net_key, init_x) - actor_opt_state = actor_optim.init(actor_params) - - # Initialise critic params and optimiser state. - critic_params = critic_network.init(critic_net_key, init_x) - critic_opt_state = critic_optim.init(critic_params) - - # Vmap network apply function over number of agents. - vmapped_actor_network_apply_fn = jax.vmap( - actor_network.apply, - in_axes=(None, Observation(1, 1, None)), - out_axes=(1), - ) - vmapped_critic_network_apply_fn = jax.vmap( - critic_network.apply, - in_axes=(None, Observation(1, 1, None)), - out_axes=(1), - ) - - # Pack apply and update functions. - apply_fns = (vmapped_actor_network_apply_fn, vmapped_critic_network_apply_fn) - update_fns = (actor_optim.update, critic_optim.update) - - # Define agents state - agents_state = LearnerState( - params=Params( - actor_params=actor_params, - critic_params=critic_params, - ), - opt_states=OptStates( - actor_opt_state=actor_opt_state, - critic_opt_state=critic_opt_state, - ), - ) - # Replicate agents state per learner device - agents_state = flax.jax_utils.replicate(agents_state, devices=learner_devices) - - # Get Learner function: pmap over learner devices. - single_device_update = get_learner_fn(apply_fns, update_fns, config) - multi_device_update = jax.pmap( - single_device_update, - axis_name="local_devices", - devices=learner_devices, - ) - - # Close dummy envs. - dummy_envs.close() - - return multi_device_update, agents_state, apply_fns - - -def run_experiment(_config: DictConfig) -> None: # noqa: CCR001 - """Runs experiment.""" - config = copy.deepcopy(_config) - - # Setup device distribution. - local_devices = jax.local_devices() #why are we using local devices insted of devices? ------------------------------------------------------------------------------------------------------------------------------------ define a ratio insted of the devices to use? - learner_devices = [local_devices[d_id] for d_id in config.arch.learner_device_ids] - - # PRNG keys. - rng, rng_e, actor_net_key, critic_net_key = jax.random.split( - jax.random.PRNGKey(config.system.seed), num=4 - ) - learner_keys = jax.device_put_replicated(rng, learner_devices) - - # Sanity check of config - assert ( - config.arch.num_envs % len(config.arch.learner_device_ids) == 0 - ), "local_num_envs must be divisible by len(learner_device_ids)" - #each thread is going to devide needs to give an equal number of traj to each learning device? shound't each actor Thread have a designated N learneres? If we have less actor T than learners then ech actor will devide based on the num_env and gives to N actors, ig to lessen the managment each actor gives to all of the learners? - #this deviates from the paper? - assert ( - int(config.arch.num_envs / len(config.arch.learner_device_ids)) - * config.arch.n_threads_per_executor - % config.system.num_minibatches - == 0 - ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" #this one makes sense but the assertion is a bit off? - - # Setup learner. - ( - multi_device_update, - agents_state, - apply_fns, - ) = learner_setup((actor_net_key, critic_net_key), config, learner_devices) - - # Setup evaluator. - eval_envs = make(config)(config.arch.num_eval_episodes) # type: ignore - evaluator = evaluator_setup(eval_envs=eval_envs, apply_fn=apply_fns[0], config=config) - - # Calculate total timesteps. - batch_size = int( - config.arch.num_envs - * config.system.rollout_length - * config.arch.n_threads_per_executor - * len(config.arch.executor_device_ids) - ) - config.system.total_timesteps = config.system.num_updates * batch_size - - # Setup logger. - config.arch.log_frequency = config.system.num_updates // config.arch.num_evaluation - logger = Logger(config) - cfg_dict: Dict = OmegaConf.to_container(config, resolve=True) - pprint(cfg_dict) - - # Set up checkpointer - save_checkpoint = config.logger.checkpointing.save_model - if save_checkpoint: - checkpointer = Checkpointer( - metadata=cfg_dict, # Save all config as metadata in the checkpoint - model_name=config.logger.system_name, - **config.logger.checkpointing.save_args, # Checkpoint args - ) - - if config.logger.checkpointing.load_model: - print( - f"{Fore.RED}{Style.BRIGHT}Loading checkpoint is not supported\ - for sebulba architecture yet{Style.RESET_ALL}" - ) - - # Executor setup and launch. - unreplicated_params = flax.jax_utils.unreplicate(agents_state.params) - params_queues: List = [] - rollout_queues: List = [] - for d_idx, d_id in enumerate( # Loop through each executor device - config.arch.executor_device_ids - ): - # Replicate params per executor device - device_params = jax.device_put(unreplicated_params, local_devices[d_id]) - # Loop through each executor thread - for thread_id in range(config.arch.n_threads_per_executor): - params_queues.append(queue.Queue(maxsize=1)) - rollout_queues.append(queue.Queue(maxsize=1)) - params_queues[-1].put(device_params) - threading.Thread( - target=rollout, - args=( - jax.device_put(rng, local_devices[d_id]), - config, - rollout_queues[-1], - params_queues[-1], - d_idx * config.arch.n_threads_per_executor + thread_id, - apply_fns, - logger, - learner_devices, - ), - ).start() - - # Run experiment for the total number of updates. - rollout_queue_get_time: deque = deque(maxlen=10) - data_transfer_time: deque = deque(maxlen=10) - trainer_update_number = 0 - max_episode_return = jnp.float32(0.0) - best_params = None - while True: - trainer_update_number += 1 - rollout_queue_get_time_start = time.time() - sharded_storages = [] - sharded_next_obss = [] - sharded_next_dones = [] - sharded_next_action_masks = [] - - # Loop through each executor device - for d_idx, _ in enumerate(config.arch.executor_device_ids): - # Loop through each executor thread - for thread_id in range(config.arch.n_threads_per_executor): - # Get data from rollout queue - ( - t_env, - sharded_storage, - sharded_next_obs, - sharded_next_done, - sharded_next_action_mask, - avg_params_queue_get_time, - ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() - sharded_storages.append(sharded_storage) - sharded_next_obss.append(sharded_next_obs) - sharded_next_dones.append(sharded_next_done) - sharded_next_action_masks.append(sharded_next_action_mask) - - rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) - training_time_start = time.time() - - # Update learner - (agents_state, learner_keys, loss_info) = multi_device_update( # type: ignore - agents_state, - sharded_storages, - sharded_next_obss, - sharded_next_dones, - sharded_next_action_masks, - learner_keys, - ) - - # Send updated params to executors - unreplicated_params = flax.jax_utils.unreplicate(agents_state.params) - for d_idx, d_id in enumerate(config.arch.executor_device_ids): - device_params = jax.device_put(unreplicated_params, local_devices[d_id]) - for thread_id in range(config.arch.n_threads_per_executor): - params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( - device_params - ) - - if trainer_update_number % config.arch.log_frequency == 0: - # Logging training info - logger.log_trainer_metrics( - experiment_output={ - "loss_info": loss_info, - "queue_info": { - "rollout_queue_get_time": np.mean(rollout_queue_get_time), - "data_transfer_time": np.mean(data_transfer_time), - "rollout_params_queue_get_time_diff": np.mean(rollout_queue_get_time) - - avg_params_queue_get_time, - "rollout_queue_size": rollout_queues[0].qsize(), - "params_queue_size": params_queues[0].qsize(), - }, - "speed_info": { - "training_time": time.time() - training_time_start, - "trainer_update_number": trainer_update_number, - }, - }, - t_env=t_env, - ) - - # Evaluation - rng_e, _ = jax.random.split(rng_e) - evaluator_output = evaluator(params=unreplicated_params, rng=rng_e) - # Log the results of the evaluation. - episode_return = logger.log_evaluator_metrics( - t_env=t_env, - metrics=evaluator_output, - eval_step=trainer_update_number, - ) - - if save_checkpoint: - # Save checkpoint of learner state - checkpointer.save( - timestep=t_env, - unreplicated_learner_state=flax.jax_utils.unreplicate(agents_state), - episode_return=episode_return, - ) - - if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(unreplicated_params) - max_episode_return = episode_return - - # Check if training is finished - if trainer_update_number >= config.system.num_updates: - rng_e, _ = jax.random.split(rng_e) - # Measure absolute metric - evaluator_output = evaluator(params=best_params, rng=rng_e, eval_multiplier=10) - # Log the results of the evaluation. - logger.log_evaluator_metrics( - t_env=t_env, - metrics=evaluator_output, - eval_step=trainer_update_number + 1, - absolute_metric=True, - ) - break - - -@hydra.main(config_path="../../configs", config_name="default_ff_ippo.yaml", version_base="1.2") -def hydra_entry_point(cfg: DictConfig) -> None: - """Experiment entry point.""" - - # Run experiment. - run_experiment(cfg) - - print(f"{Fore.CYAN}{Style.BRIGHT}IPPO experiment completed{Style.RESET_ALL}") - - -if __name__ == "__main__": - hydra_entry_point() \ No newline at end of file diff --git a/mava/systems/sebulba/ppo/rec_ippo.py b/mava/systems/sebulba/ppo/rec_ippo.py deleted file mode 100644 index 6e204fb21..000000000 --- a/mava/systems/sebulba/ppo/rec_ippo.py +++ /dev/null @@ -1,850 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import time -from typing import Any, Dict, Tuple, List -import threading -import chex -import flax -import hydra -import jax -import jax.debug -import jax.numpy as jnp -import numpy as np -import optax -import queue -from collections import deque -from colorama import Fore, Style -from flax.core.frozen_dict import FrozenDict -from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState -from rich.pretty import pprint - -from mava.evaluator import make_sebulba_eval_fns as make_eval_fns -from mava.networks import RecurrentActor as Actor -from mava.networks import RecurrentValueNet as Critic -from mava.networks import ScannedRNN -from mava.systems.anakin.ppo.types import ( - HiddenStates, - OptStates, - Params, - RNNLearnerState, - RNNPPOTransition, -) -from mava.types import ExperimentOutput, LearnerFn, RecActorApply, RecCriticApply, RNNObservation, Observation -from mava.utils import make_env as environments -from mava.utils.checkpointing import Checkpointer -from mava.utils.jax_utils import ( - merge_leading_dims, - unreplicate_batch_dim, - unreplicate_n_dims, -) -from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import sebulba_check_total_timesteps -from mava.utils.training import make_learning_rate -from mava.wrappers.episode_metrics import get_final_step_metrics - - -def rollout( - key: chex.PRNGKey, - config: DictConfig, - rollout_queue: queue.Queue, - params_queue: queue.Queue, - apply_fns: Tuple, - learner_devices: List, - actor_device_id : int, - init_hstates : HiddenStates): - - #setup - - env = environments.make_gym_env(config, config.arch.num_envs) - current_actor_device = jax.devices()[actor_device_id] - actor_apply_fn, critic_apply_fn = apply_fns - - # Define the util functions: select action function and prepare data to share it with learner. - @jax.jit - def get_action_and_value( - params: FrozenDict, - observation: RNNObservation, - last_hstates : HiddenStates, - key: chex.PRNGKey, - ) -> Tuple: - """Get action and value.""" - key, subkey = jax.random.split(key) - - policy_hidden_state, actor_policy = actor_apply_fn(params.actor_params, last_hstates.policy_hidden_state, observation) - action = actor_policy.sample(seed=subkey) - log_prob = actor_policy.log_prob(action) - - critic_hidden_state, value = critic_apply_fn(params.critic_params, last_hstates.critic_hidden_state, observation) - hastates = HiddenStates(policy_hidden_state, critic_hidden_state) - return action, log_prob, value, key, hastates - - # Define queues to track time - params_queue_get_time: deque = deque(maxlen=1) - rollout_time: deque = deque(maxlen=1) - rollout_queue_put_time: deque = deque(maxlen=1) - - next_obs , info = env.reset() - next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) - next_hstates = init_hstates - move_to_device = lambda x : jax.device_put(x, device = current_actor_device) - - # Loop till the learner has finished training - for update in range(config.system.num_updates): - inference_time: float = 0 - storage_time: float = 0 - env_send_time: float = 0 - - # Get the latest parameters from the learner - params_queue_get_time_start = time.time() - params = params_queue.get() - params_queue_get_time.append(time.time() - params_queue_get_time_start) - - # Rollout - rollout_time_start = time.time() - storage: List = [] - - # Loop over the rollout length - for _ in range(0, config.system.rollout_length): - - # Cached for transition - cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) # (num_envs, num_agents, ...) - cached_next_dones = move_to_device(next_dones) # (num_envs, num_agents) - cashed_action_mask = move_to_device(np.stack(info["actions_mask"])) # (num_envs, num_agents, num_actions) - - # Add the sequence_len dim - cached_next_obs, cached_next_dones, cashed_action_mask = jax.tree_map(lambda x: x[jnp.newaxis, : ], (cached_next_obs, cached_next_dones, cashed_action_mask)) - - full_observation = Observation(cached_next_obs, cashed_action_mask) - full_observation_dones = (full_observation, cached_next_dones) - cashed_next_hstate = move_to_device(next_hstates) - # Get action and value - inference_time_start = time.time() - ( - action, - log_prob, - value, - key, - next_hstates - ) = get_action_and_value(params, full_observation_dones, cashed_next_hstate, key) - - - # Step the environment - inference_time += time.time() - inference_time_start - env_send_time_start = time.time() - cpu_action = jax.device_get(action) - next_obs, next_reward, terminated, truncated, info = env.step(cpu_action[0].swapaxes(0,1)) # (num_env, num_agents) --> (num_agents, num_env) - env_send_time += time.time() - env_send_time_start - - # Prepare the data - storage_time_start = time.time() - next_dones = np.logical_or(terminated, truncated) - metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics - - # Append data to storage - storage.append( - RNNPPOTransition( - done=cached_next_dones[0], - action=action[0], - value=value[0], - reward=next_reward, - log_prob=log_prob[0], - obs=Observation(cached_next_obs[0], cashed_action_mask[0]), - hstates=cashed_next_hstate, - info=metrics, - ) - ) - storage_time += time.time() - storage_time_start - rollout_time.append(time.time() - rollout_time_start) - - parse_timer = time.time() - - # Prepare data to share with learner - #[PPOTransition() * rollout_len] --> PPOTransition[done = (rollout_len, num_envs, num_agents), action = (rollout_len, num_envs, num_agents, num_actions), ...] - stacked_storage = jax.tree_map( lambda *xs : jnp.stack(xs), *storage) - - # Split the arrays over the different learner_devices on the num_envs axis - shard_split_payload= lambda x, axis : jax.device_put_sharded(jnp.split(x, len(learner_devices), axis=axis), devices=learner_devices) - - sharded_storage = jax.tree_map(lambda x : shard_split_payload(x, 1) , stacked_storage) # (num_learner_devices, rollout_len, num_envs, num_agents, ...) - - # (num_learner_devices, num_envs, num_agents, ...) - sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis = 1), 0) - sharded_next_action_mask = shard_split_payload(np.stack(info["actions_mask"]), 0) - sharded_next_done = shard_split_payload(next_dones, 0) - sharded_next_hstate = jax.tree_map( lambda x: shard_split_payload(x,0), next_hstates) - - # Pack the obs and action mask - payload_obs_dones = (Observation(sharded_next_obs, sharded_next_action_mask), cached_next_dones) - - # For debugging - speed_info = { - "rollout_time": np.mean(rollout_time), - "params_queue_get_time": np.mean(params_queue_get_time), - "action_inference": inference_time, - "storage_time": storage_time, - "env_step_time": env_send_time, - "rollout_queue_put_time": np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0, - "parse_time" : time.time() - parse_timer, - } - #print(speed_info) - - payload = ( - sharded_storage, - payload_obs_dones, - sharded_next_done, - sharded_next_hstate - ) - - # Put data in the rollout queue to share it with the learner - rollout_queue_put_time_start = time.time() - rollout_queue.put(payload) - rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) - - -def get_learner_fn( - apply_fns: Tuple[ RecActorApply, RecCriticApply], - update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], - config: DictConfig, -) -> LearnerFn[RNNLearnerState]: - """Get the learner function.""" - - # Get apply and update functions for actor and critic networks. - actor_apply_fn, critic_apply_fn = apply_fns - actor_update_fn, critic_update_fn = update_fns - - def _update_step(learner_state: RNNLearnerState, traj_batch : RNNPPOTransition, last_obs: RNNObservation, last_dones : chex.Array, last_hstate : HiddenStates) -> Tuple[RNNLearnerState, Tuple]: - """A single update of the network. - - This function steps the environment and records the trajectory batch for - training. It then calculates advantages and targets based on the recorded - trajectory and updates the actor and critic networks based on the calculated - losses. - - Args: - learner_state (NamedTuple): - - params (Params): The current model parameters. - - opt_states (OptStates): The current optimizer states. - - key (PRNGKey): The random number generator state. - - env_state (State): The environment state. - - last_timestep (TimeStep): The last timestep in the current trajectory. - _ (Any): The current metrics info. - """ - - def _calculate_gae( #todo: lake sure this is appropriate - traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - def _get_advantages( - carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition - ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: - gae, next_value, next_done = carry - done, value, reward = transition.done, transition.value, transition.reward - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - next_done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae - return (gae, value, done), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val, last_done), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - # CALCULATE ADVANTAGE - params, opt_states, key, _, _, _, _ = learner_state - last_obs = jax.tree_map(lambda x: x[jnp.newaxis, : ], last_obs) - last_dones = last_dones[jnp.newaxis, :] - - - _, last_val = critic_apply_fn(params.critic_params, last_hstate.critic_hidden_state, last_obs) - - advantages, targets = _calculate_gae(traj_batch, last_val[0], last_dones[0]) - - def _update_epoch(update_state: Tuple, _: Any) -> Tuple: - """Update the network for a single epoch.""" - - def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: - """Update the network for a single minibatch.""" - - # UNPACK TRAIN STATE AND BATCH INFO - params, opt_states, key = train_state - traj_batch, advantages, targets = batch_info - - def _actor_loss_fn( - actor_params: FrozenDict, - actor_opt_state: OptState, - traj_batch: RNNPPOTransition, - gae: chex.Array, - key: chex.PRNGKey, - ) -> Tuple: - """Calculate the actor loss.""" - # RERUN NETWORK - - obs_and_done = (traj_batch.obs, traj_batch.done) - _, actor_policy = actor_apply_fn( - actor_params, traj_batch.hstates.policy_hidden_state[0], obs_and_done - ) - log_prob = actor_policy.log_prob(traj_batch.action) - - ratio = jnp.exp(log_prob - traj_batch.log_prob) - gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( - jnp.clip( - ratio, - 1.0 - config.system.clip_eps, - 1.0 + config.system.clip_eps, - ) - * gae - ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() - # The seed will be used in the TanhTransformedDistribution: - entropy = actor_policy.entropy(seed=key).mean() - - total_loss = loss_actor - config.system.ent_coef * entropy - return total_loss, (loss_actor, entropy) - - def _critic_loss_fn( - critic_params: FrozenDict, - critic_opt_state: OptState, - traj_batch: RNNPPOTransition, - targets: chex.Array, - ) -> Tuple: - """Calculate the critic loss.""" - # RERUN NETWORK - obs_and_done = (traj_batch.obs, traj_batch.done) - _, value = critic_apply_fn( - critic_params, traj_batch.hstates.critic_hidden_state[0], obs_and_done - ) - - # CALCULATE VALUE LOSS - value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( - -config.system.clip_eps, config.system.clip_eps - ) - value_losses = jnp.square(value - targets) - value_losses_clipped = jnp.square(value_pred_clipped - targets) - value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - - total_loss = config.system.vf_coef * value_loss - return total_loss, (value_loss) - - # CALCULATE ACTOR LOSS - key, entropy_key = jax.random.split(key) - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn( - params.actor_params, - opt_states.actor_opt_state, - traj_batch, - advantages, - entropy_key, - ) - - # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, opt_states.critic_opt_state, traj_batch, targets - ) - - # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x - # pmean over devices. - actor_grads, actor_loss_info = jax.lax.pmean( - (actor_grads, actor_loss_info), axis_name="device" - ) - # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" - ) - - # UPDATE ACTOR PARAMS AND OPTIMISER STATE - actor_updates, actor_new_opt_state = actor_update_fn( - actor_grads, opt_states.actor_opt_state - ) - actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - - # UPDATE CRITIC PARAMS AND OPTIMISER STATE - critic_updates, critic_new_opt_state = critic_update_fn( - critic_grads, opt_states.critic_opt_state - ) - critic_new_params = optax.apply_updates(params.critic_params, critic_updates) - - new_params = Params(actor_new_params, critic_new_params) - new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] - loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, - } - - return (new_params, new_opt_state, entropy_key), loss_info - - params, opt_states, traj_batch, advantages, targets, key = update_state - key, shuffle_key, entropy_key = jax.random.split(key, 3) - - # SHUFFLE MINIBATCHES - batch = (traj_batch, advantages, targets) - num_recurrent_chunks = ( - config.system.rollout_length // config.system.recurrent_chunk_size - ) - batch = jax.tree_util.tree_map( - lambda x: x.reshape( - config.system.recurrent_chunk_size, - config.arch.num_envs * num_recurrent_chunks, - *x.shape[2:], - ), - batch, - ) - permutation = jax.random.permutation( - shuffle_key, config.arch.num_envs * num_recurrent_chunks - ) - shuffled_batch = jax.tree_util.tree_map( - lambda x: jnp.take(x, permutation, axis=1), batch - ) - reshaped_batch = jax.tree_util.tree_map( - lambda x: jnp.reshape( - x, (x.shape[0], config.system.num_minibatches, -1, *x.shape[2:]) - ), - shuffled_batch, - ) - minibatches = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch) - - # UPDATE MINIBATCHES - (params, opt_states, entropy_key), loss_info = jax.lax.scan( - _update_minibatch, (params, opt_states, entropy_key), minibatches - ) - - update_state = ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - ) - return update_state, loss_info - - update_state = (params, opt_states, traj_batch, advantages, targets, key) - # UPDATE EPOCHS - update_state, loss_info = jax.lax.scan( - _update_epoch, update_state, None, config.system.ppo_epochs - ) - - params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = RNNLearnerState(params, opt_states, key, None, None, None, None) - metric = traj_batch.info - return learner_state, (metric, loss_info) - - def learner_fn(learner_state: RNNLearnerState, traj_batch : RNNPPOTransition, last_obs: chex.Array, last_dones : chex.Array, last_hstate : chex.Array) -> ExperimentOutput[RNNLearnerState]: - """Learner function. - - This function represents the learner, it updates the network parameters - by iteratively applying the `_update_step` function for a fixed number of - updates. The `_update_step` function is vectorized over a batch of inputs. - - Args: - learner_state (NamedTuple): - - params (Params): The initial model parameters. - - opt_states (OptStates): The initial optimizer state. - - key (chex.PRNGKey): The random number generator state. - - env_state (LogEnvState): The environment state. - - timesteps (TimeStep): The initial timestep in the initial trajectory. - """ - - - learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch , last_obs, last_dones, last_hstate) - - return ExperimentOutput( - learner_state=learner_state, - episode_metrics=episode_info, - train_metrics=loss_info, - ) - - return learner_fn - - -def learner_setup( - keys: chex.Array, config: DictConfig, learner_devices: List -) -> Tuple[LearnerFn[RNNLearnerState], Actor, RNNLearnerState]: - """Initialise learner_fn, network, optimiser, environment and states.""" - - #create temporory envoirnments. - env = environments.make_gym_env(config, 1) - # Get number of agents and actions. - action_space = env.single_action_space - config.system.num_agents = len(action_space) - config.system.num_actions = action_space[0].n - - # PRNG keys. - key, actor_net_key, critic_net_key = keys - - # Define network and optimisers. - actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - actor_action_head = hydra.utils.instantiate( - config.network.action_head, action_dim=config.system.num_actions - ) - critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) - critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) - - actor_network = Actor( - pre_torso=actor_pre_torso, - post_torso=actor_post_torso, - action_head=actor_action_head, - hidden_state_dim=config.network.hidden_state_dim, - ) - critic_network = Critic( - pre_torso=critic_pre_torso, - post_torso=critic_post_torso, - hidden_state_dim=config.network.hidden_state_dim, - ) - - actor_lr = make_learning_rate(config.system.actor_lr, config) - critic_lr = make_learning_rate(config.system.critic_lr, config) - - actor_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(actor_lr, eps=1e-5), - ) - critic_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(critic_lr, eps=1e-5), - ) - - # Initialise observation: Select only obs for a single agent. - init_obs = jnp.array([[env.single_observation_space.sample()]]) - init_action_mask = jnp.ones((config.system.num_agents, config.system.num_actions)) - init_dones = jnp.zeros((1, 1, config.system.num_agents), dtype=jax.numpy.bool_) - init_x = (Observation(init_obs, init_action_mask), init_dones) - - # Initialise hidden states. - init_policy_hstate = ScannedRNN.initialize_carry( - (config.arch.num_envs, config.system.num_agents), config.network.hidden_state_dim - ) - init_critic_hstate = ScannedRNN.initialize_carry( - (config.arch.num_envs, config.system.num_agents), config.network.hidden_state_dim - ) - - # initialise params and optimiser state. - actor_params = actor_network.init(actor_net_key, init_policy_hstate, init_x) - actor_opt_state = actor_optim.init(actor_params) - critic_params = critic_network.init(critic_net_key, init_critic_hstate, init_x) - critic_opt_state = critic_optim.init(critic_params) - - # Get network apply functions and optimiser updates. - apply_fns = (actor_network.apply, critic_network.apply) - update_fns = (actor_optim.update, critic_optim.update) - - # Get batched iterated update and replicate it to pmap it over learner cores. - learn = get_learner_fn(apply_fns, update_fns, config) - learn = jax.pmap(learn, axis_name="device", devices = learner_devices) - - # Pack params and initial states. - params = Params(actor_params, critic_params) - hstates = HiddenStates(init_policy_hstate, init_critic_hstate) - - # Load model from checkpoint if specified. - if config.logger.checkpointing.load_model: - loaded_checkpoint = Checkpointer( - model_name=config.logger.system_name, - **config.logger.checkpointing.load_args, # Other checkpoint args - ) - # Restore the learner state from the checkpoint - restored_params, restored_hstates = loaded_checkpoint.restore_params( - input_params=params, restore_hstates=True, THiddenState=HiddenStates - ) - # Update the params and hstates - params = restored_params - hstates = restored_hstates if restored_hstates else hstates - - # Define params to be replicated across devices and batches. - key, step_keys = jax.random.split(key) - opt_states = OptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states, hstates, step_keys) - - # Duplicate learner across Learner devices. - replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=learner_devices) - - # Initialise learner state. - params, opt_states, hstates, step_keys = replicate_learner - init_learner_state = RNNLearnerState(params, opt_states, step_keys, None, None, init_dones, hstates) - env.close() - - return learn, apply_fns, init_learner_state - - -def run_experiment(_config: DictConfig) -> float: - """Runs experiment.""" - config = copy.deepcopy(_config) - - devices = jax.devices() - learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] - - # PRNG keys. - key, key_e, actor_net_key, critic_net_key = jax.random.split( - jax.random.PRNGKey(config.system.seed), num=4 - ) - - # Sanity check of config - if config.system.recurrent_chunk_size is None: - config.system.recurrent_chunk_size = config.system.rollout_length - else: - assert ( - config.system.rollout_length % config.system.recurrent_chunk_size == 0 - ), "Rollout length must be divisible by recurrent chunk size." - assert ( - config.arch.num_envs % len(config.arch.learner_device_ids) == 0 - ), "The number of environments must to be divisible by the number of learners " - - assert ( - int(config.arch.num_envs / len(config.arch.learner_device_ids)) - * config.arch.n_threads_per_executor - % config.system.num_minibatches - == 0 - ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" - - - # Setup learner. - learn, apply_fns , learner_state = learner_setup( - (key ,actor_net_key, critic_net_key), config, learner_devices - ) - - # Setup evaluator. - # One key per device for evaluation. - evaluator, absolute_metric_evaluator = make_eval_fns(environments.make_gym_env, apply_fns[0], config,use_recurrent_net = True, scanned_rnn = ScannedRNN) #todo: make this more generic - - # Calculate total timesteps. - config = sebulba_check_total_timesteps(config) - assert ( - config.system.num_updates > config.arch.num_evaluation - ), "Number of updates per evaluation must be less than total number of updates." - - # Calculate number of updates per evaluation. - config.system.num_updates_per_eval, remaining_updates = divmod(config.system.num_updates , config.arch.num_evaluation) - config.arch.num_evaluation += (remaining_updates != 0) # Add an evaluation step if the num_updates is not a multiple of num_evaluation - steps_per_rollout = ( - len(config.arch.executor_device_ids) - * config.arch.n_threads_per_executor - * config.system.rollout_length - * config.arch.num_envs - * config.system.num_updates_per_eval - ) - - # Logger setup - logger = MavaLogger(config) - cfg: Dict = OmegaConf.to_container(config, resolve=True) - cfg["arch"]["devices"] = jax.devices() - pprint(cfg) - - # Set up checkpointer - save_checkpoint = config.logger.checkpointing.save_model - if save_checkpoint: - checkpointer = Checkpointer( - metadata=config, # Save all config as metadata in the checkpoint - model_name=config.logger.system_name, - **config.logger.checkpointing.save_args, # Checkpoint args - ) - - # Executor setup and launch. - unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) - unreplicated_hstates = flax.jax_utils.unreplicate(learner_state.hstates) - params_queues: List = [] - rollout_queues: List = [] - for d_idx, d_id in enumerate( # Loop through each executor device - config.arch.executor_device_ids - ): - # Replicate params per executor device - device_params = jax.device_put(unreplicated_params, devices[d_id]) - device_hstates = jax.device_put(unreplicated_hstates, devices[d_id]) - # Loop through each executor thread - for thread_id in range(config.arch.n_threads_per_executor): - params_queues.append(queue.Queue(maxsize=1)) - rollout_queues.append(queue.Queue(maxsize=1)) - params_queues[-1].put(device_params) - threading.Thread( - target=rollout, - args=( - jax.device_put(key, devices[d_id]), - config, - rollout_queues[-1], - params_queues[-1], - apply_fns, - learner_devices, - d_id, - device_hstates, - ), - ).start() #todo : Use a process instead of a thread? threads are limited by pything's GIL and they only run on a single core , processes have a bogger overhead (max num_env for optimal performance?) - - # Run experiment for the total number of updates. - max_episode_return = jnp.float32(0.0) - best_params = None - for eval_step in range(config.arch.num_evaluation): - training_start_time = time.time() - learner_speeds = [] - rollout_times = [] - - episode_metrics = [] - train_metrics = [] - - # Make sure that the - num_updates_in_eval = config.system.num_updates_per_eval if eval_step != config.arch.num_evaluation - 1 else remaining_updates - for update in range(num_updates_in_eval): - sharded_storages = [] - sharded_next_obss = [] - sharded_next_dones = [] - sharded_next_hstates = [] - - rollout_start_time = time.time() - # Loop through each executor device - for d_idx, _ in enumerate(config.arch.executor_device_ids): - # Loop through each executor thread - for thread_id in range(config.arch.n_threads_per_executor): - # Get data from rollout queue - ( - sharded_storage, - sharded_next_obs, - sharded_next_done, - sharded_next_hstate, - ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() - sharded_storages.append(sharded_storage) - sharded_next_obss.append(sharded_next_obs) - sharded_next_dones.append(sharded_next_done) - sharded_next_hstates.append(sharded_next_hstate) - - rollout_times.append(time.time() - rollout_start_time) - - - # Concatinate the returned trajectories on the n_env axis - sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) - sharded_next_obss = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 1), *sharded_next_obss) - sharded_next_hstates = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 1), *sharded_next_hstates) - - sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) - - learner_start_time = time.time() - learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_dones, sharded_next_hstates) - learner_speeds.append(time.time() - learner_start_time) - - # Stack the metrics - episode_metrics.append(learner_output.episode_metrics) - train_metrics.append(learner_output.train_metrics) - - # Send updated params to executors - unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) - for d_idx, d_id in enumerate(config.arch.executor_device_ids): - device_params = jax.device_put(unreplicated_params, devices[d_id]) - for thread_id in range(config.arch.n_threads_per_executor): - params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( - device_params - ) - - - - # Log the results of the training. - elapsed_time = time.time() - training_start_time - t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics = jax.tree_map(lambda *x : np.asarray(x), *episode_metrics) - episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time - - # Separately log timesteps, actoring metrics and training metrics. - speed_info = {"total_time" : elapsed_time, "rollout_time" : np.sum(rollout_times), "learner_time" : np.sum(learner_speeds), "timestep" : t} - logger.log(speed_info , t, eval_step, LogEvent.MISC) - if ep_completed: # only log episode metrics if an episode was completed in the rollout. - logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - train_metrics = jax.tree_map(lambda *x : np.asarray(x), *train_metrics) - logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) - - # Evaluation on the learner - evaluation_start_timer = time.time() - key_e, eval_key = jax.random.split(key_e, 2) - episode_metrics = evaluator(unreplicate_n_dims(learner_output.learner_state.params.actor_params, 1 ), eval_key) - - # Log the results of the evaluation. - elapsed_time = time.time() - evaluation_start_timer - episode_return = jnp.mean(episode_metrics["episode_return"]) - - steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) - episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(episode_metrics, t, eval_step, LogEvent.EVAL) - - if save_checkpoint: - # Save checkpoint of learner state - checkpointer.save( - timestep=steps_per_rollout * (eval_step + 1), - unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state, 1), - episode_return=episode_return, - ) - - if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(learner_output.learner_state.params) - max_episode_return = episode_return - - # Update runner state to continue training. - learner_state = learner_output.learner_state - - # Record the performance for the final evaluation run. - eval_performance = float(jnp.mean(episode_metrics[config.env.eval_metric])) - - # Measure absolute metric. - if config.arch.absolute_metric: - start_time = time.time() - - key_e, eval_key = jax.random.split(key_e, 2) - episode_metrics = absolute_metric_evaluator(unreplicate_n_dims(best_params.actor_params, 1), eval_key) - - elapsed_time = time.time() - start_time - steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) - - t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(episode_metrics, t, eval_step, LogEvent.ABSOLUTE) - - # Stop the logger. - logger.stop() - - return eval_performance - - - -@hydra.main(config_path="../../../configs", config_name="default_rec_ippo_seb.yaml", version_base="1.2") -def hydra_entry_point(cfg: DictConfig) -> float: - """Experiment entry point.""" - # Allow dynamic attributes. - OmegaConf.set_struct(cfg, False) - - # Run experiment. - eval_performance = run_experiment(cfg) - print(f"{Fore.CYAN}{Style.BRIGHT}IPPO experiment completed{Style.RESET_ALL}") - return eval_performance - - -if __name__ == "__main__": - hydra_entry_point() - -#learner_output.episode_metrics.keys() -#dict_keys(['episode_length', 'episode_return']) \ No newline at end of file diff --git a/mava/systems/sebulba/ppo/test.py b/mava/systems/sebulba/ppo/test.py deleted file mode 100644 index d1f34fccf..000000000 --- a/mava/systems/sebulba/ppo/test.py +++ /dev/null @@ -1,86 +0,0 @@ - -import copy -import time -from typing import Any, Dict, Tuple, List -import threading -import chex -import flax -import gym.vector -import gym.vector.async_vector_env -import hydra -import jax -import jax.numpy as jnp -import numpy as np -import optax -import queue -from collections import deque -from colorama import Fore, Style -from flax.core.frozen_dict import FrozenDict -from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState -from rich.pretty import pprint - -#from mava.evaluator import make_eval_fns -from mava.networks import FeedForwardActor as Actor -from mava.networks import FeedForwardValueNet as Critic -from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition #todo: change this -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, Observation -from mava.utils import make_env as environments -from mava.utils.checkpointing import Checkpointer -from mava.utils.jax_utils import ( - merge_leading_dims, - unreplicate_batch_dim, - unreplicate_n_dims, -) -from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import anakin_check_total_timesteps -from mava.utils.training import make_learning_rate -from mava.wrappers.episode_metrics import get_final_step_metrics -from flax import linen as nn -import gym -import rware -import lbforaging -from mava.wrappers import GymRwareWrapper, GymRecordEpisodeMetrics, _multiagent_worker_shared_memory, GymAgentIDWrapper, GymLBFWrapper -@hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") -def hydra_entry_point(cfg: DictConfig) -> float: - """Experiment entry point.""" - # Allow dynamic attributes. - - - OmegaConf.set_struct(cfg, False) - def f(): - base = gym.make(cfg.env.scenario) - base = GymLBFWrapper(base, cfg.env.use_individual_rewards, True) - base = GymAgentIDWrapper(base) - return GymRecordEpisodeMetrics(base) - - base = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names - [ - lambda: f() - for _ in range(3) - ], - worker=_multiagent_worker_shared_memory - ) - base.reset() - n = 0 - done = False - r = [0] * 3 - while not done: - n+= 1 - agents_view, reward, terminated, truncated, info = base.step([r, r]) - print(terminated, truncated) - done = np.logical_or(terminated, truncated).all() - print(n, done) - #metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) - base.close() - print(done) - - - #print(b) - #r = 1+1 - # Create a sample input - #env = gym.make(cfg.env.scenario) - #env.reset() - #a = env.step(jnp.ones((4))) - -hydra_entry_point() diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index b329241d9..dd77105a9 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -12,23 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import warnings -from typing import Dict, Tuple, Optional +from typing import Any, Callable, Dict, Optional, Tuple import gym import numpy as np -from numpy.typing import NDArray - from gym import spaces from gym.vector.utils import write_to_shared_memory -import sys +from numpy.typing import NDArray # Filter out the warnings warnings.filterwarnings("ignore", module="gym.utils.passive_env_checker") -class GymGenericWrapper(gym.Wrapper): - """Wrapper for rware gym environments""" +class GymRwareWrapper(gym.Wrapper): + """Wrapper for rware gym environments.""" def __init__( self, @@ -37,7 +36,6 @@ def __init__( add_global_state: bool = False, ): """Initialize the gym wrapper - Args: env (gym.env): gym env instance. use_individual_rewards (bool, optional): Use individual or group rewards. @@ -45,30 +43,26 @@ def __init__( add_global_state (bool, optional) : Create global observations. Defaults to False. """ super().__init__(env) - self._env = env + self._env = env self.use_individual_rewards = use_individual_rewards - self.add_global_state = add_global_state + self.add_global_state = add_global_state self.num_agents = len(self._env.action_space) - self.num_actions = self._env.action_space[ - 0 - ].n - - def reset( - self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple: - + self.num_actions = self._env.action_space[0].n + + def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple: + if seed is not None: self.env.seed(seed) - - agents_view, info = self._env.reset() + + agents_view, info = self._env.reset() info = {"actions_mask": self.get_actions_mask(info)} if self.add_global_state: info["global_obs"] = self.get_global_obs(agents_view) - + return np.array(agents_view), info - def step(self, actions: NDArray) -> Tuple: + def step(self, actions: NDArray) -> Tuple: agents_view, reward, terminated, truncated, info = self._env.step(actions) @@ -80,7 +74,7 @@ def step(self, actions: NDArray) -> Tuple: reward = np.array(reward) else: reward = np.array([np.array(reward).mean()] * self.num_agents) - + return agents_view, reward, terminated, truncated, info def get_actions_mask(self, info: Dict) -> NDArray: @@ -88,13 +82,9 @@ def get_actions_mask(self, info: Dict) -> NDArray: return np.array(info["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) - def get_global_obs(self, obs: NDArray): + def get_global_obs(self, obs: NDArray) -> NDArray: global_obs = np.concatenate(obs, axis=0) return np.tile(global_obs, (self.num_agents, 1)) - - - - class GymRecordEpisodeMetrics(gym.Wrapper): @@ -117,14 +107,14 @@ def reset(self) -> Tuple: "episode_length": self.running_count_episode_length, "is_terminal_step": True, } - + # Reset the metrics self.running_count_episode_return = 0.0 self.running_count_episode_length = 0 - + if "won_episode" in info: metrics["won_episode"] = info["won_episode"] - + info["metrics"] = metrics return agents_view, info @@ -140,17 +130,18 @@ def step(self, actions: NDArray) -> Tuple: metrics = { "episode_return": self.running_count_episode_return, "episode_length": self.running_count_episode_length, - "is_terminal_step": False, # We handle the True case in the reset function since this gets overwritten + "is_terminal_step": False, } if "won_episode" in info: metrics["won_episode"] = info["won_episode"] - + info["metrics"] = metrics - + return agents_view, reward, terminated, truncated, info - + + class GymAgentIDWrapper(gym.Wrapper): - """Add onehot agent IDs to observation.""" + """Add one hot agent IDs to observation.""" def __init__(self, env: gym.Env): super().__init__(env) @@ -164,7 +155,9 @@ def __init__(self, env: gym.Env): observation_space.shape, ) _new_obs_shape = (_obs_shape[0] + self.env.num_agents,) - _observation_boxs = [spaces.Box(low=_obs_low, high=_obs_high, shape=_new_obs_shape, dtype=_obs_dtype)] * self.env.num_agents + _observation_boxs = [ + spaces.Box(low=_obs_low, high=_obs_high, shape=_new_obs_shape, dtype=_obs_dtype) + ] * self.env.num_agents self.observation_space = spaces.Tuple(_observation_boxs) def reset(self) -> Tuple[np.ndarray, Dict]: @@ -178,9 +171,18 @@ def step(self, action: list) -> Tuple[np.ndarray, float, bool, bool, Dict]: obs, reward, terminated, truncated, info = self.env.step(action) obs = np.concatenate([self.agent_ids, obs], axis=1) return obs, reward, terminated, truncated, info - -def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): + +# Copied form https://github.com/openai/gym/blob/master/gym/vector/async_vector_env.py +# Modified to work with multiple agents +def _multiagent_worker_shared_memory( # noqa: CCR001 + index: int, + env_fn: Callable[[], Any], + pipe: Any, + parent_pipe: Any, + shared_memory: Any, + error_queue: Any, +) -> None: assert shared_memory is not None env = env_fn() observation_space = env.observation_space @@ -190,9 +192,7 @@ def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_me command, data = pipe.recv() if command == "reset": observation, info = env.reset(**data) - write_to_shared_memory( - observation_space, index, observation, shared_memory - ) + write_to_shared_memory(observation_space, index, observation, shared_memory) pipe.send(((None, info), True)) elif command == "step": @@ -203,14 +203,13 @@ def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_me truncated, info, ) = env.step(data) + # Handel the dones across all of envs and agents if np.logical_or(terminated, truncated).all(): old_observation, old_info = observation, info observation, info = env.reset() info["final_observation"] = old_observation info["final_info"] = old_info - write_to_shared_memory( - observation_space, index, observation, shared_memory - ) + write_to_shared_memory(observation_space, index, observation, shared_memory) pipe.send(((None, reward, terminated, truncated, info), True)) elif command == "seed": env.seed(data) @@ -235,9 +234,7 @@ def _multiagent_worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_me setattr(env, name, value) pipe.send((None, True)) elif command == "_check_spaces": - pipe.send( - ((data[0] == observation_space, data[1] == env.action_space), True) - ) + pipe.send(((data[0] == observation_space, data[1] == env.action_space), True)) else: raise RuntimeError( f"Received unknown command `{command}`. Must " From e5dd71bf35c22df29a58e0267597fbf58d254040 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 10 Jul 2024 15:18:57 +0100 Subject: [PATCH 034/125] chore: pre-commits --- mava/configs/arch/sebulba.yaml | 1 - mava/evaluator.py | 167 +++++++------- mava/systems/anakin/ppo/ff_ippo.py | 4 +- mava/systems/anakin/ppo/ff_mappo.py | 4 +- mava/systems/sebulba/ppo/ff_ippo.py | 327 ++++++++++++++++----------- mava/types.py | 7 +- mava/utils/make_env.py | 16 +- mava/utils/total_timestep_checker.py | 4 +- mava/wrappers/__init__.py | 7 +- mava/wrappers/episode_metrics.py | 2 +- mava/wrappers/gym.py | 2 +- 11 files changed, 310 insertions(+), 231 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index fd555f71e..b6a0a9699 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -15,4 +15,3 @@ absolute_metric: True # Whether the absolute metric should be computed. For more n_threads_per_executor: 1 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices - diff --git a/mava/evaluator.py b/mava/evaluator.py index ca0c8c9a7..984a42377 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import chex import flax.linen as nn import jax import jax.numpy as jnp +import numpy as np from flax.core.frozen_dict import FrozenDict from jumanji.env import Environment from omegaconf import DictConfig @@ -27,13 +28,13 @@ EvalFn, EvalState, ExperimentOutput, + Observation, RecActorApply, RNNEvalState, + RNNObservation, + SebulbaEvalFn, ) -from mava.types import Observation - -import numpy as np def get_anakin_ff_evaluator_fn( env: Environment, @@ -348,7 +349,7 @@ def get_sebulba_ff_evaluator_fn( apply_fn: ActorApply, config: DictConfig, log_win_rate: bool = False, -) -> EvalFn: +) -> SebulbaEvalFn: """Get the evaluator function for feedforward networks. Args: @@ -356,63 +357,69 @@ def get_sebulba_ff_evaluator_fn( apply_fn (callable): Network forward pass method. config (dict): Experiment configuration. """ + @jax.jit - def get_action( #todo explicetly put these on the learner? they should already be there + def get_action( # todo explicetly put these on the learner? they should already be there params: FrozenDict, observation: Observation, key: chex.PRNGKey, - ) -> Tuple: + ) -> chex.Array: """Get action.""" - + pi = apply_fn(params, observation) - + if config.arch.evaluation_greedy: action = pi.mode() else: action = pi.sample(seed=key) return action - def eval_episodes(params: FrozenDict, key : chex.PRNGKey) -> Dict: - - - + + def eval_episodes(params: FrozenDict, key: chex.PRNGKey) -> Any: + obs, info = env.reset() - dones = np.zeros(env.num_envs) # todo: jnp or np? - eval_metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) - + dones = np.full(env.num_envs, False) + eval_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) + while not dones.all(): - + key, policy_key = jax.random.split(key) - - obs = jax.device_put(jnp.stack(obs, axis = 1)) - action_mask = jax.device_put(np.stack(info["actions_mask"]) ) - + + obs = jax.device_put(jnp.stack(obs, axis=1)) + action_mask = jax.device_put(np.stack(info["actions_mask"])) + actions = get_action(params, Observation(obs, action_mask), policy_key) cpu_action = jax.device_get(actions) - obs, reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) - - next_metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) - + obs, reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0, 1)) + + next_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) + next_dones = next_metrics["is_terminal_step"] - - update_metric = lambda old_metric, new_metric : np.where(np.logical_and(next_dones, dones == False), new_metric, old_metric) - eval_metrics = jax.tree_map(update_metric, eval_metrics, next_metrics) - - dones = np.logical_or(dones, next_dones) + + update_flags = np.logical_and(next_dones, np.invert(dones)) + + update_metrics = lambda new_metric, old_metric, update_flags=update_flags: np.where( + (update_flags), new_metric, old_metric + ) + + eval_metrics = jax.tree_map(update_metrics, next_metrics, eval_metrics) + + dones = np.logical_or(dones, next_dones) eval_metrics.pop("is_terminal_step") return eval_metrics - + return eval_episodes + def get_sebulba_rnn_evaluator_fn( env: Environment, apply_fn: RecActorApply, config: DictConfig, scanned_rnn: nn.Module, log_win_rate: bool = False, -) -> EvalFn: +) -> SebulbaEvalFn: """Get the evaluator function for feedforward networks. Args: @@ -420,76 +427,82 @@ def get_sebulba_rnn_evaluator_fn( apply_fn (callable): Network forward pass method. config (dict): Experiment configuration. """ + @jax.jit - def get_action( #todo explicetly put these on the learner? they should already be there + def get_action( # todo explicetly put these on the learner? they should already be there params: FrozenDict, - observation: Observation, - hstate : chex.Array, + observation: RNNObservation, + hstate: chex.Array, key: chex.PRNGKey, - ) -> Tuple: + ) -> Tuple[chex.Array, chex.Array]: """Get action.""" - + hstate, pi = apply_fn(params, hstate, observation) - + if config.arch.evaluation_greedy: action = pi.mode() else: action = pi.sample(seed=key) return action, hstate - def eval_episodes(params: FrozenDict, key : chex.PRNGKey) -> Dict: - - - + + def eval_episodes(params: FrozenDict, key: chex.PRNGKey) -> Any: + obs, info = env.reset() - eval_metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) - + eval_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) + hstate = scanned_rnn.initialize_carry( - (env.num_envs, config.system.num_agents), config.network.hidden_state_dim + (env.num_envs, config.system.num_agents), config.network.hidden_state_dim ) - - dones = jnp.zeros((env.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) - + + dones = jnp.full((env.num_envs, config.system.num_agents), False) + while not dones.all(): - + key, policy_key = jax.random.split(key) - - obs = jax.device_put(jnp.stack(obs, axis = 1)) - action_mask = jax.device_put(np.stack(info["actions_mask"]) ) - - obs, action_mask, dones = jax.tree_map(lambda x : x[jnp.newaxis, :], (obs, action_mask, dones)) - - - actions, hstate = get_action(params, (Observation(obs, action_mask), dones), hstate, policy_key) + + obs = jax.device_put(jnp.stack(obs, axis=1)) + action_mask = jax.device_put(np.stack(info["actions_mask"])) + + obs, action_mask, dones = jax.tree_map( + lambda x: x[jnp.newaxis, :], (obs, action_mask, dones) + ) + + actions, hstate = get_action( + params, (Observation(obs, action_mask), dones), hstate, policy_key + ) cpu_action = jax.device_get(actions) - obs, reward, terminated, truncated, info = env.step(cpu_action[0].swapaxes(0,1)) - - next_metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) - + obs, reward, terminated, truncated, info = env.step(cpu_action[0].swapaxes(0, 1)) + + next_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) + next_dones = np.logical_or(terminated, truncated) - - per_env_done = np.all(np.logical_and(next_dones, dones[0] == False),axis = 1) - - update_metric = lambda old_metric, new_metric : np.where(per_env_done, new_metric, old_metric) - eval_metrics = jax.tree_map(update_metric, eval_metrics, next_metrics) - - dones = np.logical_or(dones, next_dones) + + update_flags = np.all(np.logical_and(next_dones, np.invert(dones[0])), axis=1) + + update_metrics = lambda new_metric, old_metric, update_flags=update_flags: np.where( + (update_flags), new_metric, old_metric + ) + + eval_metrics = jax.tree_map(update_metrics, next_metrics, eval_metrics) + + dones = np.logical_or(dones, next_dones) eval_metrics.pop("is_terminal_step") return eval_metrics - + return eval_episodes def make_sebulba_eval_fns( - eval_env_fn: callable, + eval_env_fn: Callable, network_apply_fn: Union[ActorApply, RecActorApply], config: DictConfig, - add_global_state : bool = False, + add_global_state: bool = False, use_recurrent_net: bool = False, scanned_rnn: Optional[nn.Module] = None, -) -> Tuple[EvalFn, EvalFn]: +) -> Tuple[SebulbaEvalFn, SebulbaEvalFn]: """Initialize evaluator functions for reinforcement learning. Args: @@ -501,14 +514,16 @@ def make_sebulba_eval_fns( Required if `use_recurrent_net` is True. Defaults to None. Returns: - Tuple[EvalFn, EvalFn]: A tuple of two evaluation functions: + Tuple[SebulbaEvalFn, SebulbaEvalFn]: A tuple of two evaluation functions: one for use during training and one for absolute metrics. Raises: AssertionError: If `use_recurrent_net` is True but `scanned_rnn` is not provided. """ - eval_env, absolute_eval_env = eval_env_fn(config, config.arch.num_eval_episodes, add_global_state = add_global_state), eval_env_fn(config, config.arch.num_eval_episodes * 10, add_global_state = add_global_state) - + eval_env, absolute_eval_env = eval_env_fn( + config, config.arch.num_eval_episodes, add_global_state=add_global_state + ), eval_env_fn(config, config.arch.num_eval_episodes * 10, add_global_state=add_global_state) + # Check if win rate is required for evaluation. log_win_rate = config.env.log_win_rate # Vmap it over number of agents and create evaluator_fn. @@ -536,4 +551,4 @@ def make_sebulba_eval_fns( absolute_eval_env, network_apply_fn, config, log_win_rate # type: ignore ) - return evaluator, absolute_metric_evaluator \ No newline at end of file + return evaluator, absolute_metric_evaluator diff --git a/mava/systems/anakin/ppo/ff_ippo.py b/mava/systems/anakin/ppo/ff_ippo.py index f0803de4d..408bdf36d 100644 --- a/mava/systems/anakin/ppo/ff_ippo.py +++ b/mava/systems/anakin/ppo/ff_ippo.py @@ -462,7 +462,9 @@ def run_experiment(_config: DictConfig) -> float: # Setup evaluator. # One key per device for evaluation. eval_keys = jax.random.split(key_e, n_devices) - evaluator, absolute_metric_evaluator = make_anakin_eval_fns(eval_env, actor_network.apply, config) + evaluator, absolute_metric_evaluator = make_anakin_eval_fns( + eval_env, actor_network.apply, config + ) # Calculate total timesteps. config = anakin_check_total_timesteps(config) diff --git a/mava/systems/anakin/ppo/ff_mappo.py b/mava/systems/anakin/ppo/ff_mappo.py index 90fad5767..93d3f2c0b 100644 --- a/mava/systems/anakin/ppo/ff_mappo.py +++ b/mava/systems/anakin/ppo/ff_mappo.py @@ -459,7 +459,9 @@ def run_experiment(_config: DictConfig) -> float: # Setup evaluator. # One key per device for evaluation. eval_keys = jax.random.split(key_e, n_devices) - evaluator, absolute_metric_evaluator = make_anakin_eval_fns(eval_env, actor_network.apply, config) + evaluator, absolute_metric_evaluator = make_anakin_eval_fns( + eval_env, actor_network.apply, config + ) # Calculate total timesteps. config = anakin_check_total_timesteps(config) diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 153f9e4a9..cf598770f 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -13,9 +13,12 @@ # limitations under the License. import copy -import time -from typing import Any, Dict, Tuple, List +import queue import threading +import time +from collections import deque +from typing import Any, Dict, List, Tuple + import chex import flax import hydra @@ -24,46 +27,47 @@ import jax.numpy as jnp import numpy as np import optax -import queue -from collections import deque from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict from omegaconf import DictConfig, OmegaConf from optax._src.base import OptState from rich.pretty import pprint -from mava.evaluator import make_sebulba_eval_fns as make_eval_fns +from mava.evaluator import make_sebulba_eval_fns as make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, Observation +from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.types import ( + ActorApply, + CriticApply, + ExperimentOutput, + Observation, + SebulbaLearnerFn, +) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer -from mava.utils.jax_utils import ( - merge_leading_dims, - unreplicate_batch_dim, - unreplicate_n_dims, -) +from mava.utils.jax_utils import merge_leading_dims, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.total_timestep_checker import sebulba_check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics -def rollout( +def rollout( key: chex.PRNGKey, config: DictConfig, rollout_queue: queue.Queue, params_queue: queue.Queue, apply_fns: Tuple, learner_devices: List, - actor_device_id : int): - - #setup + actor_device_id: int, +) -> None: + + # setup env = environments.make_gym_env(config, config.arch.num_envs) current_actor_device = jax.devices()[actor_device_id] actor_apply_fn, critic_apply_fn = apply_fns - + # Define the util functions: select action function and prepare data to share it with learner. @jax.jit def get_action_and_value( @@ -73,8 +77,8 @@ def get_action_and_value( ) -> Tuple: """Get action and value.""" key, subkey = jax.random.split(key) - - actor_policy = actor_apply_fn(params.actor_params, observation) # TODO: check vmapiing + + actor_policy = actor_apply_fn(params.actor_params, observation) # TODO: check vmapiing action = actor_policy.sample(seed=subkey) log_prob = actor_policy.log_prob(action) @@ -85,35 +89,43 @@ def get_action_and_value( params_queue_get_time: deque = deque(maxlen=1) rollout_time: deque = deque(maxlen=1) rollout_queue_put_time: deque = deque(maxlen=1) - - next_obs , info = env.reset() + + next_obs, info = env.reset() next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) - - move_to_device = lambda x : jax.device_put(x, device = current_actor_device) + + move_to_device = lambda x: jax.device_put(x, device=current_actor_device) + + shard_split_payload = lambda x, axis: jax.device_put_sharded( + jnp.split(x, len(learner_devices), axis=axis), devices=learner_devices + ) # Loop till the learner has finished training - for update in range(config.system.num_updates): + for _update in range(config.system.num_updates): inference_time: float = 0 storage_time: float = 0 env_send_time: float = 0 - + # Get the latest parameters from the learner params_queue_get_time_start = time.time() params = params_queue.get() params_queue_get_time.append(time.time() - params_queue_get_time_start) - - # Rollout + + # Rollout rollout_time_start = time.time() storage: List = [] # Loop over the rollout length for _ in range(0, config.system.rollout_length): - + # Cached for transition - cached_next_obs = move_to_device(jnp.stack(next_obs, axis = 1)) # (num_envs, num_agents, ...) - cached_next_dones = move_to_device(next_dones) # (num_envs, num_agents) - cashed_action_mask = move_to_device(np.stack(info["actions_mask"])) # (num_envs, num_agents, num_actions) - + cached_next_obs = move_to_device( + jnp.stack(next_obs, axis=1) + ) # (num_envs, num_agents, ...) + cached_next_dones = move_to_device(next_dones) # (num_envs, num_agents) + cashed_action_mask = move_to_device( + np.stack(info["actions_mask"]) + ) # (num_envs, num_agents, num_actions) + full_observation = Observation(cached_next_obs, cashed_action_mask) # Get action and value inference_time_start = time.time() @@ -123,20 +135,21 @@ def get_action_and_value( value, key, ) = get_action_and_value(params, full_observation, key) - - + # Step the environment inference_time += time.time() - inference_time_start env_send_time_start = time.time() cpu_action = jax.device_get(action) - next_obs, next_reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0,1)) # (num_env, num_agents) --> (num_agents, num_env) + next_obs, next_reward, terminated, truncated, info = env.step( + cpu_action.swapaxes(0, 1) + ) # (num_env, num_agents) --> (num_agents, num_env) env_send_time += time.time() - env_send_time_start - + # Prepare the data storage_time_start = time.time() - next_dones = np.logical_or(terminated, truncated) - metrics = jax.tree_map(lambda *x : jnp.asarray(x), *info["metrics"]) # Stack the metrics - + next_dones = np.logical_or(terminated, truncated) + metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) # Stack the metrics + # Append data to storage storage.append( PPOTransition( @@ -146,68 +159,75 @@ def get_action_and_value( reward=next_reward, log_prob=log_prob, obs=full_observation, - info=metrics, - ) + info=metrics, + ) ) storage_time += time.time() - storage_time_start - rollout_time.append(time.time() - rollout_time_start) - + rollout_time.append(time.time() - rollout_time_start) + parse_timer = time.time() - - # Prepare data to share with learner - #[PPOTransition() * rollout_len] --> PPOTransition[done = (rollout_len, num_envs, num_agents), action = (rollout_len, num_envs, num_agents, num_actions), ...] - stacked_storage = jax.tree_map( lambda *xs : jnp.stack(xs), *storage) - + + # Prepare data to share with learner + # [PPOTransition() * rollout_len] --> PPOTransition[done=(rollout_len, num_envs, num_agents) + # , action=(rollout_len, num_envs, num_agents, num_actions), ...] + stacked_storage = jax.tree_map(lambda *xs: jnp.stack(xs), *storage) # Split the arrays over the different learner_devices on the num_envs axis - shard_split_payload= lambda x, axis : jax.device_put_sharded(jnp.split(x, len(learner_devices), axis=axis), devices=learner_devices) - sharded_storage = jax.tree_map(lambda x : shard_split_payload(x, 1) , stacked_storage) # (num_learner_devices, rollout_len, num_envs, num_agents, ...) - + sharded_storage = jax.tree_map( + lambda x: shard_split_payload(x, 1), stacked_storage + ) # (num_learner_devices, rollout_len, num_envs, num_agents, ...) + # (num_learner_devices, num_envs, num_agents, ...) - sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis = 1), 0) - sharded_next_action_mask = shard_split_payload(np.stack(info["actions_mask"]), 0) + sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis=1), 0) + sharded_next_action_mask = shard_split_payload(np.stack(info["actions_mask"]), 0) sharded_next_done = shard_split_payload(next_dones, 0) - + # Pack the obs and action mask payload_obs = Observation(sharded_next_obs, sharded_next_action_mask) # For debugging - speed_info = { + speed_info = { # noqa F841 "rollout_time": np.mean(rollout_time), "params_queue_get_time": np.mean(params_queue_get_time), "action_inference": inference_time, "storage_time": storage_time, "env_step_time": env_send_time, - "rollout_queue_put_time": np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0, - "parse_time" : time.time() - parse_timer, - } - #print(speed_info) - + "rollout_queue_put_time": ( + np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0 + ), + "parse_time": time.time() - parse_timer, + } + payload = ( sharded_storage, payload_obs, sharded_next_done, ) - + # Put data in the rollout queue to share it with the learner rollout_queue_put_time_start = time.time() rollout_queue.put(payload) rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) - + def get_learner_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, -) -> LearnerFn[LearnerState]: +) -> SebulbaLearnerFn[LearnerState, PPOTransition]: """Get the learner function.""" # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: Observation, last_dones : chex.Array) -> Tuple[LearnerState, Tuple]: + def _update_step( + learner_state: LearnerState, + traj_batch: PPOTransition, + last_obs: Observation, + last_dones: chex.Array, + ) -> Tuple[LearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -225,7 +245,7 @@ def _update_step(learner_state: LearnerState, traj_batch : PPOTransition, last_o _ (Any): The current metrics info. """ - def _calculate_gae( #todo: lake sure this is appropriate + def _calculate_gae( # todo: lake sure this is appropriate traj_batch: PPOTransition, last_val: chex.Array, last_done: chex.Array ) -> Tuple[chex.Array, chex.Array]: def _get_advantages( @@ -246,7 +266,7 @@ def _get_advantages( unroll=16, ) return advantages, advantages + traj_batch.value - + # CALCULATE ADVANTAGE params, opt_states, key, _, _ = learner_state last_val = critic_apply_fn(params.critic_params, last_obs) @@ -337,7 +357,8 @@ def _critic_loss_fn( # available at https://tinyurl.com/26tdzs5x # pmean over devices. actor_grads, actor_loss_info = jax.lax.pmean( - (actor_grads, actor_loss_info), axis_name="device" #todo: pmean over learner devices not all + (actor_grads, actor_loss_info), + axis_name="device", # todo: pmean over learner devices not all ) # pmean over devices. @@ -376,7 +397,12 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) # SHUFFLE MINIBATCHES - batch_size = config.system.rollout_length * (config.arch.num_envs // len(config.arch.learner_device_ids)) * len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor + batch_size = ( + config.system.rollout_length + * (config.arch.num_envs // len(config.arch.learner_device_ids)) + * len(config.arch.executor_device_ids) + * config.arch.n_threads_per_executor + ) permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) @@ -406,7 +432,12 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs: chex.Array, last_dones : chex.Array) -> ExperimentOutput[LearnerState]: + def learner_fn( + learner_state: LearnerState, + traj_batch: PPOTransition, + last_obs: chex.Array, + last_dones: chex.Array, + ) -> ExperimentOutput[LearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -423,7 +454,9 @@ def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs """ # todo: add update_batch_size - learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch , last_obs, last_dones) + learner_state, (episode_info, loss_info) = _update_step( + learner_state, traj_batch, last_obs, last_dones + ) return ExperimentOutput( learner_state=learner_state, @@ -436,15 +469,17 @@ def learner_fn(learner_state: LearnerState, traj_batch : PPOTransition, last_obs def learner_setup( keys: chex.Array, config: DictConfig, learner_devices: List -) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: +) -> Tuple[ + SebulbaLearnerFn[LearnerState, PPOTransition], Tuple[ActorApply, CriticApply], LearnerState +]: """Initialise learner_fn, network, optimiser, environment and states.""" - - #create temporory envoirnments. - env = environments.make_gym_env(config, config.arch.num_envs) + + # create temporory envoirnments. + env = environments.make_gym_env(config, config.arch.num_envs) # Get number of agents and actions. action_space = env.single_action_space config.system.num_agents = len(action_space) - config.system.num_actions = action_space[0].n + config.system.num_actions = action_space[0].n # PRNG keys. key, actor_net_key, critic_net_key = keys @@ -493,7 +528,7 @@ def learner_setup( # Get batched iterated update and replicate it to pmap it over learner cores. learn = get_learner_fn(apply_fns, update_fns, config) - learn = jax.pmap(learn, axis_name="device", devices = learner_devices) + learn = jax.pmap(learn, axis_name="device", devices=learner_devices) # Load model from checkpoint if specified. if config.logger.checkpointing.load_model: @@ -522,49 +557,54 @@ def learner_setup( return learn, apply_fns, init_learner_state -def run_experiment(_config: DictConfig) -> float: +def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 """Runs experiment.""" config = copy.deepcopy(_config) - devices = jax.devices() + devices = jax.devices() learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] # PRNG keys. key, key_e, actor_net_key, critic_net_key = jax.random.split( jax.random.PRNGKey(config.system.seed), num=4 ) - + # Sanity check of config assert ( config.arch.num_envs % len(config.arch.learner_device_ids) == 0 - ), "The number of environments must to be divisible by the number of learners " - + ), "The number of environments must to be divisible by the number of learners " + assert ( int(config.arch.num_envs / len(config.arch.learner_device_ids)) * config.arch.n_threads_per_executor % config.system.num_minibatches == 0 - ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" + ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" - # Setup learner. - learn, apply_fns , learner_state = learner_setup( - (key ,actor_net_key, critic_net_key), config, learner_devices + learn, apply_fns, learner_state = learner_setup( + (key, actor_net_key, critic_net_key), config, learner_devices ) # Setup evaluator. # One key per device for evaluation. - evaluator, absolute_metric_evaluator = make_eval_fns(environments.make_gym_env, apply_fns[0], config) #todo: make this more generic + evaluator, absolute_metric_evaluator = make_eval_fns( + environments.make_gym_env, apply_fns[0], config + ) # todo: make this more generic # Calculate total timesteps. - config = sebulba_check_total_timesteps(config) + config = sebulba_check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." # Calculate number of updates per evaluation. - config.system.num_updates_per_eval, remaining_updates = divmod(config.system.num_updates , config.arch.num_evaluation) - config.arch.num_evaluation += (remaining_updates != 0) # Add an evaluation step if the num_updates is not a multiple of num_evaluation + config.system.num_updates_per_eval, remaining_updates = divmod( + config.system.num_updates, config.arch.num_evaluation + ) + config.arch.num_evaluation += ( + remaining_updates != 0 + ) # Add an evaluation step if the num_updates is not a multiple of num_evaluation steps_per_rollout = ( len(config.arch.executor_device_ids) * config.arch.n_threads_per_executor @@ -587,18 +627,18 @@ def run_experiment(_config: DictConfig) -> float: model_name=config.logger.system_name, **config.logger.checkpointing.save_args, # Checkpoint args ) - + # Executor setup and launch. unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) params_queues: List = [] rollout_queues: List = [] - for d_idx, d_id in enumerate( # Loop through each executor device + for _d_idx, d_id in enumerate( # Loop through each executor device config.arch.executor_device_ids ): # Replicate params per executor device device_params = jax.device_put(unreplicated_params, devices[d_id]) # Loop through each executor thread - for thread_id in range(config.arch.n_threads_per_executor): + for _thread_id in range(config.arch.n_threads_per_executor): params_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1)) params_queues[-1].put(device_params) @@ -613,27 +653,30 @@ def run_experiment(_config: DictConfig) -> float: learner_devices, d_id, ), - ).start() #todo : Use a process instead of a thread? threads are limited by pything's GIL and they only run on a single core , processes have a bogger overhead (max num_env for optimal performance?) - - + ).start() + # Run experiment for the total number of updates. max_episode_return = jnp.float32(0.0) best_params = None - for eval_step in range(config.arch.num_evaluation): + for eval_step in range(config.arch.num_evaluation): training_start_time = time.time() learner_speeds = [] rollout_times = [] - + episode_metrics = [] train_metrics = [] - - # Make sure that the - num_updates_in_eval = config.system.num_updates_per_eval if eval_step != config.arch.num_evaluation - 1 else remaining_updates - for update in range(num_updates_in_eval): + + # Make sure that the + num_updates_in_eval = ( + config.system.num_updates_per_eval + if eval_step != config.arch.num_evaluation - 1 + else remaining_updates + ) + for _update in range(num_updates_in_eval): sharded_storages = [] sharded_next_obss = [] sharded_next_dones = [] - + rollout_start_time = time.time() # Loop through each executor device for d_idx, _ in enumerate(config.arch.executor_device_ids): @@ -648,24 +691,28 @@ def run_experiment(_config: DictConfig) -> float: sharded_storages.append(sharded_storage) sharded_next_obss.append(sharded_next_obs) sharded_next_dones.append(sharded_next_done) - + rollout_times.append(time.time() - rollout_start_time) - - - # Concatinate the returned trajectories on the n_env axis - sharded_storages = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 2), *sharded_storages) - sharded_next_obss = jax.tree_map(lambda *x : jnp.concatenate(x, axis = 1), *sharded_next_obss) - sharded_next_dones = jnp.concatenate(sharded_next_dones, axis = 1) + # Concatinate the returned trajectories on the n_env axis + sharded_storages = jax.tree_map( + lambda *x: jnp.concatenate(x, axis=2), *sharded_storages + ) + sharded_next_obss = jax.tree_map( + lambda *x: jnp.concatenate(x, axis=1), *sharded_next_obss + ) + sharded_next_dones = jnp.concatenate(sharded_next_dones, axis=1) learner_start_time = time.time() - learner_output = learn(learner_state, sharded_storages, sharded_next_obss, sharded_next_dones) + learner_output = learn( + learner_state, sharded_storages, sharded_next_obss, sharded_next_dones + ) learner_speeds.append(time.time() - learner_start_time) - + # Stack the metrics episode_metrics.append(learner_output.episode_metrics) train_metrics.append(learner_output.train_metrics) - + # Send updated params to executors unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) for d_idx, d_id in enumerate(config.arch.executor_device_ids): @@ -675,28 +722,33 @@ def run_experiment(_config: DictConfig) -> float: device_params ) - - # Log the results of the training. elapsed_time = time.time() - training_start_time t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics = jax.tree_map(lambda *x : np.asarray(x), *episode_metrics) - episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time - + episode_metrics = jax.tree_map(lambda *x: np.asarray(x), *episode_metrics) + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + # Separately log timesteps, actoring metrics and training metrics. - speed_info = {"total_time" : elapsed_time, "rollout_time" : np.sum(rollout_times), "learner_time" : np.sum(learner_speeds), "timestep" : t} - logger.log(speed_info , t, eval_step, LogEvent.MISC) + speed_info = { + "total_time": elapsed_time, + "rollout_time": np.sum(rollout_times), + "learner_time": np.sum(learner_speeds), + "timestep": t, + } + logger.log(speed_info, t, eval_step, LogEvent.MISC) if ep_completed: # only log episode metrics if an episode was completed in the rollout. - logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - train_metrics = jax.tree_map(lambda *x : np.asarray(x), *train_metrics) + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + train_metrics = jax.tree_map(lambda *x: np.asarray(x), *train_metrics) logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) - # Evaluation on the learner + # Evaluation on the learner evaluation_start_timer = time.time() key_e, eval_key = jax.random.split(key_e, 2) - episode_metrics = evaluator(unreplicate_n_dims(learner_output.learner_state.params.actor_params, 1 ), eval_key) - + episode_metrics = evaluator( + unreplicate_n_dims(learner_output.learner_state.params.actor_params, 1), eval_key + ) + # Log the results of the evaluation. elapsed_time = time.time() - evaluation_start_timer episode_return = jnp.mean(episode_metrics["episode_return"]) @@ -704,7 +756,7 @@ def run_experiment(_config: DictConfig) -> float: steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time logger.log(episode_metrics, t, eval_step, LogEvent.EVAL) - + if save_checkpoint: # Save checkpoint of learner state checkpointer.save( @@ -712,15 +764,15 @@ def run_experiment(_config: DictConfig) -> float: unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state, 1), episode_return=episode_return, ) - + if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(learner_output.learner_state.params) + best_params = copy.deepcopy(learner_output.learner_state.params.actor_params) max_episode_return = episode_return - + # Update runner state to continue training. learner_state = learner_output.learner_state - - # Record the performance for the final evaluation run. + + # Record the performance for the final evaluation run. eval_performance = float(jnp.mean(episode_metrics[config.env.eval_metric])) # Measure absolute metric. @@ -728,11 +780,11 @@ def run_experiment(_config: DictConfig) -> float: start_time = time.time() key_e, eval_key = jax.random.split(key_e, 2) - episode_metrics = absolute_metric_evaluator(unreplicate_n_dims(best_params.actor_params, 1), eval_key) + episode_metrics = absolute_metric_evaluator(unreplicate_n_dims(best_params, 1), eval_key) elapsed_time = time.time() - start_time steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) - + t = int(steps_per_rollout * (eval_step + 1)) episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time logger.log(episode_metrics, t, eval_step, LogEvent.ABSOLUTE) @@ -743,8 +795,9 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance - -@hydra.main(config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2") +@hydra.main( + config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2" +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. @@ -759,5 +812,5 @@ def hydra_entry_point(cfg: DictConfig) -> float: if __name__ == "__main__": hydra_entry_point() -#learner_output.episode_metrics.keys() -#dict_keys(['episode_length', 'episode_return']) \ No newline at end of file +# learner_output.episode_metrics.keys() +# dict_keys(['episode_length', 'episode_return']) diff --git a/mava/types.py b/mava/types.py index c6a2cf6aa..02d2bae90 100644 --- a/mava/types.py +++ b/mava/types.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Optional +from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar import chex from flax.core.frozen_dict import FrozenDict @@ -81,6 +81,7 @@ class RNNEvalState(NamedTuple): # `MavaState` is the main type passed around in our systems. It is often used as a scan carry. # Types like: `EvalState` | `LearnerState` (mava/systems//types.py) are `MavaState`s. MavaState = TypeVar("MavaState") +MavaTransition = TypeVar("MavaTransition") class ExperimentOutput(NamedTuple, Generic[MavaState]): @@ -92,7 +93,11 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]): LearnerFn = Callable[[MavaState], ExperimentOutput[MavaState]] +SebulbaLearnerFn = Callable[ + [MavaState, MavaTransition, chex.Array, chex.Array], ExperimentOutput[MavaState] +] EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[MavaState]] +SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict] ActorApply = Callable[[FrozenDict, Observation], Distribution] CriticApply = Callable[[FrozenDict, Observation], Value] diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index df769d8c7..2330674f0 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -22,7 +22,6 @@ import jumanji import matrax from gigastep import ScenarioBuilder -import lbforaging from jaxmarl.environments.smax import map_name_to_scenario from jumanji.env import Environment from jumanji.environments.routing.cleaner.generator import ( @@ -45,16 +44,16 @@ CleanerWrapper, ConnectorWrapper, GigastepWrapper, + GymAgentIDWrapper, GymRecordEpisodeMetrics, GymRwareWrapper, - GymAgentIDWrapper, - _multiagent_worker_shared_memory, LbfWrapper, MabraxWrapper, MatraxWrapper, RecordEpisodeMetrics, RwareWrapper, SmaxWrapper, + _multiagent_worker_shared_memory, ) # Registry mapping environment names to their generator and wrapper classes. @@ -211,7 +210,9 @@ def make_gigastep_env( def make_gym_env( - config: DictConfig, num_env : int, add_global_state: bool = False, + config: DictConfig, + num_env: int, + add_global_state: bool = False, ) -> Environment: # todo : create the appropriate annotation for the sync vector """ Create a Gym environment. @@ -238,11 +239,8 @@ def create_gym_env( return wrapped_env envs = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names - [ - lambda: create_gym_env(config, add_global_state) - for _ in range(num_env) - ], - worker=_multiagent_worker_shared_memory + [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)], + worker=_multiagent_worker_shared_memory, ) return envs diff --git a/mava/utils/total_timestep_checker.py b/mava/utils/total_timestep_checker.py index fd90b7436..744451d1b 100644 --- a/mava/utils/total_timestep_checker.py +++ b/mava/utils/total_timestep_checker.py @@ -68,7 +68,7 @@ def sebulba_check_total_timesteps(config: DictConfig) -> DictConfig: // config.system.rollout_length // config.arch.num_envs // config.arch.n_threads_per_executor - // len(config.arch.executor_device_ids) + // len(config.arch.executor_device_ids) ) print( f"{Fore.RED}{Style.BRIGHT} Changing the number of updates " @@ -76,4 +76,4 @@ def sebulba_check_total_timesteps(config: DictConfig) -> DictConfig: + " for a specific number of updates, please set total_timesteps to None!" + f"{Style.RESET_ALL}" ) - return config \ No newline at end of file + return config diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 4a4eb6ed0..ee8fdf186 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,7 +15,12 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper -from mava.wrappers.gym import GymRecordEpisodeMetrics, GymRwareWrapper, GymAgentIDWrapper, _multiagent_worker_shared_memory +from mava.wrappers.gym import ( + GymAgentIDWrapper, + GymRecordEpisodeMetrics, + GymRwareWrapper, + _multiagent_worker_shared_memory, +) from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, diff --git a/mava/wrappers/episode_metrics.py b/mava/wrappers/episode_metrics.py index a46dc1b91..a2b0fdb37 100644 --- a/mava/wrappers/episode_metrics.py +++ b/mava/wrappers/episode_metrics.py @@ -75,7 +75,7 @@ def step( # Previous episode return/length until done and then the next episode return. episode_return_info = state.episode_return * not_done + new_episode_return * done episode_length_info = state.episode_length * not_done + new_episode_length * done - + timestep.extras["episode_metrics"] = { "episode_return": episode_return_info, "episode_length": episode_length_info, diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index dd77105a9..b5f89b45f 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -245,4 +245,4 @@ def _multiagent_worker_shared_memory( # noqa: CCR001 error_queue.put((index,) + sys.exc_info()[:2]) pipe.send((None, False)) finally: - env.close() \ No newline at end of file + env.close() From af24082ab3ccd4ac878edd9de9e3e3ed7fa4b9f1 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 13 Jul 2024 23:38:03 +0100 Subject: [PATCH 035/125] fix: fix the num_updates_in_eval in the last eval --- mava/systems/sebulba/ppo/ff_ippo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index cf598770f..d8893ded8 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -666,11 +666,11 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 episode_metrics = [] train_metrics = [] - # Make sure that the + # Full or partial last eval step. num_updates_in_eval = ( - config.system.num_updates_per_eval - if eval_step != config.arch.num_evaluation - 1 - else remaining_updates + remaining_updates + if eval_step == config.arch.num_evaluation - 1 and remaining_updates + else config.system.num_updates_per_eval ) for _update in range(num_updates_in_eval): sharded_storages = [] From 32ac3890603fc0040bf4bfacc6efb88ba2e2f7f0 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 16 Jul 2024 10:58:05 +0100 Subject: [PATCH 036/125] fix: fixed the num evals cacls --- mava/systems/sebulba/ppo/ff_ippo.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index d8893ded8..71e4e31d3 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -597,11 +597,9 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." - # Calculate number of updates per evaluation. - config.system.num_updates_per_eval, remaining_updates = divmod( - config.system.num_updates, config.arch.num_evaluation - ) + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + config.arch.num_evaluation, remaining_updates = divmod(config.system.num_updates , config.system.num_updates_per_eval) config.arch.num_evaluation += ( remaining_updates != 0 ) # Add an evaluation step if the num_updates is not a multiple of num_evaluation From 45ca5875db7b05e34013bf485636311c9fcec2d4 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 16 Jul 2024 11:04:59 +0100 Subject: [PATCH 037/125] chore : pre commit --- mava/systems/sebulba/ppo/ff_ippo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 71e4e31d3..a184414d9 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -599,7 +599,9 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 ), "Number of updates per evaluation must be less than total number of updates." # Calculate number of updates per evaluation. config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation - config.arch.num_evaluation, remaining_updates = divmod(config.system.num_updates , config.system.num_updates_per_eval) + config.arch.num_evaluation, remaining_updates = divmod( + config.system.num_updates, config.system.num_updates_per_eval + ) config.arch.num_evaluation += ( remaining_updates != 0 ) # Add an evaluation step if the num_updates is not a multiple of num_evaluation From d6944984146fd1975924453efb28307af09c6836 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 16 Jul 2024 11:12:34 +0100 Subject: [PATCH 038/125] chore: created the anakin and sebulba folders --- mava/systems/{ => anakin}/ppo/__init__.py | 0 mava/systems/{ => anakin}/ppo/ff_ippo.py | 2 +- mava/systems/{ => anakin}/ppo/ff_mappo.py | 2 +- mava/systems/{ => anakin}/ppo/rec_ippo.py | 2 +- mava/systems/{ => anakin}/ppo/rec_mappo.py | 2 +- mava/systems/{ => anakin}/ppo/types.py | 0 mava/systems/{ => anakin}/q_learning/__init__.py | 0 mava/systems/{ => anakin}/q_learning/rec_iql.py | 0 mava/systems/{ => anakin}/q_learning/types.py | 0 mava/systems/{ => anakin}/sac/__init__.py | 0 mava/systems/{ => anakin}/sac/ff_isac.py | 0 mava/systems/{ => anakin}/sac/ff_masac.py | 0 mava/systems/{ => anakin}/sac/types.py | 0 mava/systems/sebulba/ppo/ff_ippo.py | 0 14 files changed, 4 insertions(+), 4 deletions(-) rename mava/systems/{ => anakin}/ppo/__init__.py (100%) rename mava/systems/{ => anakin}/ppo/ff_ippo.py (99%) rename mava/systems/{ => anakin}/ppo/ff_mappo.py (99%) rename mava/systems/{ => anakin}/ppo/rec_ippo.py (99%) rename mava/systems/{ => anakin}/ppo/rec_mappo.py (99%) rename mava/systems/{ => anakin}/ppo/types.py (100%) rename mava/systems/{ => anakin}/q_learning/__init__.py (100%) rename mava/systems/{ => anakin}/q_learning/rec_iql.py (100%) rename mava/systems/{ => anakin}/q_learning/types.py (100%) rename mava/systems/{ => anakin}/sac/__init__.py (100%) rename mava/systems/{ => anakin}/sac/ff_isac.py (100%) rename mava/systems/{ => anakin}/sac/ff_masac.py (100%) rename mava/systems/{ => anakin}/sac/types.py (100%) create mode 100644 mava/systems/sebulba/ppo/ff_ippo.py diff --git a/mava/systems/ppo/__init__.py b/mava/systems/anakin/ppo/__init__.py similarity index 100% rename from mava/systems/ppo/__init__.py rename to mava/systems/anakin/ppo/__init__.py diff --git a/mava/systems/ppo/ff_ippo.py b/mava/systems/anakin/ppo/ff_ippo.py similarity index 99% rename from mava/systems/ppo/ff_ippo.py rename to mava/systems/anakin/ppo/ff_ippo.py index 7b45fb45f..f37407dd2 100644 --- a/mava/systems/ppo/ff_ippo.py +++ b/mava/systems/anakin/ppo/ff_ippo.py @@ -32,7 +32,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer diff --git a/mava/systems/ppo/ff_mappo.py b/mava/systems/anakin/ppo/ff_mappo.py similarity index 99% rename from mava/systems/ppo/ff_mappo.py rename to mava/systems/anakin/ppo/ff_mappo.py index 519fa4f39..127216069 100644 --- a/mava/systems/ppo/ff_mappo.py +++ b/mava/systems/anakin/ppo/ff_mappo.py @@ -31,7 +31,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer diff --git a/mava/systems/ppo/rec_ippo.py b/mava/systems/anakin/ppo/rec_ippo.py similarity index 99% rename from mava/systems/ppo/rec_ippo.py rename to mava/systems/anakin/ppo/rec_ippo.py index e70a59f07..e4b6740b1 100644 --- a/mava/systems/ppo/rec_ippo.py +++ b/mava/systems/anakin/ppo/rec_ippo.py @@ -33,7 +33,7 @@ from mava.networks import RecurrentActor as Actor from mava.networks import RecurrentValueNet as Critic from mava.networks import ScannedRNN -from mava.systems.ppo.types import ( +from mava.systems.anakin.ppo.types import ( HiddenStates, OptStates, Params, diff --git a/mava/systems/ppo/rec_mappo.py b/mava/systems/anakin/ppo/rec_mappo.py similarity index 99% rename from mava/systems/ppo/rec_mappo.py rename to mava/systems/anakin/ppo/rec_mappo.py index 14284cedb..c351ba576 100644 --- a/mava/systems/ppo/rec_mappo.py +++ b/mava/systems/anakin/ppo/rec_mappo.py @@ -33,7 +33,7 @@ from mava.networks import RecurrentActor as Actor from mava.networks import RecurrentValueNet as Critic from mava.networks import ScannedRNN -from mava.systems.ppo.types import ( +from mava.systems.anakin.ppo.types import ( HiddenStates, OptStates, Params, diff --git a/mava/systems/ppo/types.py b/mava/systems/anakin/ppo/types.py similarity index 100% rename from mava/systems/ppo/types.py rename to mava/systems/anakin/ppo/types.py diff --git a/mava/systems/q_learning/__init__.py b/mava/systems/anakin/q_learning/__init__.py similarity index 100% rename from mava/systems/q_learning/__init__.py rename to mava/systems/anakin/q_learning/__init__.py diff --git a/mava/systems/q_learning/rec_iql.py b/mava/systems/anakin/q_learning/rec_iql.py similarity index 100% rename from mava/systems/q_learning/rec_iql.py rename to mava/systems/anakin/q_learning/rec_iql.py diff --git a/mava/systems/q_learning/types.py b/mava/systems/anakin/q_learning/types.py similarity index 100% rename from mava/systems/q_learning/types.py rename to mava/systems/anakin/q_learning/types.py diff --git a/mava/systems/sac/__init__.py b/mava/systems/anakin/sac/__init__.py similarity index 100% rename from mava/systems/sac/__init__.py rename to mava/systems/anakin/sac/__init__.py diff --git a/mava/systems/sac/ff_isac.py b/mava/systems/anakin/sac/ff_isac.py similarity index 100% rename from mava/systems/sac/ff_isac.py rename to mava/systems/anakin/sac/ff_isac.py diff --git a/mava/systems/sac/ff_masac.py b/mava/systems/anakin/sac/ff_masac.py similarity index 100% rename from mava/systems/sac/ff_masac.py rename to mava/systems/anakin/sac/ff_masac.py diff --git a/mava/systems/sac/types.py b/mava/systems/anakin/sac/types.py similarity index 100% rename from mava/systems/sac/types.py rename to mava/systems/anakin/sac/types.py diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py new file mode 100644 index 000000000..e69de29bb From cb8111fe0c87c616913d165e2f19788533af152d Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 16 Jul 2024 11:18:21 +0100 Subject: [PATCH 039/125] fix: imports and config paths in systems --- mava/systems/anakin/ppo/ff_ippo.py | 2 +- mava/systems/anakin/ppo/ff_mappo.py | 2 +- mava/systems/anakin/ppo/rec_ippo.py | 2 +- mava/systems/anakin/ppo/rec_mappo.py | 2 +- mava/systems/sebulba/ppo/ff_ippo.py | 13 +++++++++++++ mava/utils/checkpointing.py | 2 +- 6 files changed, 18 insertions(+), 5 deletions(-) diff --git a/mava/systems/anakin/ppo/ff_ippo.py b/mava/systems/anakin/ppo/ff_ippo.py index f37407dd2..51efd10e7 100644 --- a/mava/systems/anakin/ppo/ff_ippo.py +++ b/mava/systems/anakin/ppo/ff_ippo.py @@ -578,7 +578,7 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_ippo.yaml", version_base="1.2") +@hydra.main(config_path="../../../configs", config_name="default_ff_ippo.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/mava/systems/anakin/ppo/ff_mappo.py b/mava/systems/anakin/ppo/ff_mappo.py index 127216069..a9364fdfc 100644 --- a/mava/systems/anakin/ppo/ff_mappo.py +++ b/mava/systems/anakin/ppo/ff_mappo.py @@ -575,7 +575,7 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_mappo.yaml", version_base="1.2") +@hydra.main(config_path="../../../configs", config_name="default_ff_mappo.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/mava/systems/anakin/ppo/rec_ippo.py b/mava/systems/anakin/ppo/rec_ippo.py index e4b6740b1..a4d3df428 100644 --- a/mava/systems/anakin/ppo/rec_ippo.py +++ b/mava/systems/anakin/ppo/rec_ippo.py @@ -735,7 +735,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 return eval_performance -@hydra.main(config_path="../../configs", config_name="default_rec_ippo.yaml", version_base="1.2") +@hydra.main(config_path="../../../configs", config_name="default_rec_ippo.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/mava/systems/anakin/ppo/rec_mappo.py b/mava/systems/anakin/ppo/rec_mappo.py index c351ba576..c2f9dc678 100644 --- a/mava/systems/anakin/ppo/rec_mappo.py +++ b/mava/systems/anakin/ppo/rec_mappo.py @@ -726,7 +726,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 return eval_performance -@hydra.main(config_path="../../configs", config_name="default_rec_mappo.yaml", version_base="1.2") +@hydra.main(config_path="../../../configs", config_name="default_rec_mappo.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index e69de29bb..21db9ec1c 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mava/utils/checkpointing.py b/mava/utils/checkpointing.py index 8955f76ce..230c4938d 100644 --- a/mava/utils/checkpointing.py +++ b/mava/utils/checkpointing.py @@ -24,7 +24,7 @@ from jax.tree_util import tree_map from omegaconf import DictConfig, OmegaConf -from mava.systems.ppo.types import HiddenStates, Params +from mava.systems.anakin.ppo.types import HiddenStates, Params from mava.types import MavaState # Keep track of the version of the checkpointer From d842375c8e89bc25e73f3ea97b063cc63083c045 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 16 Jul 2024 15:27:15 +0100 Subject: [PATCH 040/125] fix: allow for reproducibility --- mava/evaluator.py | 17 ++++++++++++----- mava/systems/sebulba/ppo/ff_ippo.py | 15 ++++++++++----- mava/wrappers/gym.py | 16 +++++++++++----- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/mava/evaluator.py b/mava/evaluator.py index 984a42377..8412b2d81 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -348,6 +348,7 @@ def get_sebulba_ff_evaluator_fn( env: Environment, apply_fn: ActorApply, config: DictConfig, + np_rng : np.random.Generator, log_win_rate: bool = False, ) -> SebulbaEvalFn: """Get the evaluator function for feedforward networks. @@ -376,8 +377,9 @@ def get_action( # todo explicetly put these on the learner? they should already return action def eval_episodes(params: FrozenDict, key: chex.PRNGKey) -> Any: - - obs, info = env.reset() + + seeds = np_rng.integers(np.iinfo(np.int64).max, size=env.num_envs) + obs, info = env.reset(seed = seeds) dones = np.full(env.num_envs, False) eval_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) @@ -417,6 +419,7 @@ def get_sebulba_rnn_evaluator_fn( env: Environment, apply_fn: RecActorApply, config: DictConfig, + np_rng : np.random.Generator, scanned_rnn: nn.Module, log_win_rate: bool = False, ) -> SebulbaEvalFn: @@ -448,7 +451,8 @@ def get_action( # todo explicetly put these on the learner? they should already def eval_episodes(params: FrozenDict, key: chex.PRNGKey) -> Any: - obs, info = env.reset() + seeds = np_rng.integers(np.iinfo(np.int64).max, size=env.num_envs) + obs, info = env.reset(seed = seeds) eval_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) hstate = scanned_rnn.initialize_carry( @@ -499,6 +503,7 @@ def make_sebulba_eval_fns( eval_env_fn: Callable, network_apply_fn: Union[ActorApply, RecActorApply], config: DictConfig, + np_rng : np.random.Generator, add_global_state: bool = False, use_recurrent_net: bool = False, scanned_rnn: Optional[nn.Module] = None, @@ -533,6 +538,7 @@ def make_sebulba_eval_fns( eval_env, network_apply_fn, # type: ignore config, + np_rng, scanned_rnn, log_win_rate, ) @@ -540,15 +546,16 @@ def make_sebulba_eval_fns( absolute_eval_env, network_apply_fn, # type: ignore config, + np_rng, scanned_rnn, log_win_rate, ) else: evaluator = get_sebulba_ff_evaluator_fn( - eval_env, network_apply_fn, config, log_win_rate # type: ignore + eval_env, network_apply_fn, config, np_rng, log_win_rate # type: ignore ) absolute_metric_evaluator = get_sebulba_ff_evaluator_fn( - absolute_eval_env, network_apply_fn, config, log_win_rate # type: ignore + absolute_eval_env, network_apply_fn, config, np_rng, log_win_rate # type: ignore ) return evaluator, absolute_metric_evaluator diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index a184414d9..ce7fb224c 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -61,6 +61,7 @@ def rollout( apply_fns: Tuple, learner_devices: List, actor_device_id: int, + seeds: List[int], ) -> None: # setup @@ -89,8 +90,7 @@ def get_action_and_value( params_queue_get_time: deque = deque(maxlen=1) rollout_time: deque = deque(maxlen=1) rollout_queue_put_time: deque = deque(maxlen=1) - - next_obs, info = env.reset() + next_obs, info = env.reset(seed=seeds) next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) move_to_device = lambda x: jax.device_put(x, device=current_actor_device) @@ -586,11 +586,13 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 (key, actor_net_key, critic_net_key), config, learner_devices ) + # Generate Numpy RNG for reproducibility + np_rng = np.random.default_rng(config.system.seed) + # Setup evaluator. - # One key per device for evaluation. evaluator, absolute_metric_evaluator = make_eval_fns( - environments.make_gym_env, apply_fns[0], config - ) # todo: make this more generic + environments.make_gym_env, apply_fns[0], config, np_rng + ) # Calculate total timesteps. config = sebulba_check_total_timesteps(config) @@ -632,6 +634,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) params_queues: List = [] rollout_queues: List = [] + for _d_idx, d_id in enumerate( # Loop through each executor device config.arch.executor_device_ids ): @@ -639,6 +642,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 device_params = jax.device_put(unreplicated_params, devices[d_id]) # Loop through each executor thread for _thread_id in range(config.arch.n_threads_per_executor): + seeds = np_rng.integers(np.iinfo(np.int64).max, size=config.arch.num_envs) params_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1)) params_queues[-1].put(device_params) @@ -652,6 +656,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 apply_fns, learner_devices, d_id, + seeds, ), ).start() diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index b5f89b45f..d1c36cd54 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -49,7 +49,9 @@ def __init__( self.num_agents = len(self._env.action_space) self.num_actions = self._env.action_space[0].n - def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple: + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[np.ndarray, Dict]: if seed is not None: self.env.seed(seed) @@ -96,10 +98,12 @@ def __init__(self, env: gym.Env): self.running_count_episode_return = 0.0 self.running_count_episode_length = 0.0 - def reset(self) -> Tuple: + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[np.ndarray, Dict]: # Reset the env - agents_view, info = self._env.reset() + agents_view, info = self._env.reset(seed, options) # Create the metrics dict metrics = { @@ -160,9 +164,11 @@ def __init__(self, env: gym.Env): ] * self.env.num_agents self.observation_space = spaces.Tuple(_observation_boxs) - def reset(self) -> Tuple[np.ndarray, Dict]: + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[np.ndarray, Dict]: """Reset the environment.""" - obs, info = self.env.reset() + obs, info = self.env.reset(seed, options) obs = np.concatenate([self.agent_ids, obs], axis=1) return obs, info From 0a1ffd0314a87bd799c84bcc0c8578212699e236 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 16 Jul 2024 15:28:25 +0100 Subject: [PATCH 041/125] chore: pre-commits --- mava/evaluator.py | 12 ++++++------ mava/systems/sebulba/ppo/ff_ippo.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mava/evaluator.py b/mava/evaluator.py index 8412b2d81..bacbb050e 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -348,7 +348,7 @@ def get_sebulba_ff_evaluator_fn( env: Environment, apply_fn: ActorApply, config: DictConfig, - np_rng : np.random.Generator, + np_rng: np.random.Generator, log_win_rate: bool = False, ) -> SebulbaEvalFn: """Get the evaluator function for feedforward networks. @@ -377,9 +377,9 @@ def get_action( # todo explicetly put these on the learner? they should already return action def eval_episodes(params: FrozenDict, key: chex.PRNGKey) -> Any: - + seeds = np_rng.integers(np.iinfo(np.int64).max, size=env.num_envs) - obs, info = env.reset(seed = seeds) + obs, info = env.reset(seed=seeds) dones = np.full(env.num_envs, False) eval_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) @@ -419,7 +419,7 @@ def get_sebulba_rnn_evaluator_fn( env: Environment, apply_fn: RecActorApply, config: DictConfig, - np_rng : np.random.Generator, + np_rng: np.random.Generator, scanned_rnn: nn.Module, log_win_rate: bool = False, ) -> SebulbaEvalFn: @@ -452,7 +452,7 @@ def get_action( # todo explicetly put these on the learner? they should already def eval_episodes(params: FrozenDict, key: chex.PRNGKey) -> Any: seeds = np_rng.integers(np.iinfo(np.int64).max, size=env.num_envs) - obs, info = env.reset(seed = seeds) + obs, info = env.reset(seed=seeds) eval_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) hstate = scanned_rnn.initialize_carry( @@ -503,7 +503,7 @@ def make_sebulba_eval_fns( eval_env_fn: Callable, network_apply_fn: Union[ActorApply, RecActorApply], config: DictConfig, - np_rng : np.random.Generator, + np_rng: np.random.Generator, add_global_state: bool = False, use_recurrent_net: bool = False, scanned_rnn: Optional[nn.Module] = None, diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index ce7fb224c..0f1abb206 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -588,11 +588,11 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 # Generate Numpy RNG for reproducibility np_rng = np.random.default_rng(config.system.seed) - + # Setup evaluator. evaluator, absolute_metric_evaluator = make_eval_fns( environments.make_gym_env, apply_fns[0], config, np_rng - ) + ) # Calculate total timesteps. config = sebulba_check_total_timesteps(config) @@ -634,7 +634,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) params_queues: List = [] rollout_queues: List = [] - + for _d_idx, d_id in enumerate( # Loop through each executor device config.arch.executor_device_ids ): From f1adc3109009f86ccd965e794e7dc9f01f45f375 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 16 Jul 2024 15:30:59 +0100 Subject: [PATCH 042/125] chore: pre-commits --- mava/systems/anakin/ppo/rec_mappo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mava/systems/anakin/ppo/rec_mappo.py b/mava/systems/anakin/ppo/rec_mappo.py index c2f9dc678..93736cf10 100644 --- a/mava/systems/anakin/ppo/rec_mappo.py +++ b/mava/systems/anakin/ppo/rec_mappo.py @@ -726,7 +726,9 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 return eval_performance -@hydra.main(config_path="../../../configs", config_name="default_rec_mappo.yaml", version_base="1.2") +@hydra.main( + config_path="../../../configs", config_name="default_rec_mappo.yaml", version_base="1.2" +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. From 3850591b05af82569329dc4cf0eb358df11a8d7e Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 16 Jul 2024 15:41:13 +0100 Subject: [PATCH 043/125] feat: LBF and reproducibility --- mava/utils/make_env.py | 3 +- mava/wrappers/__init__.py | 1 + mava/wrappers/gym.py | 75 ++++++++++++++++++++++++++++++++--- requirements/requirements.txt | 1 + 4 files changed, 73 insertions(+), 7 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 5ee4e697c..9828573e0 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -45,6 +45,7 @@ ConnectorWrapper, GigastepWrapper, GymAgentIDWrapper, + GymLBFWrapper, GymRecordEpisodeMetrics, GymRwareWrapper, LbfWrapper, @@ -71,7 +72,7 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} -_gym_registry = {"RobotWarehouse": GymRwareWrapper} +_gym_registry = {"RobotWarehouse": GymRwareWrapper, "LevelBasedForaging": GymLBFWrapper} def add_extra_wrappers( diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index ee8fdf186..869e78053 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -17,6 +17,7 @@ from mava.wrappers.gigastep import GigastepWrapper from mava.wrappers.gym import ( GymAgentIDWrapper, + GymLBFWrapper, GymRecordEpisodeMetrics, GymRwareWrapper, _multiagent_worker_shared_memory, diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 978ad4033..a9bc5af8e 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -36,7 +36,6 @@ def __init__( add_global_state: bool = False, ): """Initialize the gym wrapper - Args: env (gym.env): gym env instance. use_individual_rewards (bool, optional): Use individual or group rewards. @@ -50,7 +49,9 @@ def __init__( self.num_agents = len(self._env.action_space) self.num_actions = self._env.action_space[0].n - def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple: + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[np.ndarray, Dict]: if seed is not None: self.env.seed(seed) @@ -88,6 +89,64 @@ def get_global_obs(self, obs: NDArray) -> NDArray: return np.tile(global_obs, (self.num_agents, 1)) +class GymLBFWrapper(gym.Wrapper): + """Wrapper for rware gym environments""" + + def __init__( + self, + env: gym.Env, + use_individual_rewards: bool = False, + add_global_state: bool = False, + ): + """Initialize the gym wrapper + Args: + env (gym.env): gym env instance. + use_individual_rewards (bool, optional): Use individual or group rewards. + Defaults to False. + add_global_state (bool, optional) : Create global observations. Defaults to False. + """ + super().__init__(env) + self._env = env # not having _env leaded tp self.env getting replaced --> circular called + self.use_individual_rewards = use_individual_rewards + self.add_global_state = add_global_state # todo : add the global observations + self.num_agents = len(self._env.action_space) + self.num_actions = self._env.action_space[ + 0 + ].n # todo: all the agents must have the same num_actions, add assertion? + + def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple: + + if seed is not None: + self.env.seed(seed) + + agents_view, info = self._env.reset() + + info = {"actions_mask": self.get_actions_mask(info)} + + return np.array(agents_view), info + + def step(self, actions: NDArray) -> Tuple: # Vect auto rest + + agents_view, reward, terminated, truncated, info = self._env.step(actions) + + info = {"actions_mask": self.get_actions_mask(info)} + + if self.use_individual_rewards: + reward = np.array(reward) + else: + reward = np.array([np.array(reward).sum()] * self.num_agents) + + truncated = [truncated] * self.num_agents + terminated = [terminated] * self.num_agents + + return agents_view, reward, terminated, truncated, info + + def get_actions_mask(self, info: Dict) -> NDArray: + if "action_mask" in info: + return np.array(info["action_mask"]) + return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + + class GymRecordEpisodeMetrics(gym.Wrapper): """Record the episode returns and lengths.""" @@ -97,10 +156,12 @@ def __init__(self, env: gym.Env): self.running_count_episode_return = 0.0 self.running_count_episode_length = 0.0 - def reset(self) -> Tuple: + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[np.ndarray, Dict]: # Reset the env - agents_view, info = self._env.reset() + agents_view, info = self._env.reset(seed, options) # Create the metrics dict metrics = { @@ -161,9 +222,11 @@ def __init__(self, env: gym.Env): ] * self.env.num_agents self.observation_space = spaces.Tuple(_observation_boxs) - def reset(self) -> Tuple[np.ndarray, Dict]: + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[np.ndarray, Dict]: """Reset the environment.""" - obs, info = self.env.reset() + obs, info = self.env.reset(seed, options) obs = np.concatenate([self.agent_ids, obs], axis=1) return obs, info diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 3b3bc4c58..3a7b96aef 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -9,6 +9,7 @@ jax jaxlib jaxmarl jumanji @ git+https://github.com/sash-a/jumanji +lbforaging @ git+https://github.com/Louay-Ben-nessir/lb-foraging.git matrax @ git+https://github.com/instadeepai/matrax mujoco==3.1.3 mujoco-mjx==3.1.3 From 0a2ee084bfb5b46f7035d48f05a3fb8297b42be8 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 16 Jul 2024 15:45:51 +0100 Subject: [PATCH 044/125] feat : lbf --- mava/utils/make_env.py | 7 +++-- mava/wrappers/__init__.py | 1 + mava/wrappers/gym.py | 58 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 2330674f0..9828573e0 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -45,6 +45,7 @@ ConnectorWrapper, GigastepWrapper, GymAgentIDWrapper, + GymLBFWrapper, GymRecordEpisodeMetrics, GymRwareWrapper, LbfWrapper, @@ -71,7 +72,7 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} -_gym_registry = {"RobotWarehouse": GymRwareWrapper} +_gym_registry = {"RobotWarehouse": GymRwareWrapper, "LevelBasedForaging": GymLBFWrapper} def add_extra_wrappers( @@ -218,12 +219,12 @@ def make_gym_env( Create a Gym environment. Args: - env_name (str): The name of the environment to create. config (Dict): The configuration of the environment. + num_env (int) : The number of parallel envs to create. add_global_state (bool): Whether to add the global state to the observation. Default False. Returns: - A tuple of the environments. + Async environments. """ base_env_name = config.env.env_name wrapper = _gym_registry[base_env_name] diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index ee8fdf186..869e78053 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -17,6 +17,7 @@ from mava.wrappers.gigastep import GigastepWrapper from mava.wrappers.gym import ( GymAgentIDWrapper, + GymLBFWrapper, GymRecordEpisodeMetrics, GymRwareWrapper, _multiagent_worker_shared_memory, diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index d1c36cd54..a9bc5af8e 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -89,6 +89,64 @@ def get_global_obs(self, obs: NDArray) -> NDArray: return np.tile(global_obs, (self.num_agents, 1)) +class GymLBFWrapper(gym.Wrapper): + """Wrapper for rware gym environments""" + + def __init__( + self, + env: gym.Env, + use_individual_rewards: bool = False, + add_global_state: bool = False, + ): + """Initialize the gym wrapper + Args: + env (gym.env): gym env instance. + use_individual_rewards (bool, optional): Use individual or group rewards. + Defaults to False. + add_global_state (bool, optional) : Create global observations. Defaults to False. + """ + super().__init__(env) + self._env = env # not having _env leaded tp self.env getting replaced --> circular called + self.use_individual_rewards = use_individual_rewards + self.add_global_state = add_global_state # todo : add the global observations + self.num_agents = len(self._env.action_space) + self.num_actions = self._env.action_space[ + 0 + ].n # todo: all the agents must have the same num_actions, add assertion? + + def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple: + + if seed is not None: + self.env.seed(seed) + + agents_view, info = self._env.reset() + + info = {"actions_mask": self.get_actions_mask(info)} + + return np.array(agents_view), info + + def step(self, actions: NDArray) -> Tuple: # Vect auto rest + + agents_view, reward, terminated, truncated, info = self._env.step(actions) + + info = {"actions_mask": self.get_actions_mask(info)} + + if self.use_individual_rewards: + reward = np.array(reward) + else: + reward = np.array([np.array(reward).sum()] * self.num_agents) + + truncated = [truncated] * self.num_agents + terminated = [terminated] * self.num_agents + + return agents_view, reward, terminated, truncated, info + + def get_actions_mask(self, info: Dict) -> NDArray: + if "action_mask" in info: + return np.array(info["action_mask"]) + return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + + class GymRecordEpisodeMetrics(gym.Wrapper): """Record the episode returns and lengths.""" From dc9206564c5b4b4c155b1e956abfc872be617ca6 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 17 Jul 2024 09:35:25 +0100 Subject: [PATCH 045/125] fix: sync neptune logging for sebulba to avoid stalling --- mava/configs/arch/anakin.yaml | 2 +- mava/configs/arch/sebulba.yaml | 4 ++-- mava/utils/logger.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mava/configs/arch/anakin.yaml b/mava/configs/arch/anakin.yaml index 86e75898b..d58d85286 100644 --- a/mava/configs/arch/anakin.yaml +++ b/mava/configs/arch/anakin.yaml @@ -1,5 +1,5 @@ # --- Anakin config --- - +arch_name: "Anakin" # --- Training --- num_envs: 16 # Number of vectorised environments per device. diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index b6a0a9699..e0305e2dc 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,5 +1,5 @@ # --- Sebulba config --- -arch_name: "sebulba" +arch_name: "Sebulba" num_envs: 32 # number of envs per thread # --- Evaluation --- @@ -12,6 +12,6 @@ absolute_metric: True # Whether the absolute metric should be computed. For more # on the absolute metric please see: https://arxiv.org/abs/2209.10485 # --- Sebulba devices config --- -n_threads_per_executor: 1 # num of different threads/env batches per actor +n_threads_per_executor: 2 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices diff --git a/mava/utils/logger.py b/mava/utils/logger.py index 8273e44a2..dc217f263 100644 --- a/mava/utils/logger.py +++ b/mava/utils/logger.py @@ -150,8 +150,9 @@ class NeptuneLogger(BaseLogger): def __init__(self, cfg: DictConfig, unique_token: str) -> None: tags = list(cfg.logger.kwargs.neptune_tag) project = cfg.logger.kwargs.neptune_project + mode = "sync" if cfg.arch.arch_name == "Sebulba" else "async" - self.logger = neptune.init_run(project=project, tags=tags) + self.logger = neptune.init_run(project=project, tags=tags, mode=mode) self.logger["config"] = stringify_unsupported(cfg) self.detailed_logging = cfg.logger.kwargs.detailed_neptune_logging From 133a25060151ccd99b0f0fe1a73af48310dbbbff Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 17 Jul 2024 09:54:45 +0100 Subject: [PATCH 046/125] fix: added missing lbf import --- mava/utils/make_env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 9828573e0..eeebed9d0 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -20,6 +20,7 @@ import gym.wrappers.compatibility import jaxmarl import jumanji +import lbforaging # noqa: F401 used implicitly import matrax from gigastep import ScenarioBuilder from jaxmarl.environments.smax import map_name_to_scenario From b938c831b7c5217f6e9f898d3c564ac45510c10a Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 17 Jul 2024 10:11:09 +0100 Subject: [PATCH 047/125] fix: seeds need to python arrays not np arrays --- mava/evaluator.py | 4 ++-- mava/systems/sebulba/ppo/ff_ippo.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mava/evaluator.py b/mava/evaluator.py index bacbb050e..fb611d1b3 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -378,7 +378,7 @@ def get_action( # todo explicetly put these on the learner? they should already def eval_episodes(params: FrozenDict, key: chex.PRNGKey) -> Any: - seeds = np_rng.integers(np.iinfo(np.int64).max, size=env.num_envs) + seeds = np_rng.integers(np.iinfo(np.int64).max, size=env.num_envs).tolist() obs, info = env.reset(seed=seeds) dones = np.full(env.num_envs, False) eval_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) @@ -451,7 +451,7 @@ def get_action( # todo explicetly put these on the learner? they should already def eval_episodes(params: FrozenDict, key: chex.PRNGKey) -> Any: - seeds = np_rng.integers(np.iinfo(np.int64).max, size=env.num_envs) + seeds = np_rng.integers(np.iinfo(np.int64).max, size=env.num_envs).tolist() obs, info = env.reset(seed=seeds) eval_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/sebulba/ppo/ff_ippo.py index 0f1abb206..42d2732ae 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/sebulba/ppo/ff_ippo.py @@ -642,7 +642,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 device_params = jax.device_put(unreplicated_params, devices[d_id]) # Loop through each executor thread for _thread_id in range(config.arch.n_threads_per_executor): - seeds = np_rng.integers(np.iinfo(np.int64).max, size=config.arch.num_envs) + seeds = np_rng.integers(np.iinfo(np.int64).max, size=config.arch.num_envs).tolist() params_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1)) params_queues[-1].put(device_params) From a36847680413642c634d214095fb4eab0ad5dcae Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 17 Jul 2024 12:40:51 +0100 Subject: [PATCH 048/125] fix: config and imports for anakin q_learning and sac --- mava/systems/anakin/q_learning/rec_iql.py | 4 ++-- mava/systems/anakin/sac/ff_isac.py | 4 ++-- mava/systems/anakin/sac/ff_masac.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mava/systems/anakin/q_learning/rec_iql.py b/mava/systems/anakin/q_learning/rec_iql.py index 6be8e61a4..89139277a 100644 --- a/mava/systems/anakin/q_learning/rec_iql.py +++ b/mava/systems/anakin/q_learning/rec_iql.py @@ -34,7 +34,7 @@ from mava.evaluator import make_eval_fns from mava.networks import RecQNetwork, ScannedRNN -from mava.systems.q_learning.types import ( +from mava.systems.anakin.q_learning.types import ( ActionSelectionState, ActionState, LearnerState, @@ -645,7 +645,7 @@ def run_experiment(cfg: DictConfig) -> float: return float(eval_performance) -@hydra.main(config_path="../../configs", config_name="default_rec_iql.yaml", version_base="1.2") +@hydra.main(config_path="../../../configs", config_name="default_rec_iql.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/mava/systems/anakin/sac/ff_isac.py b/mava/systems/anakin/sac/ff_isac.py index 2c33028d1..1642176f3 100644 --- a/mava/systems/anakin/sac/ff_isac.py +++ b/mava/systems/anakin/sac/ff_isac.py @@ -34,7 +34,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardQNet as QNetwork -from mava.systems.sac.types import ( +from mava.systems.anakin.sac.types import ( BufferState, LearnerState, Metrics, @@ -607,7 +607,7 @@ def run_experiment(cfg: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_isac.yaml", version_base="1.2") +@hydra.main(config_path="../../../configs", config_name="default_ff_isac.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/mava/systems/anakin/sac/ff_masac.py b/mava/systems/anakin/sac/ff_masac.py index 4401906ee..2367a67a4 100644 --- a/mava/systems/anakin/sac/ff_masac.py +++ b/mava/systems/anakin/sac/ff_masac.py @@ -34,7 +34,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardQNet as QNetwork -from mava.systems.sac.types import ( +from mava.systems.anakin.sac.types import ( BufferState, LearnerState, Metrics, @@ -626,7 +626,7 @@ def run_experiment(cfg: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_masac.yaml", version_base="1.2") +@hydra.main(config_path="../../../configs", config_name="default_ff_masac.yaml", version_base="1.2") def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. From 32433ff2d93aee917f9a9504ff8d19d94be33fb1 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 17 Jul 2024 14:17:38 +0100 Subject: [PATCH 049/125] chore: arch_name for anakin --- mava/configs/arch/anakin.yaml | 1 + mava/configs/arch/sebulba.yaml | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mava/configs/arch/anakin.yaml b/mava/configs/arch/anakin.yaml index 86e75898b..6e15238dc 100644 --- a/mava/configs/arch/anakin.yaml +++ b/mava/configs/arch/anakin.yaml @@ -1,4 +1,5 @@ # --- Anakin config --- +arch_name: "Anakin" # --- Training --- num_envs: 16 # Number of vectorised environments per device. diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index b6a0a9699..f38324e86 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,6 +1,8 @@ # --- Sebulba config --- -arch_name: "sebulba" -num_envs: 32 # number of envs per thread +arch_name: "Sebulba" + +# --- Training --- +num_envs: 32 # number of environments per thread. # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select From a68c8e944c9e118eba10acbd3655332d0d935c24 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 17 Jul 2024 14:18:56 +0100 Subject: [PATCH 050/125] fix: sum the rewards when using a shared reward --- mava/wrappers/gym.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index a9bc5af8e..83c523702 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -75,7 +75,7 @@ def step(self, actions: NDArray) -> Tuple: if self.use_individual_rewards: reward = np.array(reward) else: - reward = np.array([np.array(reward).mean()] * self.num_agents) + reward = np.array([np.array(reward).sum()] * self.num_agents) return agents_view, reward, terminated, truncated, info From 8cee7ac0dc5c9b3d927062f0951a8b3e100173e6 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 17 Jul 2024 15:50:11 +0100 Subject: [PATCH 051/125] fix: configs revamp --- mava/configs/env/gym.yaml | 24 ++++++++++--------- .../configs/env/scenario/gym-10x10-3p-3f.yaml | 15 ++++++++++++ .../configs/env/scenario/gym-15x15-3p-5f.yaml | 15 ++++++++++++ .../configs/env/scenario/gym-15x15-4p-3f.yaml | 15 ++++++++++++ .../configs/env/scenario/gym-15x15-4p-5f.yaml | 15 ++++++++++++ .../env/scenario/gym-2s-10x10-3p-3f.yaml | 15 ++++++++++++ .../env/scenario/gym-2s-8x8-2p-2f-coop.yaml | 15 ++++++++++++ .../env/scenario/gym-8x8-2p-2f-coop.yaml | 15 ++++++++++++ mava/configs/env/scenario/gym-small-4ag.yaml | 14 +++++++++++ mava/configs/env/scenario/gym-tiny-2ag.yaml | 14 +++++++++++ .../env/scenario/gym-tiny-4ag-easy.yaml | 14 +++++++++++ mava/configs/env/scenario/gym-tiny-4ag.yaml | 14 +++++++++++ mava/utils/make_env.py | 23 +++++++++--------- mava/wrappers/gym.py | 24 +++++++++---------- 14 files changed, 198 insertions(+), 34 deletions(-) create mode 100644 mava/configs/env/scenario/gym-10x10-3p-3f.yaml create mode 100644 mava/configs/env/scenario/gym-15x15-3p-5f.yaml create mode 100644 mava/configs/env/scenario/gym-15x15-4p-3f.yaml create mode 100644 mava/configs/env/scenario/gym-15x15-4p-5f.yaml create mode 100644 mava/configs/env/scenario/gym-2s-10x10-3p-3f.yaml create mode 100644 mava/configs/env/scenario/gym-2s-8x8-2p-2f-coop.yaml create mode 100644 mava/configs/env/scenario/gym-8x8-2p-2f-coop.yaml create mode 100644 mava/configs/env/scenario/gym-small-4ag.yaml create mode 100644 mava/configs/env/scenario/gym-tiny-2ag.yaml create mode 100644 mava/configs/env/scenario/gym-tiny-4ag-easy.yaml create mode 100644 mava/configs/env/scenario/gym-tiny-4ag.yaml diff --git a/mava/configs/env/gym.yaml b/mava/configs/env/gym.yaml index 1e197a45e..295b9974e 100644 --- a/mava/configs/env/gym.yaml +++ b/mava/configs/env/gym.yaml @@ -1,22 +1,24 @@ # ---Environment Configs--- +scenario: gym-2s-8x8-2p-2f-coop copy -scenario: rware:rware-tiny-2ag-v1 # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] - -env_name: RobotWarehouse # Used for logging purposes. +env_name: Gym # Used for logging purposes, will get changed to the scenario name at runtime. # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return -# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. -# This should not be changed. -implicit_agent_id: False -# Whether or not to log the winrate of this environment. This should not be changed as not all -# environments have a winrate metric. +# Whether the add agents IDs to the observations returned by the environment. +add_agent_id : False + +# Whether or not to log the winrate of this environment. log_win_rate: False -# Weather or not to average the returned rewards over all of the agents. -use_individual_rewards: True +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True kwargs: - time_limit: 500 + {} + +# Possible scenarios: +# RobotWarehouse : [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] +# LevelBasedForaging : [2s-8x8-2p-2f-coop, 8x8-2p-2f-coop, 2s-10x10-3p-3f, 10x10-3p-3f, 15x15-3p-5f, 15x15-4p-3f, 15x15-4p-5f] \ No newline at end of file diff --git a/mava/configs/env/scenario/gym-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-10x10-3p-3f.yaml new file mode 100644 index 000000000..386431be4 --- /dev/null +++ b/mava/configs/env/scenario/gym-10x10-3p-3f.yaml @@ -0,0 +1,15 @@ +# The config of the 10x10-3p-3f scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 10x10-3p-3f + +task_config: + field_size: [10,10] + sight: 10 + num_agents: 3 + max_food: 3 + max_player_level: 2 + force_coop: False + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-15x15-3p-5f.yaml b/mava/configs/env/scenario/gym-15x15-3p-5f.yaml new file mode 100644 index 000000000..1a8380511 --- /dev/null +++ b/mava/configs/env/scenario/gym-15x15-3p-5f.yaml @@ -0,0 +1,15 @@ +# The config of the 15x15-3p-5f scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 15x15-3p-5f + +task_config: + field_size: [15, 15] + sight: 15 + num_agents: 3 + max_food: 5 + max_player_level: 2 + force_coop: False + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-15x15-4p-3f.yaml b/mava/configs/env/scenario/gym-15x15-4p-3f.yaml new file mode 100644 index 000000000..fa22f737b --- /dev/null +++ b/mava/configs/env/scenario/gym-15x15-4p-3f.yaml @@ -0,0 +1,15 @@ +# The config of the 15x15-4p-3f scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 15x15-4p-3f + +task_config: + field_size: [15, 15] + sight: 15 + num_agents: 4 + max_food: 3 + max_player_level: 2 + force_coop: False + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-15x15-4p-5f.yaml b/mava/configs/env/scenario/gym-15x15-4p-5f.yaml new file mode 100644 index 000000000..28937215c --- /dev/null +++ b/mava/configs/env/scenario/gym-15x15-4p-5f.yaml @@ -0,0 +1,15 @@ +# The config of the 15x15-4p-5f scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 15x15-4p-5f + +task_config: + field_size: [15, 15] + sight: 15 + num_agents: 4 + max_food: 5 + max_player_level: 2 + force_coop: False + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-2s-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-2s-10x10-3p-3f.yaml new file mode 100644 index 000000000..f0262eb8d --- /dev/null +++ b/mava/configs/env/scenario/gym-2s-10x10-3p-3f.yaml @@ -0,0 +1,15 @@ +# The config of the 2s10x10-3p-3f scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 2s-10x10-3p-3f + +task_config: + field_size: [10, 10] + sight: 2 + num_agents: 3 + max_food: 3 + max_player_level: 2 + force_coop: False + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-2s-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-2s-8x8-2p-2f-coop.yaml new file mode 100644 index 000000000..ffdc5be0e --- /dev/null +++ b/mava/configs/env/scenario/gym-2s-8x8-2p-2f-coop.yaml @@ -0,0 +1,15 @@ +# The config of the 2s-8x8-2p-2f-coop scenario with the VectorObserver set as default. +name: LevelBasedForaging +task_name: 2s-8x8-2p-2f-coop + +task_config: + field_size: [8, 8] # size of the grid to generate. + sight: 2 # field of view of an agent. + num_agents: 2 # number of agents on the grid. + max_food: 2 # number of food in the environment. + max_player_level: 2 # maximum level of the agents (inclusive). + force_coop: True # force cooperation between agents. + max_episode_steps: 50 # max number of steps per episode. + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-8x8-2p-2f-coop.yaml new file mode 100644 index 000000000..52519fecb --- /dev/null +++ b/mava/configs/env/scenario/gym-8x8-2p-2f-coop.yaml @@ -0,0 +1,15 @@ +# The config of the 8x8-2p-2f-coop scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 8x8-2p-2f-coop + +task_config: + field_size: [8, 8] + sight: 8 + num_agents: 2 + max_food: 2 + max_player_level: 2 + force_coop: True + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-small-4ag.yaml b/mava/configs/env/scenario/gym-small-4ag.yaml new file mode 100644 index 000000000..af3eb830b --- /dev/null +++ b/mava/configs/env/scenario/gym-small-4ag.yaml @@ -0,0 +1,14 @@ +# The config of the small-4ag environment +name: RobotWarehouse +task_name: small-4ag + +task_config: + column_height: 8 + shelf_rows: 2 + shelf_columns: 3 + n_agents: 4 + sensor_range: 1 + request_queue_size: 4 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-tiny-2ag.yaml b/mava/configs/env/scenario/gym-tiny-2ag.yaml new file mode 100644 index 000000000..e648887a0 --- /dev/null +++ b/mava/configs/env/scenario/gym-tiny-2ag.yaml @@ -0,0 +1,14 @@ +# The config of the tiny-2ag environment +name: RobotWarehouse +task_name: tiny-2ag + +task_config: + column_height: 8 + shelf_rows: 1 + shelf_columns: 3 + n_agents: 2 + sensor_range: 1 + request_queue_size: 2 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-tiny-4ag-easy.yaml b/mava/configs/env/scenario/gym-tiny-4ag-easy.yaml new file mode 100644 index 000000000..7d8840882 --- /dev/null +++ b/mava/configs/env/scenario/gym-tiny-4ag-easy.yaml @@ -0,0 +1,14 @@ +# The config of the tiny-4ag-easy environment +name: RobotWarehouse +task_name: tiny-4ag-easy + +task_config: + column_height: 8 + shelf_rows: 1 + shelf_columns: 3 + n_agents: 4 + sensor_range: 1 + request_queue_size: 8 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-tiny-4ag.yaml b/mava/configs/env/scenario/gym-tiny-4ag.yaml new file mode 100644 index 000000000..dbfe55bd4 --- /dev/null +++ b/mava/configs/env/scenario/gym-tiny-4ag.yaml @@ -0,0 +1,14 @@ +# The config of the tiny_4ag environment +name: RobotWarehouse +task_name: tiny-4ag + +task_config: + column_height: 8 + shelf_rows: 1 + shelf_columns: 3 + n_agents: 4 + sensor_range: 1 + request_queue_size: 4 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index eeebed9d0..3f851fa76 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -20,7 +20,8 @@ import gym.wrappers.compatibility import jaxmarl import jumanji -import lbforaging # noqa: F401 used implicitly +from lbforaging.foraging import environment as GymLBF +import rware.warehouse as GymRware import matrax from gigastep import ScenarioBuilder from jaxmarl.environments.smax import map_name_to_scenario @@ -73,7 +74,7 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} -_gym_registry = {"RobotWarehouse": GymRwareWrapper, "LevelBasedForaging": GymLBFWrapper} +_gym_registry = {"RobotWarehouse": (GymRware, GymRwareWrapper), "LevelBasedForaging": (GymLBF ,GymLBFWrapper)} def add_extra_wrappers( @@ -215,7 +216,7 @@ def make_gym_env( config: DictConfig, num_env: int, add_global_state: bool = False, -) -> Environment: # todo : create the appropriate annotation for the sync vector +) -> gym.vector.AsyncVectorEnv: """ Create a Gym environment. @@ -227,20 +228,20 @@ def make_gym_env( Returns: Async environments. """ - base_env_name = config.env.env_name - wrapper = _gym_registry[base_env_name] + base_env_name = config.env.scenario.name + env_maker, wrapper = _gym_registry[base_env_name] def create_gym_env( config: DictConfig, add_global_state: bool = False - ) -> Environment: # todo: add the RecordEpisodeMetrics for gym. - env = gym.make(config.env.scenario) - wrapped_env = wrapper(env, config.env.use_individual_rewards, add_global_state) - if not config.env.implicit_agent_id: - wrapped_env = GymAgentIDWrapper(wrapped_env) # todo : add agent id wrapper for gym . + ) -> Environment: + env = env_maker(**config.env.scenario.task_config) + wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state) + if config.env.add_agent_id: + wrapped_env = GymAgentIDWrapper(wrapped_env) wrapped_env = GymRecordEpisodeMetrics(wrapped_env) return wrapped_env - envs = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names + envs = gym.vector.AsyncVectorEnv( [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)], worker=_multiagent_worker_shared_memory, ) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 83c523702..8112a087e 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -32,19 +32,19 @@ class GymRwareWrapper(gym.Wrapper): def __init__( self, env: gym.Env, - use_individual_rewards: bool = False, + use_shared_rewards: bool = False, add_global_state: bool = False, ): """Initialize the gym wrapper Args: env (gym.env): gym env instance. - use_individual_rewards (bool, optional): Use individual or group rewards. + use_shared_rewards (bool, optional): Use individual or shared rewards. Defaults to False. add_global_state (bool, optional) : Create global observations. Defaults to False. """ super().__init__(env) self._env = env - self.use_individual_rewards = use_individual_rewards + self.use_shared_rewards = use_shared_rewards self.add_global_state = add_global_state self.num_agents = len(self._env.action_space) self.num_actions = self._env.action_space[0].n @@ -72,10 +72,10 @@ def step(self, actions: NDArray) -> Tuple: if self.add_global_state: info["global_obs"] = self.get_global_obs(agents_view) - if self.use_individual_rewards: - reward = np.array(reward) - else: + if self.use_shared_rewards: reward = np.array([np.array(reward).sum()] * self.num_agents) + else: + reward = np.array(reward) return agents_view, reward, terminated, truncated, info @@ -95,19 +95,19 @@ class GymLBFWrapper(gym.Wrapper): def __init__( self, env: gym.Env, - use_individual_rewards: bool = False, + use_shared_rewards: bool = False, add_global_state: bool = False, ): """Initialize the gym wrapper Args: env (gym.env): gym env instance. - use_individual_rewards (bool, optional): Use individual or group rewards. + use_shared_rewards (bool, optional): Use individual or shared rewards. Defaults to False. add_global_state (bool, optional) : Create global observations. Defaults to False. """ super().__init__(env) self._env = env # not having _env leaded tp self.env getting replaced --> circular called - self.use_individual_rewards = use_individual_rewards + self.use_shared_rewards = use_shared_rewards self.add_global_state = add_global_state # todo : add the global observations self.num_agents = len(self._env.action_space) self.num_actions = self._env.action_space[ @@ -131,10 +131,10 @@ def step(self, actions: NDArray) -> Tuple: # Vect auto rest info = {"actions_mask": self.get_actions_mask(info)} - if self.use_individual_rewards: - reward = np.array(reward) - else: + if self.use_shared_rewards: reward = np.array([np.array(reward).sum()] * self.num_agents) + else: + reward = np.array(reward) truncated = [truncated] * self.num_agents terminated = [terminated] * self.num_agents From e199f3a19b50990735f9740388639fb0ec5d36f5 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 17 Jul 2024 15:52:50 +0100 Subject: [PATCH 052/125] chore: pre-commits --- mava/configs/env/gym.yaml | 2 +- mava/utils/make_env.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mava/configs/env/gym.yaml b/mava/configs/env/gym.yaml index 295b9974e..2ee6f9256 100644 --- a/mava/configs/env/gym.yaml +++ b/mava/configs/env/gym.yaml @@ -21,4 +21,4 @@ kwargs: # Possible scenarios: # RobotWarehouse : [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] -# LevelBasedForaging : [2s-8x8-2p-2f-coop, 8x8-2p-2f-coop, 2s-10x10-3p-3f, 10x10-3p-3f, 15x15-3p-5f, 15x15-4p-3f, 15x15-4p-5f] \ No newline at end of file +# LevelBasedForaging : [2s-8x8-2p-2f-coop, 8x8-2p-2f-coop, 2s-10x10-3p-3f, 10x10-3p-3f, 15x15-3p-5f, 15x15-4p-3f, 15x15-4p-5f] diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 3f851fa76..9d89ab581 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -20,9 +20,8 @@ import gym.wrappers.compatibility import jaxmarl import jumanji -from lbforaging.foraging import environment as GymLBF -import rware.warehouse as GymRware import matrax +import rware.warehouse as gym_rware from gigastep import ScenarioBuilder from jaxmarl.environments.smax import map_name_to_scenario from jumanji.env import Environment @@ -38,6 +37,7 @@ from jumanji.environments.routing.robot_warehouse.generator import ( RandomGenerator as RwareRandomGenerator, ) +from lbforaging.foraging import environment as gym_lbf from omegaconf import DictConfig from mava.wrappers import ( @@ -74,7 +74,10 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} -_gym_registry = {"RobotWarehouse": (GymRware, GymRwareWrapper), "LevelBasedForaging": (GymLBF ,GymLBFWrapper)} +_gym_registry = { + "RobotWarehouse": (gym_rware, GymRwareWrapper), + "LevelBasedForaging": (gym_lbf, GymLBFWrapper), +} def add_extra_wrappers( @@ -216,7 +219,7 @@ def make_gym_env( config: DictConfig, num_env: int, add_global_state: bool = False, -) -> gym.vector.AsyncVectorEnv: +) -> gym.vector.AsyncVectorEnv: """ Create a Gym environment. @@ -231,17 +234,15 @@ def make_gym_env( base_env_name = config.env.scenario.name env_maker, wrapper = _gym_registry[base_env_name] - def create_gym_env( - config: DictConfig, add_global_state: bool = False - ) -> Environment: + def create_gym_env(config: DictConfig, add_global_state: bool = False) -> Environment: env = env_maker(**config.env.scenario.task_config) wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state) if config.env.add_agent_id: - wrapped_env = GymAgentIDWrapper(wrapped_env) + wrapped_env = GymAgentIDWrapper(wrapped_env) wrapped_env = GymRecordEpisodeMetrics(wrapped_env) return wrapped_env - envs = gym.vector.AsyncVectorEnv( + envs = gym.vector.AsyncVectorEnv( [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)], worker=_multiagent_worker_shared_memory, ) From 2b71d3b32652c34c6666b10266a184ba6dac17c2 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 17 Jul 2024 16:18:55 +0100 Subject: [PATCH 053/125] fix: more config changes --- mava/configs/arch/anakin.yaml | 2 +- mava/configs/arch/sebulba.yaml | 2 +- mava/configs/default_ff_ippo.yaml | 2 +- mava/configs/env/{gym.yaml => gym_lbf.yaml} | 8 ++----- mava/configs/env/rware_gym.yaml | 20 ++++++++++++++++++ mava/wrappers/gym.py | 23 +++++++++++++-------- 6 files changed, 39 insertions(+), 18 deletions(-) rename mava/configs/env/{gym.yaml => gym_lbf.yaml} (60%) create mode 100644 mava/configs/env/rware_gym.yaml diff --git a/mava/configs/arch/anakin.yaml b/mava/configs/arch/anakin.yaml index 6e15238dc..d6414f5ac 100644 --- a/mava/configs/arch/anakin.yaml +++ b/mava/configs/arch/anakin.yaml @@ -1,5 +1,5 @@ # --- Anakin config --- -arch_name: "Anakin" +arch_name: anakin # --- Training --- num_envs: 16 # Number of vectorised environments per device. diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index f38324e86..0ff3707cd 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,5 +1,5 @@ # --- Sebulba config --- -arch_name: "Sebulba" +arch_name: sebulba # --- Training --- num_envs: 32 # number of environments per thread. diff --git a/mava/configs/default_ff_ippo.yaml b/mava/configs/default_ff_ippo.yaml index d942584ce..c4aa6ea49 100644 --- a/mava/configs/default_ff_ippo.yaml +++ b/mava/configs/default_ff_ippo.yaml @@ -3,5 +3,5 @@ defaults: - arch: anakin - system: ppo/ff_ippo - network: mlp - - env: rware + - env: rware_gym - _self_ diff --git a/mava/configs/env/gym.yaml b/mava/configs/env/gym_lbf.yaml similarity index 60% rename from mava/configs/env/gym.yaml rename to mava/configs/env/gym_lbf.yaml index 2ee6f9256..dfabeb888 100644 --- a/mava/configs/env/gym.yaml +++ b/mava/configs/env/gym_lbf.yaml @@ -1,7 +1,7 @@ # ---Environment Configs--- -scenario: gym-2s-8x8-2p-2f-coop copy +scenario: gym-2s-8x8-2p-2f-coop copy # [gym-2s-8x8-2p-2f-coop, gym-8x8-2p-2f-coop, gym-2s-10x10-3p-3f, gym-10x10-3p-3f, gym-15x15-3p-5f, gym-15x15-4p-3f, gym-15x15-4p-5f] -env_name: Gym # Used for logging purposes, will get changed to the scenario name at runtime. +env_name: LevelBasedForaging # Used for logging purposes. # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. @@ -18,7 +18,3 @@ use_shared_rewards: True kwargs: {} - -# Possible scenarios: -# RobotWarehouse : [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] -# LevelBasedForaging : [2s-8x8-2p-2f-coop, 8x8-2p-2f-coop, 2s-10x10-3p-3f, 10x10-3p-3f, 15x15-3p-5f, 15x15-4p-3f, 15x15-4p-5f] diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml new file mode 100644 index 000000000..a61bc734e --- /dev/null +++ b/mava/configs/env/rware_gym.yaml @@ -0,0 +1,20 @@ +# ---Environment Configs--- +scenario: gym-2s-8x8-2p-2f-coop # [gym-tiny-2ag, gym-tiny-4ag, gym-tiny-4ag-easy, gym-small-4ag] + +env_name: RobotWarehouse # Used for logging purposes. + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the add agents IDs to the observations returned by the environment. +add_agent_id : False + +# Whether or not to log the winrate of this environment. +log_win_rate: False + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True + +kwargs: + {} diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 8112a087e..396f78ef4 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -32,10 +32,10 @@ class GymRwareWrapper(gym.Wrapper): def __init__( self, env: gym.Env, - use_shared_rewards: bool = False, + use_shared_rewards: bool = True, add_global_state: bool = False, ): - """Initialize the gym wrapper + """Initialise the gym wrapper Args: env (gym.env): gym env instance. use_shared_rewards (bool, optional): Use individual or shared rewards. @@ -95,10 +95,10 @@ class GymLBFWrapper(gym.Wrapper): def __init__( self, env: gym.Env, - use_shared_rewards: bool = False, + use_shared_rewards: bool = True, add_global_state: bool = False, ): - """Initialize the gym wrapper + """Initialise the gym wrapper Args: env (gym.env): gym env instance. use_shared_rewards (bool, optional): Use individual or shared rewards. @@ -106,13 +106,13 @@ def __init__( add_global_state (bool, optional) : Create global observations. Defaults to False. """ super().__init__(env) - self._env = env # not having _env leaded tp self.env getting replaced --> circular called + self._env = env self.use_shared_rewards = use_shared_rewards - self.add_global_state = add_global_state # todo : add the global observations + self.add_global_state = add_global_state self.num_agents = len(self._env.action_space) self.num_actions = self._env.action_space[ 0 - ].n # todo: all the agents must have the same num_actions, add assertion? + ].n def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple: @@ -130,7 +130,9 @@ def step(self, actions: NDArray) -> Tuple: # Vect auto rest agents_view, reward, terminated, truncated, info = self._env.step(actions) info = {"actions_mask": self.get_actions_mask(info)} - + if self.add_global_state: + info["global_obs"] = self.get_global_obs(agents_view) + if self.use_shared_rewards: reward = np.array([np.array(reward).sum()] * self.num_agents) else: @@ -145,7 +147,10 @@ def get_actions_mask(self, info: Dict) -> NDArray: if "action_mask" in info: return np.array(info["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) - + + def get_global_obs(self, obs: NDArray) -> NDArray: + global_obs = np.concatenate(obs, axis=0) + return np.tile(global_obs, (self.num_agents, 1)) class GymRecordEpisodeMetrics(gym.Wrapper): """Record the episode returns and lengths.""" From e87ad286cb87fde7c40fde4f5c83ca5692e714d7 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 17 Jul 2024 16:20:37 +0100 Subject: [PATCH 054/125] chore: pre-commits --- mava/wrappers/gym.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 396f78ef4..13975a9a5 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -106,13 +106,11 @@ def __init__( add_global_state (bool, optional) : Create global observations. Defaults to False. """ super().__init__(env) - self._env = env + self._env = env self.use_shared_rewards = use_shared_rewards - self.add_global_state = add_global_state + self.add_global_state = add_global_state self.num_agents = len(self._env.action_space) - self.num_actions = self._env.action_space[ - 0 - ].n + self.num_actions = self._env.action_space[0].n def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple: @@ -132,7 +130,7 @@ def step(self, actions: NDArray) -> Tuple: # Vect auto rest info = {"actions_mask": self.get_actions_mask(info)} if self.add_global_state: info["global_obs"] = self.get_global_obs(agents_view) - + if self.use_shared_rewards: reward = np.array([np.array(reward).sum()] * self.num_agents) else: @@ -147,11 +145,12 @@ def get_actions_mask(self, info: Dict) -> NDArray: if "action_mask" in info: return np.array(info["action_mask"]) return np.ones((self.num_agents, self.num_actions), dtype=np.float32) - + def get_global_obs(self, obs: NDArray) -> NDArray: global_obs = np.concatenate(obs, axis=0) return np.tile(global_obs, (self.num_agents, 1)) + class GymRecordEpisodeMetrics(gym.Wrapper): """Record the episode returns and lengths.""" From 2b587c05626bf469dbf499d2c86b6b414152ba0c Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 10:24:20 +0100 Subject: [PATCH 055/125] chore: renamed arch_name to architecture_name --- mava/configs/arch/anakin.yaml | 2 +- mava/configs/arch/sebulba.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/configs/arch/anakin.yaml b/mava/configs/arch/anakin.yaml index d6414f5ac..eb948b7a1 100644 --- a/mava/configs/arch/anakin.yaml +++ b/mava/configs/arch/anakin.yaml @@ -1,5 +1,5 @@ # --- Anakin config --- -arch_name: anakin +architecture_name: anakin # --- Training --- num_envs: 16 # Number of vectorised environments per device. diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 0ff3707cd..0b539059b 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,5 +1,5 @@ # --- Sebulba config --- -arch_name: sebulba +architecture_name: sebulba # --- Training --- num_envs: 32 # number of environments per thread. From 5ad4d2fa5e6962826a70e7da24f2ad9db515a09d Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 10:30:39 +0100 Subject: [PATCH 056/125] chore: config files rename --- mava/configs/env/{gym_lbf.yaml => lbf_gym.yaml} | 7 ++----- mava/configs/env/rware_gym.yaml | 7 ++----- .../{gym-10x10-3p-3f.yaml => gym-lbf-10x10-3p-3f.yaml} | 0 .../{gym-15x15-3p-5f.yaml => gym-lbf-15x15-3p-5f.yaml} | 0 .../{gym-15x15-4p-3f.yaml => gym-lbf-15x15-4p-3f.yaml} | 0 .../{gym-15x15-4p-5f.yaml => gym-lbf-15x15-4p-5f.yaml} | 0 ...gym-2s-10x10-3p-3f.yaml => gym-lbf-2s-10x10-3p-3f.yaml} | 0 ...-8x8-2p-2f-coop.yaml => gym-lbf-2s-8x8-2p-2f-coop.yaml} | 0 ...gym-8x8-2p-2f-coop.yaml => gym-lbf-8x8-2p-2f-coop.yaml} | 0 .../{gym-small-4ag.yaml => gym-rware-small-4ag.yaml} | 0 .../{gym-tiny-2ag.yaml => gym-rware-tiny-2ag.yaml} | 0 ...gym-tiny-4ag-easy.yaml => gym-rware-tiny-4ag-easy.yaml} | 0 .../{gym-tiny-4ag.yaml => gym-rware-tiny-4ag.yaml} | 0 13 files changed, 4 insertions(+), 10 deletions(-) rename mava/configs/env/{gym_lbf.yaml => lbf_gym.yaml} (70%) rename mava/configs/env/scenario/{gym-10x10-3p-3f.yaml => gym-lbf-10x10-3p-3f.yaml} (100%) rename mava/configs/env/scenario/{gym-15x15-3p-5f.yaml => gym-lbf-15x15-3p-5f.yaml} (100%) rename mava/configs/env/scenario/{gym-15x15-4p-3f.yaml => gym-lbf-15x15-4p-3f.yaml} (100%) rename mava/configs/env/scenario/{gym-15x15-4p-5f.yaml => gym-lbf-15x15-4p-5f.yaml} (100%) rename mava/configs/env/scenario/{gym-2s-10x10-3p-3f.yaml => gym-lbf-2s-10x10-3p-3f.yaml} (100%) rename mava/configs/env/scenario/{gym-2s-8x8-2p-2f-coop.yaml => gym-lbf-2s-8x8-2p-2f-coop.yaml} (100%) rename mava/configs/env/scenario/{gym-8x8-2p-2f-coop.yaml => gym-lbf-8x8-2p-2f-coop.yaml} (100%) rename mava/configs/env/scenario/{gym-small-4ag.yaml => gym-rware-small-4ag.yaml} (100%) rename mava/configs/env/scenario/{gym-tiny-2ag.yaml => gym-rware-tiny-2ag.yaml} (100%) rename mava/configs/env/scenario/{gym-tiny-4ag-easy.yaml => gym-rware-tiny-4ag-easy.yaml} (100%) rename mava/configs/env/scenario/{gym-tiny-4ag.yaml => gym-rware-tiny-4ag.yaml} (100%) diff --git a/mava/configs/env/gym_lbf.yaml b/mava/configs/env/lbf_gym.yaml similarity index 70% rename from mava/configs/env/gym_lbf.yaml rename to mava/configs/env/lbf_gym.yaml index dfabeb888..3fca4d62d 100644 --- a/mava/configs/env/gym_lbf.yaml +++ b/mava/configs/env/lbf_gym.yaml @@ -1,5 +1,5 @@ # ---Environment Configs--- -scenario: gym-2s-8x8-2p-2f-coop copy # [gym-2s-8x8-2p-2f-coop, gym-8x8-2p-2f-coop, gym-2s-10x10-3p-3f, gym-10x10-3p-3f, gym-15x15-3p-5f, gym-15x15-4p-3f, gym-15x15-4p-5f] +scenario: gym-2s-8x8-2p-2f-coop copy # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] env_name: LevelBasedForaging # Used for logging purposes. @@ -14,7 +14,4 @@ add_agent_id : False log_win_rate: False # Weather or not to sum the returned rewards over all of the agents. -use_shared_rewards: True - -kwargs: - {} +use_shared_rewards: True \ No newline at end of file diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml index a61bc734e..576bf0d2b 100644 --- a/mava/configs/env/rware_gym.yaml +++ b/mava/configs/env/rware_gym.yaml @@ -1,5 +1,5 @@ # ---Environment Configs--- -scenario: gym-2s-8x8-2p-2f-coop # [gym-tiny-2ag, gym-tiny-4ag, gym-tiny-4ag-easy, gym-small-4ag] +scenario: gym-2s-8x8-2p-2f-coop # [gym-rware-tiny-2ag, gym-rware-tiny-4ag, gym-rware-tiny-4ag-easy, gym-rware-small-4ag] env_name: RobotWarehouse # Used for logging purposes. @@ -14,7 +14,4 @@ add_agent_id : False log_win_rate: False # Weather or not to sum the returned rewards over all of the agents. -use_shared_rewards: True - -kwargs: - {} +use_shared_rewards: True \ No newline at end of file diff --git a/mava/configs/env/scenario/gym-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml similarity index 100% rename from mava/configs/env/scenario/gym-10x10-3p-3f.yaml rename to mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml diff --git a/mava/configs/env/scenario/gym-15x15-3p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml similarity index 100% rename from mava/configs/env/scenario/gym-15x15-3p-5f.yaml rename to mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml diff --git a/mava/configs/env/scenario/gym-15x15-4p-3f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml similarity index 100% rename from mava/configs/env/scenario/gym-15x15-4p-3f.yaml rename to mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml diff --git a/mava/configs/env/scenario/gym-15x15-4p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml similarity index 100% rename from mava/configs/env/scenario/gym-15x15-4p-5f.yaml rename to mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml diff --git a/mava/configs/env/scenario/gym-2s-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml similarity index 100% rename from mava/configs/env/scenario/gym-2s-10x10-3p-3f.yaml rename to mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml diff --git a/mava/configs/env/scenario/gym-2s-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml similarity index 100% rename from mava/configs/env/scenario/gym-2s-8x8-2p-2f-coop.yaml rename to mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml diff --git a/mava/configs/env/scenario/gym-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml similarity index 100% rename from mava/configs/env/scenario/gym-8x8-2p-2f-coop.yaml rename to mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml diff --git a/mava/configs/env/scenario/gym-small-4ag.yaml b/mava/configs/env/scenario/gym-rware-small-4ag.yaml similarity index 100% rename from mava/configs/env/scenario/gym-small-4ag.yaml rename to mava/configs/env/scenario/gym-rware-small-4ag.yaml diff --git a/mava/configs/env/scenario/gym-tiny-2ag.yaml b/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml similarity index 100% rename from mava/configs/env/scenario/gym-tiny-2ag.yaml rename to mava/configs/env/scenario/gym-rware-tiny-2ag.yaml diff --git a/mava/configs/env/scenario/gym-tiny-4ag-easy.yaml b/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml similarity index 100% rename from mava/configs/env/scenario/gym-tiny-4ag-easy.yaml rename to mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml diff --git a/mava/configs/env/scenario/gym-tiny-4ag.yaml b/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml similarity index 100% rename from mava/configs/env/scenario/gym-tiny-4ag.yaml rename to mava/configs/env/scenario/gym-rware-tiny-4ag.yaml From 432071e9476aadf2342ea0f571fd0d4b30edc7cd Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 10:40:55 +0100 Subject: [PATCH 057/125] fix; moved from gym to gymnasium --- mava/configs/env/lbf_gym.yaml | 2 +- mava/configs/env/rware_gym.yaml | 2 +- mava/utils/make_env.py | 14 +++++++------- mava/wrappers/gym.py | 28 ++++++++++++++-------------- requirements/requirements.txt | 3 ++- 5 files changed, 25 insertions(+), 24 deletions(-) diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml index 3fca4d62d..0c6016dd4 100644 --- a/mava/configs/env/lbf_gym.yaml +++ b/mava/configs/env/lbf_gym.yaml @@ -14,4 +14,4 @@ add_agent_id : False log_win_rate: False # Weather or not to sum the returned rewards over all of the agents. -use_shared_rewards: True \ No newline at end of file +use_shared_rewards: True diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml index 576bf0d2b..4d5e0c7f3 100644 --- a/mava/configs/env/rware_gym.yaml +++ b/mava/configs/env/rware_gym.yaml @@ -14,4 +14,4 @@ add_agent_id : False log_win_rate: False # Weather or not to sum the returned rewards over all of the agents. -use_shared_rewards: True \ No newline at end of file +use_shared_rewards: True diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 9d89ab581..dcab4216a 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -14,10 +14,10 @@ from typing import Tuple -import gym -import gym.vector -import gym.wrappers -import gym.wrappers.compatibility +import gymnasium +import gymnasium.vector +import gymnasium.wrappers +import gymnasium.wrappers.compatibility import jaxmarl import jumanji import matrax @@ -219,9 +219,9 @@ def make_gym_env( config: DictConfig, num_env: int, add_global_state: bool = False, -) -> gym.vector.AsyncVectorEnv: +) -> gymnasium.vector.AsyncVectorEnv: """ - Create a Gym environment. + Create a gymnasium environment. Args: config (Dict): The configuration of the environment. @@ -242,7 +242,7 @@ def create_gym_env(config: DictConfig, add_global_state: bool = False) -> Enviro wrapped_env = GymRecordEpisodeMetrics(wrapped_env) return wrapped_env - envs = gym.vector.AsyncVectorEnv( + envs = gymnasium.vector.AsyncVectorEnv( [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)], worker=_multiagent_worker_shared_memory, ) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 13975a9a5..5b8f9cd74 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -16,28 +16,28 @@ import warnings from typing import Any, Callable, Dict, Optional, Tuple -import gym +import gymnasium import numpy as np -from gym import spaces -from gym.vector.utils import write_to_shared_memory +from gymnasium import spaces +from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray # Filter out the warnings -warnings.filterwarnings("ignore", module="gym.utils.passive_env_checker") +warnings.filterwarnings("ignore", module="gymnasium.utils.passive_env_checker") -class GymRwareWrapper(gym.Wrapper): +class GymRwareWrapper(gymnasium.Wrapper): """Wrapper for rware gym environments.""" def __init__( self, - env: gym.Env, + env: gymnasium.Env, use_shared_rewards: bool = True, add_global_state: bool = False, ): """Initialise the gym wrapper Args: - env (gym.env): gym env instance. + env (gymnasium.env): gymnasium env instance. use_shared_rewards (bool, optional): Use individual or shared rewards. Defaults to False. add_global_state (bool, optional) : Create global observations. Defaults to False. @@ -89,18 +89,18 @@ def get_global_obs(self, obs: NDArray) -> NDArray: return np.tile(global_obs, (self.num_agents, 1)) -class GymLBFWrapper(gym.Wrapper): +class GymLBFWrapper(gymnasium.Wrapper): """Wrapper for rware gym environments""" def __init__( self, - env: gym.Env, + env: gymnasium.Env, use_shared_rewards: bool = True, add_global_state: bool = False, ): """Initialise the gym wrapper Args: - env (gym.env): gym env instance. + env (gymnasium.env): gymnasium env instance. use_shared_rewards (bool, optional): Use individual or shared rewards. Defaults to False. add_global_state (bool, optional) : Create global observations. Defaults to False. @@ -151,10 +151,10 @@ def get_global_obs(self, obs: NDArray) -> NDArray: return np.tile(global_obs, (self.num_agents, 1)) -class GymRecordEpisodeMetrics(gym.Wrapper): +class GymRecordEpisodeMetrics(gymnasium.Wrapper): """Record the episode returns and lengths.""" - def __init__(self, env: gym.Env): + def __init__(self, env: gymnasium.Env): super().__init__(env) self._env = env self.running_count_episode_return = 0.0 @@ -206,10 +206,10 @@ def step(self, actions: NDArray) -> Tuple: return agents_view, reward, terminated, truncated, info -class GymAgentIDWrapper(gym.Wrapper): +class GymAgentIDWrapper(gymnasium.Wrapper): """Add one hot agent IDs to observation.""" - def __init__(self, env: gym.Env): + def __init__(self, env: gymnasium.Env): super().__init__(env) self.agent_ids = np.eye(self.env.num_agents) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 3a7b96aef..74b07af25 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,13 +3,14 @@ distrax @ git+https://github.com/google-deepmind/distrax # distrax release does flashbax~=0.1.0 flax gigastep @ git+https://github.com/mlech26l/gigastep +gymnasium hydra-core==1.3.2 id-marl-eval @ git+https://github.com/instadeepai/marl-eval jax jaxlib jaxmarl jumanji @ git+https://github.com/sash-a/jumanji -lbforaging @ git+https://github.com/Louay-Ben-nessir/lb-foraging.git +lbforaging @ git+https://github.com/LukasSchaefer/lb-foraging.git@gymnasium_integration matrax @ git+https://github.com/instadeepai/matrax mujoco==3.1.3 mujoco-mjx==3.1.3 From 77e6e126e73e02ce5ad62105b08372a28edda699 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 10:44:37 +0100 Subject: [PATCH 058/125] feat: generic gym wrapper --- mava/utils/make_env.py | 4 +-- mava/wrappers/__init__.py | 2 +- mava/wrappers/gym.py | 51 +++++++-------------------------------- 3 files changed, 12 insertions(+), 45 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index dcab4216a..a2dd6ef54 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -49,7 +49,7 @@ GymAgentIDWrapper, GymLBFWrapper, GymRecordEpisodeMetrics, - GymRwareWrapper, + GymWrapper, LbfWrapper, MabraxWrapper, MatraxWrapper, @@ -75,7 +75,7 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} _gym_registry = { - "RobotWarehouse": (gym_rware, GymRwareWrapper), + "RobotWarehouse": (gym_rware, GymWrapper), "LevelBasedForaging": (gym_lbf, GymLBFWrapper), } diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 869e78053..03e2223dc 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -19,7 +19,7 @@ GymAgentIDWrapper, GymLBFWrapper, GymRecordEpisodeMetrics, - GymRwareWrapper, + GymWrapper, _multiagent_worker_shared_memory, ) from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 5b8f9cd74..49dbafd1f 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -26,8 +26,8 @@ warnings.filterwarnings("ignore", module="gymnasium.utils.passive_env_checker") -class GymRwareWrapper(gymnasium.Wrapper): - """Wrapper for rware gym environments.""" +class GymWrapper(gymnasium.Wrapper): + """Wrapper for gym environments.""" def __init__( self, @@ -89,7 +89,7 @@ def get_global_obs(self, obs: NDArray) -> NDArray: return np.tile(global_obs, (self.num_agents, 1)) -class GymLBFWrapper(gymnasium.Wrapper): +class GymLBFWrapper(GymWrapper): """Wrapper for rware gym environments""" def __init__( @@ -105,50 +105,17 @@ def __init__( Defaults to False. add_global_state (bool, optional) : Create global observations. Defaults to False. """ - super().__init__(env) - self._env = env - self.use_shared_rewards = use_shared_rewards - self.add_global_state = add_global_state - self.num_agents = len(self._env.action_space) - self.num_actions = self._env.action_space[0].n - - def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple: - - if seed is not None: - self.env.seed(seed) + super().__init__(env, use_shared_rewards, add_global_state) - agents_view, info = self._env.reset() + def step(self, actions: NDArray) -> Tuple: - info = {"actions_mask": self.get_actions_mask(info)} - - return np.array(agents_view), info - - def step(self, actions: NDArray) -> Tuple: # Vect auto rest - - agents_view, reward, terminated, truncated, info = self._env.step(actions) - - info = {"actions_mask": self.get_actions_mask(info)} - if self.add_global_state: - info["global_obs"] = self.get_global_obs(agents_view) - - if self.use_shared_rewards: - reward = np.array([np.array(reward).sum()] * self.num_agents) - else: - reward = np.array(reward) - - truncated = [truncated] * self.num_agents - terminated = [terminated] * self.num_agents + agents_view, reward, terminated, truncated, info = super().step(actions) + truncated = np.repeat(truncated, self.num_agents) + terminated = np.repeat(terminated, self.num_agents) + return agents_view, reward, terminated, truncated, info - def get_actions_mask(self, info: Dict) -> NDArray: - if "action_mask" in info: - return np.array(info["action_mask"]) - return np.ones((self.num_agents, self.num_actions), dtype=np.float32) - - def get_global_obs(self, obs: NDArray) -> NDArray: - global_obs = np.concatenate(obs, axis=0) - return np.tile(global_obs, (self.num_agents, 1)) class GymRecordEpisodeMetrics(gymnasium.Wrapper): From 43511fd31ec2e39f9f304493cd8f4c6710c97078 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 10:50:15 +0100 Subject: [PATCH 059/125] feat: using gymnasium async worker --- mava/utils/make_env.py | 4 +- mava/wrappers/__init__.py | 2 +- mava/wrappers/gym.py | 109 +++++++++++++++++++++++--------------- 3 files changed, 69 insertions(+), 46 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index a2dd6ef54..26197a289 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -56,7 +56,7 @@ RecordEpisodeMetrics, RwareWrapper, SmaxWrapper, - _multiagent_worker_shared_memory, + async_multiagent_worker, ) # Registry mapping environment names to their generator and wrapper classes. @@ -244,7 +244,7 @@ def create_gym_env(config: DictConfig, add_global_state: bool = False) -> Enviro envs = gymnasium.vector.AsyncVectorEnv( [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)], - worker=_multiagent_worker_shared_memory, + worker=async_multiagent_worker, ) return envs diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 03e2223dc..80cbccc52 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -20,7 +20,7 @@ GymLBFWrapper, GymRecordEpisodeMetrics, GymWrapper, - _multiagent_worker_shared_memory, + async_multiagent_worker, ) from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 49dbafd1f..3fec9f47e 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -22,6 +22,13 @@ from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray +import multiprocessing +import sys +import traceback +from copy import deepcopy +from multiprocessing import Queue +from multiprocessing.connection import Connection + # Filter out the warnings warnings.filterwarnings("ignore", module="gymnasium.utils.passive_env_checker") @@ -208,76 +215,92 @@ def step(self, action: list) -> Tuple[np.ndarray, float, bool, bool, Dict]: return obs, reward, terminated, truncated, info -# Copied form https://github.com/openai/gym/blob/master/gym/vector/async_vector_env.py +# Copied form https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/vector/async_vector_env.py # Modified to work with multiple agents -def _multiagent_worker_shared_memory( # noqa: CCR001 +def async_multiagent_worker( index: int, - env_fn: Callable[[], Any], - pipe: Any, - parent_pipe: Any, - shared_memory: Any, - error_queue: Any, -) -> None: - assert shared_memory is not None + env_fn: callable, + pipe: Connection, + parent_pipe: Connection, + shared_memory: multiprocessing.Array | dict[str, Any] | tuple[Any, ...], + error_queue: Queue, +): env = env_fn() observation_space = env.observation_space + action_space = env.action_space + autoreset = False + parent_pipe.close() + try: while True: command, data = pipe.recv() + if command == "reset": observation, info = env.reset(**data) - write_to_shared_memory(observation_space, index, observation, shared_memory) - pipe.send(((None, info), True)) - + if shared_memory: + write_to_shared_memory( + observation_space, index, observation, shared_memory + ) + observation = None + autoreset = False + pipe.send(((observation, info), True)) elif command == "step": - ( - observation, - reward, - terminated, - truncated, - info, - ) = env.step(data) - # Handel the dones across all of envs and agents - if np.logical_or(terminated, truncated).all(): - old_observation, old_info = observation, info + if autoreset: observation, info = env.reset() - info["final_observation"] = old_observation - info["final_info"] = old_info - write_to_shared_memory(observation_space, index, observation, shared_memory) - pipe.send(((None, reward, terminated, truncated, info), True)) - elif command == "seed": - env.seed(data) - pipe.send((None, True)) + reward, terminated, truncated = 0, False, False + else: + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + autoreset = np.logical_or(terminated, truncated).all() + + if shared_memory: + write_to_shared_memory( + observation_space, index, observation, shared_memory + ) + observation = None + + pipe.send(((observation, reward, terminated, truncated, info), True)) elif command == "close": pipe.send((None, True)) break elif command == "_call": name, args, kwargs = data - if name in ["reset", "step", "seed", "close"]: + if name in ["reset", "step", "close", "_setattr", "_check_spaces"]: raise ValueError( - f"Trying to call function `{name}` with " - f"`_call`. Use `{name}` directly instead." + f"Trying to call function `{name}` with `call`, use `{name}` directly instead." ) - function = getattr(env, name) - if callable(function): - pipe.send((function(*args, **kwargs), True)) + + attr = env.get_wrapper_attr(name) + if callable(attr): + pipe.send((attr(*args, **kwargs), True)) else: - pipe.send((function, True)) + pipe.send((attr, True)) elif command == "_setattr": name, value = data - setattr(env, name, value) + env.set_wrapper_attr(name, value) pipe.send((None, True)) elif command == "_check_spaces": - pipe.send(((data[0] == observation_space, data[1] == env.action_space), True)) + pipe.send( + ( + (data[0] == observation_space, data[1] == action_space), + True, + ) + ) else: raise RuntimeError( - f"Received unknown command `{command}`. Must " - "be one of {`reset`, `step`, `seed`, `close`, `_call`, " - "`_setattr`, `_check_spaces`}." + f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]." ) except (KeyboardInterrupt, Exception): - error_queue.put((index,) + sys.exc_info()[:2]) + error_type, error_message, _ = sys.exc_info() + trace = traceback.format_exc() + + error_queue.put((index, error_type, error_message, trace)) pipe.send((None, False)) finally: - env.close() + env.close() \ No newline at end of file From eaf9a1cb380abb807fc39796ab03f83bc304637b Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 10:58:58 +0100 Subject: [PATCH 060/125] chore: pre-commits and annotaions --- mava/wrappers/gym.py | 55 +++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 3fec9f47e..556fba094 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -13,8 +13,11 @@ # limitations under the License. import sys +import traceback import warnings -from typing import Any, Callable, Dict, Optional, Tuple +from multiprocessing import Queue +from multiprocessing.connection import Connection +from typing import Any, Callable, Dict, Optional, Tuple, Union import gymnasium import numpy as np @@ -22,13 +25,6 @@ from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray -import multiprocessing -import sys -import traceback -from copy import deepcopy -from multiprocessing import Queue -from multiprocessing.connection import Connection - # Filter out the warnings warnings.filterwarnings("ignore", module="gymnasium.utils.passive_env_checker") @@ -58,7 +54,7 @@ def __init__( def reset( self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[np.ndarray, Dict]: + ) -> Tuple[NDArray, Dict]: if seed is not None: self.env.seed(seed) @@ -71,7 +67,7 @@ def reset( return np.array(agents_view), info - def step(self, actions: NDArray) -> Tuple: + def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: agents_view, reward, terminated, truncated, info = self._env.step(actions) @@ -97,7 +93,7 @@ def get_global_obs(self, obs: NDArray) -> NDArray: class GymLBFWrapper(GymWrapper): - """Wrapper for rware gym environments""" + """Wrapper for LBF gym environments""" def __init__( self, @@ -114,15 +110,14 @@ def __init__( """ super().__init__(env, use_shared_rewards, add_global_state) - def step(self, actions: NDArray) -> Tuple: + def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: agents_view, reward, terminated, truncated, info = super().step(actions) truncated = np.repeat(truncated, self.num_agents) terminated = np.repeat(terminated, self.num_agents) - - return agents_view, reward, terminated, truncated, info + return agents_view, reward, terminated, truncated, info class GymRecordEpisodeMetrics(gymnasium.Wrapper): @@ -136,7 +131,7 @@ def __init__(self, env: gymnasium.Env): def reset( self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[np.ndarray, Dict]: + ) -> Tuple[NDArray, Dict]: # Reset the env agents_view, info = self._env.reset(seed, options) @@ -202,29 +197,29 @@ def __init__(self, env: gymnasium.Env): def reset( self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[np.ndarray, Dict]: + ) -> Tuple[NDArray, Dict]: """Reset the environment.""" obs, info = self.env.reset(seed, options) obs = np.concatenate([self.agent_ids, obs], axis=1) return obs, info - def step(self, action: list) -> Tuple[np.ndarray, float, bool, bool, Dict]: + def step(self, action: list) -> Tuple[NDArray, float, bool, bool, Dict]: """Step the environment.""" obs, reward, terminated, truncated, info = self.env.step(action) obs = np.concatenate([self.agent_ids, obs], axis=1) return obs, reward, terminated, truncated, info -# Copied form https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/vector/async_vector_env.py +# Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py # Modified to work with multiple agents -def async_multiagent_worker( +def async_multiagent_worker( # noqa CCR001 index: int, - env_fn: callable, + env_fn: Callable, pipe: Connection, parent_pipe: Connection, - shared_memory: multiprocessing.Array | dict[str, Any] | tuple[Any, ...], + shared_memory: Union[NDArray, dict[str, Any], tuple[Any, ...]], error_queue: Queue, -): +) -> None: env = env_fn() observation_space = env.observation_space action_space = env.action_space @@ -239,9 +234,7 @@ def async_multiagent_worker( if command == "reset": observation, info = env.reset(**data) if shared_memory: - write_to_shared_memory( - observation_space, index, observation, shared_memory - ) + write_to_shared_memory(observation_space, index, observation, shared_memory) observation = None autoreset = False pipe.send(((observation, info), True)) @@ -260,9 +253,7 @@ def async_multiagent_worker( autoreset = np.logical_or(terminated, truncated).all() if shared_memory: - write_to_shared_memory( - observation_space, index, observation, shared_memory - ) + write_to_shared_memory(observation_space, index, observation, shared_memory) observation = None pipe.send(((observation, reward, terminated, truncated, info), True)) @@ -273,7 +264,8 @@ def async_multiagent_worker( name, args, kwargs = data if name in ["reset", "step", "close", "_setattr", "_check_spaces"]: raise ValueError( - f"Trying to call function `{name}` with `call`, use `{name}` directly instead." + f"Trying to call function `{name}` with \ + `call`, use `{name}` directly instead." ) attr = env.get_wrapper_attr(name) @@ -294,7 +286,8 @@ def async_multiagent_worker( ) else: raise RuntimeError( - f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]." + f"Received unknown command `{command}`. Must be one of \ + [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]." ) except (KeyboardInterrupt, Exception): error_type, error_message, _ = sys.exc_info() @@ -303,4 +296,4 @@ def async_multiagent_worker( error_queue.put((index, error_type, error_message, trace)) pipe.send((None, False)) finally: - env.close() \ No newline at end of file + env.close() From 16c0ac3645ed66c519c71b16fa8dd4f2092c9d08 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 14:22:27 +0100 Subject: [PATCH 061/125] fix: config file fixes --- mava/configs/env/lbf_gym.yaml | 4 +++- mava/configs/env/rware_gym.yaml | 4 +++- mava/configs/env/scenario/gym-rware-small-4ag.yaml | 4 ++++ mava/configs/env/scenario/gym-rware-tiny-2ag.yaml | 4 ++++ mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml | 4 ++++ mava/configs/env/scenario/gym-rware-tiny-4ag.yaml | 4 ++++ 6 files changed, 22 insertions(+), 2 deletions(-) diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml index 0c6016dd4..6981f3492 100644 --- a/mava/configs/env/lbf_gym.yaml +++ b/mava/configs/env/lbf_gym.yaml @@ -1,5 +1,7 @@ # ---Environment Configs--- -scenario: gym-2s-8x8-2p-2f-coop copy # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] +defaults: + - _self_ + - scenario: gym-2s-8x8-2p-2f-coop # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] env_name: LevelBasedForaging # Used for logging purposes. diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml index 4d5e0c7f3..87bd3a473 100644 --- a/mava/configs/env/rware_gym.yaml +++ b/mava/configs/env/rware_gym.yaml @@ -1,5 +1,7 @@ # ---Environment Configs--- -scenario: gym-2s-8x8-2p-2f-coop # [gym-rware-tiny-2ag, gym-rware-tiny-4ag, gym-rware-tiny-4ag-easy, gym-rware-small-4ag] +defaults: + - _self_ + - scenario: gym-rware-tiny-2ag # [gym-rware-tiny-2ag, gym-rware-tiny-4ag, gym-rware-tiny-4ag-easy, gym-rware-small-4ag] env_name: RobotWarehouse # Used for logging purposes. diff --git a/mava/configs/env/scenario/gym-rware-small-4ag.yaml b/mava/configs/env/scenario/gym-rware-small-4ag.yaml index af3eb830b..39f8efa4e 100644 --- a/mava/configs/env/scenario/gym-rware-small-4ag.yaml +++ b/mava/configs/env/scenario/gym-rware-small-4ag.yaml @@ -9,6 +9,10 @@ task_config: n_agents: 4 sensor_range: 1 request_queue_size: 4 + msg_bits : 0 + max_inactivity_steps : null + max_steps : 500 + reward_type : 0 env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml b/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml index e648887a0..95ef11fc2 100644 --- a/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml +++ b/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml @@ -9,6 +9,10 @@ task_config: n_agents: 2 sensor_range: 1 request_queue_size: 2 + msg_bits : 0 + max_inactivity_steps : null + max_steps : 500 + reward_type : 0 env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml b/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml index 7d8840882..7753b73ec 100644 --- a/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml +++ b/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml @@ -9,6 +9,10 @@ task_config: n_agents: 4 sensor_range: 1 request_queue_size: 8 + msg_bits : 0 + max_inactivity_steps : null + max_steps : 500 + reward_type : 0 env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml b/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml index dbfe55bd4..c28cf92c5 100644 --- a/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml +++ b/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml @@ -9,6 +9,10 @@ task_config: n_agents: 4 sensor_range: 1 request_queue_size: 4 + msg_bits : 0 + max_inactivity_steps : null + max_steps : 500 + reward_type : 0 env_kwargs: {} # there are no scenario specific env_kwargs for this env From 18b928d22b5b5b2ddaae215c1f5fd8c07821ebe6 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 15:47:06 +0100 Subject: [PATCH 062/125] fix: rware import --- mava/utils/make_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 26197a289..95c8ea33f 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -21,7 +21,7 @@ import jaxmarl import jumanji import matrax -import rware.warehouse as gym_rware +from rware.warehouse import Warehouse as gym_rware from gigastep import ScenarioBuilder from jaxmarl.environments.smax import map_name_to_scenario from jumanji.env import Environment From 19a776599f0c46dcfbb92fa2275ec4880d54c6b8 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 18:48:45 +0100 Subject: [PATCH 063/125] fix: better agent ids wrapper? --- mava/utils/make_env.py | 4 ++-- mava/wrappers/gym.py | 25 ++++++++++++------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 95c8ea33f..e49d6344b 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -21,7 +21,6 @@ import jaxmarl import jumanji import matrax -from rware.warehouse import Warehouse as gym_rware from gigastep import ScenarioBuilder from jaxmarl.environments.smax import map_name_to_scenario from jumanji.env import Environment @@ -39,6 +38,7 @@ ) from lbforaging.foraging import environment as gym_lbf from omegaconf import DictConfig +from rware.warehouse import Warehouse as gym_Warehouse from mava.wrappers import ( AgentIDWrapper, @@ -75,7 +75,7 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} _gym_registry = { - "RobotWarehouse": (gym_rware, GymWrapper), + "RobotWarehouse": (gym_Warehouse, GymWrapper), "LevelBasedForaging": (gym_lbf, GymLBFWrapper), } diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 556fba094..c175dedd7 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -21,7 +21,6 @@ import gymnasium import numpy as np -from gymnasium import spaces from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray @@ -182,18 +181,7 @@ def __init__(self, env: gymnasium.Env): super().__init__(env) self.agent_ids = np.eye(self.env.num_agents) - observation_space = self.env.observation_space[0] - _obs_low, _obs_high, _obs_dtype, _obs_shape = ( - observation_space.low[0], - observation_space.high[0], - observation_space.dtype, - observation_space.shape, - ) - _new_obs_shape = (_obs_shape[0] + self.env.num_agents,) - _observation_boxs = [ - spaces.Box(low=_obs_low, high=_obs_high, shape=_new_obs_shape, dtype=_obs_dtype) - ] * self.env.num_agents - self.observation_space = spaces.Tuple(_observation_boxs) + self.observation_space = self.modify_space(self.env.observation_space) def reset( self, seed: Optional[int] = None, options: Optional[dict] = None @@ -209,6 +197,17 @@ def step(self, action: list) -> Tuple[NDArray, float, bool, bool, Dict]: obs = np.concatenate([self.agent_ids, obs], axis=1) return obs, reward, terminated, truncated, info + def modify_space(self, space: gymnasium.spaces) -> gymnasium.spaces: + if isinstance(space, gymnasium.spaces.Box): + new_shape = space.shape[0] + len(self.agent_ids) + return gymnasium.spaces.Box( + low=space.low, high=space.high, shape=new_shape, dtype=space.dtype + ) + elif isinstance(space, gymnasium.spaces.Tuple): + return gymnasium.spaces.Tuple(self.modify_space(s) for s in space) + else: + raise ValueError(f"Space {type(space)} is not currently supported.") + # Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py # Modified to work with multiple agents From c4a05d69effec40cbdbfd33c700b0adeda52f69b Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 18:55:16 +0100 Subject: [PATCH 064/125] chore: bunch of minor changes --- mava/wrappers/gym.py | 29 +++++------------------------ 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index c175dedd7..dcaa6a5ad 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -29,7 +29,10 @@ class GymWrapper(gymnasium.Wrapper): - """Wrapper for gym environments.""" + """Base wrapper for multi-agent gym environments. + This wrapper works out of the box for RobotWarehouse. + See `GymLBFWrapper` for how it can be modified to work for other environments. + """ def __init__( self, @@ -54,7 +57,6 @@ def __init__( def reset( self, seed: Optional[int] = None, options: Optional[dict] = None ) -> Tuple[NDArray, Dict]: - if seed is not None: self.env.seed(seed) @@ -67,7 +69,6 @@ def reset( return np.array(agents_view), info def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: - agents_view, reward, terminated, truncated, info = self._env.step(actions) info = {"actions_mask": self.get_actions_mask(info)} @@ -92,25 +93,9 @@ def get_global_obs(self, obs: NDArray) -> NDArray: class GymLBFWrapper(GymWrapper): - """Wrapper for LBF gym environments""" - - def __init__( - self, - env: gymnasium.Env, - use_shared_rewards: bool = True, - add_global_state: bool = False, - ): - """Initialise the gym wrapper - Args: - env (gymnasium.env): gymnasium env instance. - use_shared_rewards (bool, optional): Use individual or shared rewards. - Defaults to False. - add_global_state (bool, optional) : Create global observations. Defaults to False. - """ - super().__init__(env, use_shared_rewards, add_global_state) + """Wrapper for the gym level based foraging environment.""" def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: - agents_view, reward, terminated, truncated, info = super().step(actions) truncated = np.repeat(truncated, self.num_agents) @@ -131,8 +116,6 @@ def __init__(self, env: gymnasium.Env): def reset( self, seed: Optional[int] = None, options: Optional[dict] = None ) -> Tuple[NDArray, Dict]: - - # Reset the env agents_view, info = self._env.reset(seed, options) # Create the metrics dict @@ -154,8 +137,6 @@ def reset( return agents_view, info def step(self, actions: NDArray) -> Tuple: - - # Step the env agents_view, reward, terminated, truncated, info = self._env.step(actions) self.running_count_episode_return += float(np.mean(reward)) From 559581885bb520cde72fc5a46b4e11f21bec327f Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 18 Jul 2024 19:13:29 +0100 Subject: [PATCH 065/125] chore : annotation --- mava/wrappers/gym.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index dcaa6a5ad..e7576714d 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -136,7 +136,7 @@ def reset( return agents_view, info - def step(self, actions: NDArray) -> Tuple: + def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: agents_view, reward, terminated, truncated, info = self._env.step(actions) self.running_count_episode_return += float(np.mean(reward)) From 29b1303214c29bc3f129b027f6112432e885d662 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 12:05:35 +0100 Subject: [PATCH 066/125] chore: comments --- mava/wrappers/gym.py | 1 + requirements/requirements.txt | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index e7576714d..18d3ede73 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -230,6 +230,7 @@ def async_multiagent_worker( # noqa CCR001 truncated, info, ) = env.step(data) + # The autoreset was modified to work with boolean arrays. autoreset = np.logical_or(terminated, truncated).all() if shared_memory: diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 74b07af25..0c68a3ca5 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,7 +10,7 @@ jax jaxlib jaxmarl jumanji @ git+https://github.com/sash-a/jumanji -lbforaging @ git+https://github.com/LukasSchaefer/lb-foraging.git@gymnasium_integration +lbforaging @ git+https://github.com/LukasSchaefer/lb-foraging.git@gymnasium_integration # fixes: https://github.com/semitable/lb-foraging/issues/20 matrax @ git+https://github.com/instadeepai/matrax mujoco==3.1.3 mujoco-mjx==3.1.3 @@ -19,7 +19,7 @@ numpy omegaconf optax protobuf~=3.20 -rware @ git+https://github.com/RuanJohn/robotic-warehouse.git +rware @ git+https://github.com/RuanJohn/robotic-warehouse.git # compatibility with latest gymnasium scipy==1.12.0 tensorboard_logger tensorflow_probability From 669dfbd044998fedd961c3fbb0c192d5b07d8fd5 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 13:08:10 +0100 Subject: [PATCH 067/125] feat: restructured the folders --- mava/systems/{anakin => }/ppo/__init__.py | 0 mava/systems/{anakin/ppo => ppo/anakin}/ff_ippo.py | 2 +- mava/systems/{anakin/ppo => ppo/anakin}/ff_mappo.py | 2 +- mava/systems/{anakin/ppo => ppo/anakin}/rec_ippo.py | 2 +- mava/systems/{anakin/ppo => ppo/anakin}/rec_mappo.py | 2 +- mava/systems/{sebulba/ppo => ppo/sebulba}/ff_ippo.py | 0 mava/systems/{anakin => }/ppo/types.py | 0 mava/systems/{anakin => }/q_learning/__init__.py | 0 .../systems/{anakin/q_learning => q_learning/anakin}/rec_iql.py | 2 +- mava/systems/{anakin => }/q_learning/types.py | 0 mava/systems/{anakin => }/sac/__init__.py | 0 mava/systems/{anakin/sac => sac/anakin}/ff_isac.py | 2 +- mava/systems/{anakin/sac => sac/anakin}/ff_masac.py | 2 +- mava/systems/{anakin => }/sac/types.py | 0 mava/utils/checkpointing.py | 2 +- 15 files changed, 8 insertions(+), 8 deletions(-) rename mava/systems/{anakin => }/ppo/__init__.py (100%) rename mava/systems/{anakin/ppo => ppo/anakin}/ff_ippo.py (99%) rename mava/systems/{anakin/ppo => ppo/anakin}/ff_mappo.py (99%) rename mava/systems/{anakin/ppo => ppo/anakin}/rec_ippo.py (99%) rename mava/systems/{anakin/ppo => ppo/anakin}/rec_mappo.py (99%) rename mava/systems/{sebulba/ppo => ppo/sebulba}/ff_ippo.py (100%) rename mava/systems/{anakin => }/ppo/types.py (100%) rename mava/systems/{anakin => }/q_learning/__init__.py (100%) rename mava/systems/{anakin/q_learning => q_learning/anakin}/rec_iql.py (99%) rename mava/systems/{anakin => }/q_learning/types.py (100%) rename mava/systems/{anakin => }/sac/__init__.py (100%) rename mava/systems/{anakin/sac => sac/anakin}/ff_isac.py (99%) rename mava/systems/{anakin/sac => sac/anakin}/ff_masac.py (99%) rename mava/systems/{anakin => }/sac/types.py (100%) diff --git a/mava/systems/anakin/ppo/__init__.py b/mava/systems/ppo/__init__.py similarity index 100% rename from mava/systems/anakin/ppo/__init__.py rename to mava/systems/ppo/__init__.py diff --git a/mava/systems/anakin/ppo/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py similarity index 99% rename from mava/systems/anakin/ppo/ff_ippo.py rename to mava/systems/ppo/anakin/ff_ippo.py index 51efd10e7..44e196535 100644 --- a/mava/systems/anakin/ppo/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -32,7 +32,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer diff --git a/mava/systems/anakin/ppo/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py similarity index 99% rename from mava/systems/anakin/ppo/ff_mappo.py rename to mava/systems/ppo/anakin/ff_mappo.py index a9364fdfc..7f7dce965 100644 --- a/mava/systems/anakin/ppo/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -31,7 +31,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer diff --git a/mava/systems/anakin/ppo/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py similarity index 99% rename from mava/systems/anakin/ppo/rec_ippo.py rename to mava/systems/ppo/anakin/rec_ippo.py index a4d3df428..1f962aa38 100644 --- a/mava/systems/anakin/ppo/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -33,7 +33,7 @@ from mava.networks import RecurrentActor as Actor from mava.networks import RecurrentValueNet as Critic from mava.networks import ScannedRNN -from mava.systems.anakin.ppo.types import ( +from mava.systems.ppo.types import ( HiddenStates, OptStates, Params, diff --git a/mava/systems/anakin/ppo/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py similarity index 99% rename from mava/systems/anakin/ppo/rec_mappo.py rename to mava/systems/ppo/anakin/rec_mappo.py index 93736cf10..0afb3a6c2 100644 --- a/mava/systems/anakin/ppo/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -33,7 +33,7 @@ from mava.networks import RecurrentActor as Actor from mava.networks import RecurrentValueNet as Critic from mava.networks import ScannedRNN -from mava.systems.anakin.ppo.types import ( +from mava.systems.ppo.types import ( HiddenStates, OptStates, Params, diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py similarity index 100% rename from mava/systems/sebulba/ppo/ff_ippo.py rename to mava/systems/ppo/sebulba/ff_ippo.py diff --git a/mava/systems/anakin/ppo/types.py b/mava/systems/ppo/types.py similarity index 100% rename from mava/systems/anakin/ppo/types.py rename to mava/systems/ppo/types.py diff --git a/mava/systems/anakin/q_learning/__init__.py b/mava/systems/q_learning/__init__.py similarity index 100% rename from mava/systems/anakin/q_learning/__init__.py rename to mava/systems/q_learning/__init__.py diff --git a/mava/systems/anakin/q_learning/rec_iql.py b/mava/systems/q_learning/anakin/rec_iql.py similarity index 99% rename from mava/systems/anakin/q_learning/rec_iql.py rename to mava/systems/q_learning/anakin/rec_iql.py index 89139277a..c4d31aade 100644 --- a/mava/systems/anakin/q_learning/rec_iql.py +++ b/mava/systems/q_learning/anakin/rec_iql.py @@ -34,7 +34,7 @@ from mava.evaluator import make_eval_fns from mava.networks import RecQNetwork, ScannedRNN -from mava.systems.anakin.q_learning.types import ( +from mava.systems.q_learning.types import ( ActionSelectionState, ActionState, LearnerState, diff --git a/mava/systems/anakin/q_learning/types.py b/mava/systems/q_learning/types.py similarity index 100% rename from mava/systems/anakin/q_learning/types.py rename to mava/systems/q_learning/types.py diff --git a/mava/systems/anakin/sac/__init__.py b/mava/systems/sac/__init__.py similarity index 100% rename from mava/systems/anakin/sac/__init__.py rename to mava/systems/sac/__init__.py diff --git a/mava/systems/anakin/sac/ff_isac.py b/mava/systems/sac/anakin/ff_isac.py similarity index 99% rename from mava/systems/anakin/sac/ff_isac.py rename to mava/systems/sac/anakin/ff_isac.py index 1642176f3..d6963ab5c 100644 --- a/mava/systems/anakin/sac/ff_isac.py +++ b/mava/systems/sac/anakin/ff_isac.py @@ -34,7 +34,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardQNet as QNetwork -from mava.systems.anakin.sac.types import ( +from mava.systems.sac.types import ( BufferState, LearnerState, Metrics, diff --git a/mava/systems/anakin/sac/ff_masac.py b/mava/systems/sac/anakin/ff_masac.py similarity index 99% rename from mava/systems/anakin/sac/ff_masac.py rename to mava/systems/sac/anakin/ff_masac.py index 2367a67a4..c256018e9 100644 --- a/mava/systems/anakin/sac/ff_masac.py +++ b/mava/systems/sac/anakin/ff_masac.py @@ -34,7 +34,7 @@ from mava.evaluator import make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardQNet as QNetwork -from mava.systems.anakin.sac.types import ( +from mava.systems.sac.types import ( BufferState, LearnerState, Metrics, diff --git a/mava/systems/anakin/sac/types.py b/mava/systems/sac/types.py similarity index 100% rename from mava/systems/anakin/sac/types.py rename to mava/systems/sac/types.py diff --git a/mava/utils/checkpointing.py b/mava/utils/checkpointing.py index 230c4938d..8955f76ce 100644 --- a/mava/utils/checkpointing.py +++ b/mava/utils/checkpointing.py @@ -24,7 +24,7 @@ from jax.tree_util import tree_map from omegaconf import DictConfig, OmegaConf -from mava.systems.anakin.ppo.types import HiddenStates, Params +from mava.systems.ppo.types import HiddenStates, Params from mava.types import MavaState # Keep track of the version of the checkpointer From d1f8364cd3a70cfa7bebdea6709044f1f770fc42 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 14:03:18 +0100 Subject: [PATCH 068/125] update the gym wrappers --- mava/configs/arch/anakin.yaml | 3 +- mava/configs/arch/sebulba.yaml | 8 +- mava/configs/default_ff_ippo.yaml | 2 +- mava/configs/env/lbf_gym.yaml | 19 ++ mava/configs/env/rware_gym.yaml | 19 ++ .../env/scenario/gym-lbf-10x10-3p-3f.yaml | 15 ++ .../env/scenario/gym-lbf-15x15-3p-5f.yaml | 15 ++ .../env/scenario/gym-lbf-15x15-4p-3f.yaml | 15 ++ .../env/scenario/gym-lbf-15x15-4p-5f.yaml | 15 ++ .../env/scenario/gym-lbf-2s-10x10-3p-3f.yaml | 15 ++ .../scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml | 15 ++ .../env/scenario/gym-lbf-8x8-2p-2f-coop.yaml | 15 ++ .../env/scenario/gym-rware-small-4ag.yaml | 18 ++ .../env/scenario/gym-rware-tiny-2ag.yaml | 18 ++ .../env/scenario/gym-rware-tiny-4ag-easy.yaml | 18 ++ .../env/scenario/gym-rware-tiny-4ag.yaml | 18 ++ mava/configs/system/ppo/ff_ippo.yaml | 6 +- mava/utils/logger.py | 5 +- mava/utils/make_env.py | 45 ++-- mava/wrappers/__init__.py | 4 +- mava/wrappers/gym.py | 242 ++++++++---------- requirements/requirements.txt | 4 +- 22 files changed, 362 insertions(+), 172 deletions(-) create mode 100644 mava/configs/env/lbf_gym.yaml create mode 100644 mava/configs/env/rware_gym.yaml create mode 100644 mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml create mode 100644 mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml create mode 100644 mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml create mode 100644 mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml create mode 100644 mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml create mode 100644 mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml create mode 100644 mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml create mode 100644 mava/configs/env/scenario/gym-rware-small-4ag.yaml create mode 100644 mava/configs/env/scenario/gym-rware-tiny-2ag.yaml create mode 100644 mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml create mode 100644 mava/configs/env/scenario/gym-rware-tiny-4ag.yaml diff --git a/mava/configs/arch/anakin.yaml b/mava/configs/arch/anakin.yaml index d58d85286..eb948b7a1 100644 --- a/mava/configs/arch/anakin.yaml +++ b/mava/configs/arch/anakin.yaml @@ -1,5 +1,6 @@ # --- Anakin config --- -arch_name: "Anakin" +architecture_name: anakin + # --- Training --- num_envs: 16 # Number of vectorised environments per device. diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index e0305e2dc..0b539059b 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -1,6 +1,8 @@ # --- Sebulba config --- -arch_name: "Sebulba" -num_envs: 32 # number of envs per thread +architecture_name: sebulba + +# --- Training --- +num_envs: 32 # number of environments per thread. # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select @@ -12,6 +14,6 @@ absolute_metric: True # Whether the absolute metric should be computed. For more # on the absolute metric please see: https://arxiv.org/abs/2209.10485 # --- Sebulba devices config --- -n_threads_per_executor: 2 # num of different threads/env batches per actor +n_threads_per_executor: 1 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices diff --git a/mava/configs/default_ff_ippo.yaml b/mava/configs/default_ff_ippo.yaml index d942584ce..c4aa6ea49 100644 --- a/mava/configs/default_ff_ippo.yaml +++ b/mava/configs/default_ff_ippo.yaml @@ -3,5 +3,5 @@ defaults: - arch: anakin - system: ppo/ff_ippo - network: mlp - - env: rware + - env: rware_gym - _self_ diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml new file mode 100644 index 000000000..6981f3492 --- /dev/null +++ b/mava/configs/env/lbf_gym.yaml @@ -0,0 +1,19 @@ +# ---Environment Configs--- +defaults: + - _self_ + - scenario: gym-2s-8x8-2p-2f-coop # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] + +env_name: LevelBasedForaging # Used for logging purposes. + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the add agents IDs to the observations returned by the environment. +add_agent_id : False + +# Whether or not to log the winrate of this environment. +log_win_rate: False + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml new file mode 100644 index 000000000..87bd3a473 --- /dev/null +++ b/mava/configs/env/rware_gym.yaml @@ -0,0 +1,19 @@ +# ---Environment Configs--- +defaults: + - _self_ + - scenario: gym-rware-tiny-2ag # [gym-rware-tiny-2ag, gym-rware-tiny-4ag, gym-rware-tiny-4ag-easy, gym-rware-small-4ag] + +env_name: RobotWarehouse # Used for logging purposes. + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the add agents IDs to the observations returned by the environment. +add_agent_id : False + +# Whether or not to log the winrate of this environment. +log_win_rate: False + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True diff --git a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml new file mode 100644 index 000000000..386431be4 --- /dev/null +++ b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml @@ -0,0 +1,15 @@ +# The config of the 10x10-3p-3f scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 10x10-3p-3f + +task_config: + field_size: [10,10] + sight: 10 + num_agents: 3 + max_food: 3 + max_player_level: 2 + force_coop: False + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml new file mode 100644 index 000000000..1a8380511 --- /dev/null +++ b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml @@ -0,0 +1,15 @@ +# The config of the 15x15-3p-5f scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 15x15-3p-5f + +task_config: + field_size: [15, 15] + sight: 15 + num_agents: 3 + max_food: 5 + max_player_level: 2 + force_coop: False + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml new file mode 100644 index 000000000..fa22f737b --- /dev/null +++ b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml @@ -0,0 +1,15 @@ +# The config of the 15x15-4p-3f scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 15x15-4p-3f + +task_config: + field_size: [15, 15] + sight: 15 + num_agents: 4 + max_food: 3 + max_player_level: 2 + force_coop: False + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml new file mode 100644 index 000000000..28937215c --- /dev/null +++ b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml @@ -0,0 +1,15 @@ +# The config of the 15x15-4p-5f scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 15x15-4p-5f + +task_config: + field_size: [15, 15] + sight: 15 + num_agents: 4 + max_food: 5 + max_player_level: 2 + force_coop: False + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml new file mode 100644 index 000000000..f0262eb8d --- /dev/null +++ b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml @@ -0,0 +1,15 @@ +# The config of the 2s10x10-3p-3f scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 2s-10x10-3p-3f + +task_config: + field_size: [10, 10] + sight: 2 + num_agents: 3 + max_food: 3 + max_player_level: 2 + force_coop: False + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml new file mode 100644 index 000000000..ffdc5be0e --- /dev/null +++ b/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml @@ -0,0 +1,15 @@ +# The config of the 2s-8x8-2p-2f-coop scenario with the VectorObserver set as default. +name: LevelBasedForaging +task_name: 2s-8x8-2p-2f-coop + +task_config: + field_size: [8, 8] # size of the grid to generate. + sight: 2 # field of view of an agent. + num_agents: 2 # number of agents on the grid. + max_food: 2 # number of food in the environment. + max_player_level: 2 # maximum level of the agents (inclusive). + force_coop: True # force cooperation between agents. + max_episode_steps: 50 # max number of steps per episode. + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml new file mode 100644 index 000000000..52519fecb --- /dev/null +++ b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml @@ -0,0 +1,15 @@ +# The config of the 8x8-2p-2f-coop scenario with the VectorObserver set as default +name: LevelBasedForaging +task_name: 8x8-2p-2f-coop + +task_config: + field_size: [8, 8] + sight: 8 + num_agents: 2 + max_food: 2 + max_player_level: 2 + force_coop: True + max_episode_steps: 50 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-small-4ag.yaml b/mava/configs/env/scenario/gym-rware-small-4ag.yaml new file mode 100644 index 000000000..39f8efa4e --- /dev/null +++ b/mava/configs/env/scenario/gym-rware-small-4ag.yaml @@ -0,0 +1,18 @@ +# The config of the small-4ag environment +name: RobotWarehouse +task_name: small-4ag + +task_config: + column_height: 8 + shelf_rows: 2 + shelf_columns: 3 + n_agents: 4 + sensor_range: 1 + request_queue_size: 4 + msg_bits : 0 + max_inactivity_steps : null + max_steps : 500 + reward_type : 0 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml b/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml new file mode 100644 index 000000000..95ef11fc2 --- /dev/null +++ b/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml @@ -0,0 +1,18 @@ +# The config of the tiny-2ag environment +name: RobotWarehouse +task_name: tiny-2ag + +task_config: + column_height: 8 + shelf_rows: 1 + shelf_columns: 3 + n_agents: 2 + sensor_range: 1 + request_queue_size: 2 + msg_bits : 0 + max_inactivity_steps : null + max_steps : 500 + reward_type : 0 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml b/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml new file mode 100644 index 000000000..7753b73ec --- /dev/null +++ b/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml @@ -0,0 +1,18 @@ +# The config of the tiny-4ag-easy environment +name: RobotWarehouse +task_name: tiny-4ag-easy + +task_config: + column_height: 8 + shelf_rows: 1 + shelf_columns: 3 + n_agents: 4 + sensor_range: 1 + request_queue_size: 8 + msg_bits : 0 + max_inactivity_steps : null + max_steps : 500 + reward_type : 0 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml b/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml new file mode 100644 index 000000000..c28cf92c5 --- /dev/null +++ b/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml @@ -0,0 +1,18 @@ +# The config of the tiny_4ag environment +name: RobotWarehouse +task_name: tiny-4ag + +task_config: + column_height: 8 + shelf_rows: 1 + shelf_columns: 3 + n_agents: 4 + sensor_range: 1 + request_queue_size: 4 + msg_bits : 0 + max_inactivity_steps : null + max_steps : 500 + reward_type : 0 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/system/ppo/ff_ippo.yaml b/mava/configs/system/ppo/ff_ippo.yaml index c80b43ec8..9efb0611a 100644 --- a/mava/configs/system/ppo/ff_ippo.yaml +++ b/mava/configs/system/ppo/ff_ippo.yaml @@ -9,12 +9,12 @@ seed: 42 add_agent_id: True # --- RL hyperparameters --- -actor_lr: 0.0005 # Learning rate for actor network -critic_lr: 0.0005 # Learning rate for critic network +actor_lr: 2.5e-4 # Learning rate for actor network +critic_lr: 2.5e-4 # Learning rate for critic network update_batch_size: 2 # Number of vectorised gradient updates per device. rollout_length: 128 # Number of environment steps per vectorised environment. ppo_epochs: 4 # Number of ppo epochs per training data batch. -num_minibatches: 1 # Number of minibatches per ppo epoch. +num_minibatches: 2 # Number of minibatches per ppo epoch. gamma: 0.99 # Discounting factor. gae_lambda: 0.95 # Lambda value for GAE computation. clip_eps: 0.2 # Clipping value for PPO updates and value function. diff --git a/mava/utils/logger.py b/mava/utils/logger.py index dc217f263..4edad361e 100644 --- a/mava/utils/logger.py +++ b/mava/utils/logger.py @@ -150,9 +150,8 @@ class NeptuneLogger(BaseLogger): def __init__(self, cfg: DictConfig, unique_token: str) -> None: tags = list(cfg.logger.kwargs.neptune_tag) project = cfg.logger.kwargs.neptune_project - mode = "sync" if cfg.arch.arch_name == "Sebulba" else "async" - self.logger = neptune.init_run(project=project, tags=tags, mode=mode) + self.logger = neptune.init_run(project=project, tags=tags) self.logger["config"] = stringify_unsupported(cfg) self.detailed_logging = cfg.logger.kwargs.detailed_neptune_logging @@ -338,7 +337,7 @@ def get_logger_path(config: DictConfig, logger_type: str) -> str: def describe(x: ArrayLike) -> Union[Dict[str, ArrayLike], ArrayLike]: """Generate summary statistics for an array of metrics (mean, std, min, max).""" - if not (isinstance(x, jax.Array) or isinstance(x, np.ndarray)) or x.size <= 1: + if not isinstance(x, jax.Array) or x.size <= 1: return x # np instead of jnp because we don't jit here diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 9828573e0..e49d6344b 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -14,10 +14,10 @@ from typing import Tuple -import gym -import gym.vector -import gym.wrappers -import gym.wrappers.compatibility +import gymnasium +import gymnasium.vector +import gymnasium.wrappers +import gymnasium.wrappers.compatibility import jaxmarl import jumanji import matrax @@ -36,7 +36,9 @@ from jumanji.environments.routing.robot_warehouse.generator import ( RandomGenerator as RwareRandomGenerator, ) +from lbforaging.foraging import environment as gym_lbf from omegaconf import DictConfig +from rware.warehouse import Warehouse as gym_Warehouse from mava.wrappers import ( AgentIDWrapper, @@ -47,14 +49,14 @@ GymAgentIDWrapper, GymLBFWrapper, GymRecordEpisodeMetrics, - GymRwareWrapper, + GymWrapper, LbfWrapper, MabraxWrapper, MatraxWrapper, RecordEpisodeMetrics, RwareWrapper, SmaxWrapper, - _multiagent_worker_shared_memory, + async_multiagent_worker, ) # Registry mapping environment names to their generator and wrapper classes. @@ -72,7 +74,10 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} -_gym_registry = {"RobotWarehouse": GymRwareWrapper, "LevelBasedForaging": GymLBFWrapper} +_gym_registry = { + "RobotWarehouse": (gym_Warehouse, GymWrapper), + "LevelBasedForaging": (gym_lbf, GymLBFWrapper), +} def add_extra_wrappers( @@ -214,9 +219,9 @@ def make_gym_env( config: DictConfig, num_env: int, add_global_state: bool = False, -) -> Environment: # todo : create the appropriate annotation for the sync vector +) -> gymnasium.vector.AsyncVectorEnv: """ - Create a Gym environment. + Create a gymnasium environment. Args: config (Dict): The configuration of the environment. @@ -226,22 +231,20 @@ def make_gym_env( Returns: Async environments. """ - base_env_name = config.env.env_name - wrapper = _gym_registry[base_env_name] - - def create_gym_env( - config: DictConfig, add_global_state: bool = False - ) -> Environment: # todo: add the RecordEpisodeMetrics for gym. - env = gym.make(config.env.scenario) - wrapped_env = wrapper(env, config.env.use_individual_rewards, add_global_state) - if not config.env.implicit_agent_id: - wrapped_env = GymAgentIDWrapper(wrapped_env) # todo : add agent id wrapper for gym . + base_env_name = config.env.scenario.name + env_maker, wrapper = _gym_registry[base_env_name] + + def create_gym_env(config: DictConfig, add_global_state: bool = False) -> Environment: + env = env_maker(**config.env.scenario.task_config) + wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state) + if config.env.add_agent_id: + wrapped_env = GymAgentIDWrapper(wrapped_env) wrapped_env = GymRecordEpisodeMetrics(wrapped_env) return wrapped_env - envs = gym.vector.AsyncVectorEnv( # todo : give them more descriptive names + envs = gymnasium.vector.AsyncVectorEnv( [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)], - worker=_multiagent_worker_shared_memory, + worker=async_multiagent_worker, ) return envs diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 869e78053..80cbccc52 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -19,8 +19,8 @@ GymAgentIDWrapper, GymLBFWrapper, GymRecordEpisodeMetrics, - GymRwareWrapper, - _multiagent_worker_shared_memory, + GymWrapper, + async_multiagent_worker, ) from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper from mava.wrappers.jumanji import ( diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index a9bc5af8e..18d3ede73 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -13,46 +13,50 @@ # limitations under the License. import sys +import traceback import warnings -from typing import Any, Callable, Dict, Optional, Tuple +from multiprocessing import Queue +from multiprocessing.connection import Connection +from typing import Any, Callable, Dict, Optional, Tuple, Union -import gym +import gymnasium import numpy as np -from gym import spaces -from gym.vector.utils import write_to_shared_memory +from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray # Filter out the warnings -warnings.filterwarnings("ignore", module="gym.utils.passive_env_checker") +warnings.filterwarnings("ignore", module="gymnasium.utils.passive_env_checker") -class GymRwareWrapper(gym.Wrapper): - """Wrapper for rware gym environments.""" +class GymWrapper(gymnasium.Wrapper): + """Base wrapper for multi-agent gym environments. + This wrapper works out of the box for RobotWarehouse. + See `GymLBFWrapper` for how it can be modified to work for other environments. + """ def __init__( self, - env: gym.Env, - use_individual_rewards: bool = False, + env: gymnasium.Env, + use_shared_rewards: bool = True, add_global_state: bool = False, ): - """Initialize the gym wrapper + """Initialise the gym wrapper Args: - env (gym.env): gym env instance. - use_individual_rewards (bool, optional): Use individual or group rewards. + env (gymnasium.env): gymnasium env instance. + use_shared_rewards (bool, optional): Use individual or shared rewards. Defaults to False. add_global_state (bool, optional) : Create global observations. Defaults to False. """ super().__init__(env) self._env = env - self.use_individual_rewards = use_individual_rewards + self.use_shared_rewards = use_shared_rewards self.add_global_state = add_global_state self.num_agents = len(self._env.action_space) self.num_actions = self._env.action_space[0].n def reset( self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[np.ndarray, Dict]: - + ) -> Tuple[NDArray, Dict]: if seed is not None: self.env.seed(seed) @@ -64,18 +68,17 @@ def reset( return np.array(agents_view), info - def step(self, actions: NDArray) -> Tuple: - + def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: agents_view, reward, terminated, truncated, info = self._env.step(actions) info = {"actions_mask": self.get_actions_mask(info)} if self.add_global_state: info["global_obs"] = self.get_global_obs(agents_view) - if self.use_individual_rewards: - reward = np.array(reward) + if self.use_shared_rewards: + reward = np.array([np.array(reward).sum()] * self.num_agents) else: - reward = np.array([np.array(reward).mean()] * self.num_agents) + reward = np.array(reward) return agents_view, reward, terminated, truncated, info @@ -89,68 +92,22 @@ def get_global_obs(self, obs: NDArray) -> NDArray: return np.tile(global_obs, (self.num_agents, 1)) -class GymLBFWrapper(gym.Wrapper): - """Wrapper for rware gym environments""" +class GymLBFWrapper(GymWrapper): + """Wrapper for the gym level based foraging environment.""" - def __init__( - self, - env: gym.Env, - use_individual_rewards: bool = False, - add_global_state: bool = False, - ): - """Initialize the gym wrapper - Args: - env (gym.env): gym env instance. - use_individual_rewards (bool, optional): Use individual or group rewards. - Defaults to False. - add_global_state (bool, optional) : Create global observations. Defaults to False. - """ - super().__init__(env) - self._env = env # not having _env leaded tp self.env getting replaced --> circular called - self.use_individual_rewards = use_individual_rewards - self.add_global_state = add_global_state # todo : add the global observations - self.num_agents = len(self._env.action_space) - self.num_actions = self._env.action_space[ - 0 - ].n # todo: all the agents must have the same num_actions, add assertion? - - def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple: - - if seed is not None: - self.env.seed(seed) - - agents_view, info = self._env.reset() + def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + agents_view, reward, terminated, truncated, info = super().step(actions) - info = {"actions_mask": self.get_actions_mask(info)} - - return np.array(agents_view), info - - def step(self, actions: NDArray) -> Tuple: # Vect auto rest - - agents_view, reward, terminated, truncated, info = self._env.step(actions) - - info = {"actions_mask": self.get_actions_mask(info)} - - if self.use_individual_rewards: - reward = np.array(reward) - else: - reward = np.array([np.array(reward).sum()] * self.num_agents) - - truncated = [truncated] * self.num_agents - terminated = [terminated] * self.num_agents + truncated = np.repeat(truncated, self.num_agents) + terminated = np.repeat(terminated, self.num_agents) return agents_view, reward, terminated, truncated, info - def get_actions_mask(self, info: Dict) -> NDArray: - if "action_mask" in info: - return np.array(info["action_mask"]) - return np.ones((self.num_agents, self.num_actions), dtype=np.float32) - -class GymRecordEpisodeMetrics(gym.Wrapper): +class GymRecordEpisodeMetrics(gymnasium.Wrapper): """Record the episode returns and lengths.""" - def __init__(self, env: gym.Env): + def __init__(self, env: gymnasium.Env): super().__init__(env) self._env = env self.running_count_episode_return = 0.0 @@ -158,9 +115,7 @@ def __init__(self, env: gym.Env): def reset( self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[np.ndarray, Dict]: - - # Reset the env + ) -> Tuple[NDArray, Dict]: agents_view, info = self._env.reset(seed, options) # Create the metrics dict @@ -181,9 +136,7 @@ def reset( return agents_view, info - def step(self, actions: NDArray) -> Tuple: - - # Step the env + def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: agents_view, reward, terminated, truncated, info = self._env.step(actions) self.running_count_episode_return += float(np.mean(reward)) @@ -202,111 +155,126 @@ def step(self, actions: NDArray) -> Tuple: return agents_view, reward, terminated, truncated, info -class GymAgentIDWrapper(gym.Wrapper): +class GymAgentIDWrapper(gymnasium.Wrapper): """Add one hot agent IDs to observation.""" - def __init__(self, env: gym.Env): + def __init__(self, env: gymnasium.Env): super().__init__(env) self.agent_ids = np.eye(self.env.num_agents) - observation_space = self.env.observation_space[0] - _obs_low, _obs_high, _obs_dtype, _obs_shape = ( - observation_space.low[0], - observation_space.high[0], - observation_space.dtype, - observation_space.shape, - ) - _new_obs_shape = (_obs_shape[0] + self.env.num_agents,) - _observation_boxs = [ - spaces.Box(low=_obs_low, high=_obs_high, shape=_new_obs_shape, dtype=_obs_dtype) - ] * self.env.num_agents - self.observation_space = spaces.Tuple(_observation_boxs) + self.observation_space = self.modify_space(self.env.observation_space) def reset( self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[np.ndarray, Dict]: + ) -> Tuple[NDArray, Dict]: """Reset the environment.""" obs, info = self.env.reset(seed, options) obs = np.concatenate([self.agent_ids, obs], axis=1) return obs, info - def step(self, action: list) -> Tuple[np.ndarray, float, bool, bool, Dict]: + def step(self, action: list) -> Tuple[NDArray, float, bool, bool, Dict]: """Step the environment.""" obs, reward, terminated, truncated, info = self.env.step(action) obs = np.concatenate([self.agent_ids, obs], axis=1) return obs, reward, terminated, truncated, info + def modify_space(self, space: gymnasium.spaces) -> gymnasium.spaces: + if isinstance(space, gymnasium.spaces.Box): + new_shape = space.shape[0] + len(self.agent_ids) + return gymnasium.spaces.Box( + low=space.low, high=space.high, shape=new_shape, dtype=space.dtype + ) + elif isinstance(space, gymnasium.spaces.Tuple): + return gymnasium.spaces.Tuple(self.modify_space(s) for s in space) + else: + raise ValueError(f"Space {type(space)} is not currently supported.") + -# Copied form https://github.com/openai/gym/blob/master/gym/vector/async_vector_env.py +# Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py # Modified to work with multiple agents -def _multiagent_worker_shared_memory( # noqa: CCR001 +def async_multiagent_worker( # noqa CCR001 index: int, - env_fn: Callable[[], Any], - pipe: Any, - parent_pipe: Any, - shared_memory: Any, - error_queue: Any, + env_fn: Callable, + pipe: Connection, + parent_pipe: Connection, + shared_memory: Union[NDArray, dict[str, Any], tuple[Any, ...]], + error_queue: Queue, ) -> None: - assert shared_memory is not None env = env_fn() observation_space = env.observation_space + action_space = env.action_space + autoreset = False + parent_pipe.close() + try: while True: command, data = pipe.recv() + if command == "reset": observation, info = env.reset(**data) - write_to_shared_memory(observation_space, index, observation, shared_memory) - pipe.send(((None, info), True)) - + if shared_memory: + write_to_shared_memory(observation_space, index, observation, shared_memory) + observation = None + autoreset = False + pipe.send(((observation, info), True)) elif command == "step": - ( - observation, - reward, - terminated, - truncated, - info, - ) = env.step(data) - # Handel the dones across all of envs and agents - if np.logical_or(terminated, truncated).all(): - old_observation, old_info = observation, info + if autoreset: observation, info = env.reset() - info["final_observation"] = old_observation - info["final_info"] = old_info - write_to_shared_memory(observation_space, index, observation, shared_memory) - pipe.send(((None, reward, terminated, truncated, info), True)) - elif command == "seed": - env.seed(data) - pipe.send((None, True)) + reward, terminated, truncated = 0, False, False + else: + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + # The autoreset was modified to work with boolean arrays. + autoreset = np.logical_or(terminated, truncated).all() + + if shared_memory: + write_to_shared_memory(observation_space, index, observation, shared_memory) + observation = None + + pipe.send(((observation, reward, terminated, truncated, info), True)) elif command == "close": pipe.send((None, True)) break elif command == "_call": name, args, kwargs = data - if name in ["reset", "step", "seed", "close"]: + if name in ["reset", "step", "close", "_setattr", "_check_spaces"]: raise ValueError( - f"Trying to call function `{name}` with " - f"`_call`. Use `{name}` directly instead." + f"Trying to call function `{name}` with \ + `call`, use `{name}` directly instead." ) - function = getattr(env, name) - if callable(function): - pipe.send((function(*args, **kwargs), True)) + + attr = env.get_wrapper_attr(name) + if callable(attr): + pipe.send((attr(*args, **kwargs), True)) else: - pipe.send((function, True)) + pipe.send((attr, True)) elif command == "_setattr": name, value = data - setattr(env, name, value) + env.set_wrapper_attr(name, value) pipe.send((None, True)) elif command == "_check_spaces": - pipe.send(((data[0] == observation_space, data[1] == env.action_space), True)) + pipe.send( + ( + (data[0] == observation_space, data[1] == action_space), + True, + ) + ) else: raise RuntimeError( - f"Received unknown command `{command}`. Must " - "be one of {`reset`, `step`, `seed`, `close`, `_call`, " - "`_setattr`, `_check_spaces`}." + f"Received unknown command `{command}`. Must be one of \ + [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]." ) except (KeyboardInterrupt, Exception): - error_queue.put((index,) + sys.exc_info()[:2]) + error_type, error_message, _ = sys.exc_info() + trace = traceback.format_exc() + + error_queue.put((index, error_type, error_message, trace)) pipe.send((None, False)) finally: env.close() diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 3b3bc4c58..0c68a3ca5 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,12 +3,14 @@ distrax @ git+https://github.com/google-deepmind/distrax # distrax release does flashbax~=0.1.0 flax gigastep @ git+https://github.com/mlech26l/gigastep +gymnasium hydra-core==1.3.2 id-marl-eval @ git+https://github.com/instadeepai/marl-eval jax jaxlib jaxmarl jumanji @ git+https://github.com/sash-a/jumanji +lbforaging @ git+https://github.com/LukasSchaefer/lb-foraging.git@gymnasium_integration # fixes: https://github.com/semitable/lb-foraging/issues/20 matrax @ git+https://github.com/instadeepai/matrax mujoco==3.1.3 mujoco-mjx==3.1.3 @@ -17,7 +19,7 @@ numpy omegaconf optax protobuf~=3.20 -rware @ git+https://github.com/RuanJohn/robotic-warehouse.git +rware @ git+https://github.com/RuanJohn/robotic-warehouse.git # compatibility with latest gymnasium scipy==1.12.0 tensorboard_logger tensorflow_probability From dc641c6a6f2f16042304de47e00ba8523b7ce59b Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 14:48:56 +0100 Subject: [PATCH 069/125] folder re-structuring --- mava/configs/default_ff_ippo_seb.yaml | 2 +- mava/configs/env/gym.yaml | 21 ------------------- mava/systems/anakin/sac/__init__.py | 13 ------------ mava/systems/{anakin => ppo}/__init__.py | 0 .../{anakin/ppo => ppo/anakin}/ff_ippo.py | 2 +- .../{anakin/ppo => ppo/anakin}/ff_mappo.py | 2 +- .../{anakin/ppo => ppo/anakin}/rec_ippo.py | 2 +- .../{anakin/ppo => ppo/anakin}/rec_mappo.py | 2 +- .../{sebulba/ppo => ppo/sebulba}/ff_ippo.py | 4 ++-- mava/systems/{anakin => }/ppo/types.py | 0 .../{anakin/ppo => q_learning}/__init__.py | 0 .../anakin}/rec_iql.py | 0 mava/systems/{anakin => }/q_learning/types.py | 0 .../{anakin/q_learning => sac}/__init__.py | 0 .../{anakin/sac => sac/anakin}/ff_isac.py | 0 .../{anakin/sac => sac/anakin}/ff_masac.py | 0 mava/systems/{anakin => }/sac/types.py | 0 mava/utils/checkpointing.py | 2 +- 18 files changed, 8 insertions(+), 42 deletions(-) delete mode 100644 mava/configs/env/gym.yaml delete mode 100644 mava/systems/anakin/sac/__init__.py rename mava/systems/{anakin => ppo}/__init__.py (100%) rename mava/systems/{anakin/ppo => ppo/anakin}/ff_ippo.py (99%) rename mava/systems/{anakin/ppo => ppo/anakin}/ff_mappo.py (99%) rename mava/systems/{anakin/ppo => ppo/anakin}/rec_ippo.py (99%) rename mava/systems/{anakin/ppo => ppo/anakin}/rec_mappo.py (99%) rename mava/systems/{sebulba/ppo => ppo/sebulba}/ff_ippo.py (99%) rename mava/systems/{anakin => }/ppo/types.py (100%) rename mava/systems/{anakin/ppo => q_learning}/__init__.py (100%) rename mava/systems/{anakin/q_learning => q_learning/anakin}/rec_iql.py (100%) rename mava/systems/{anakin => }/q_learning/types.py (100%) rename mava/systems/{anakin/q_learning => sac}/__init__.py (100%) rename mava/systems/{anakin/sac => sac/anakin}/ff_isac.py (100%) rename mava/systems/{anakin/sac => sac/anakin}/ff_masac.py (100%) rename mava/systems/{anakin => }/sac/types.py (100%) diff --git a/mava/configs/default_ff_ippo_seb.yaml b/mava/configs/default_ff_ippo_seb.yaml index 1002d90c4..204719232 100644 --- a/mava/configs/default_ff_ippo_seb.yaml +++ b/mava/configs/default_ff_ippo_seb.yaml @@ -3,5 +3,5 @@ defaults: - arch: sebulba - system: ppo/ff_ippo - network: mlp - - env: gym + - env: rware_gym - _self_ diff --git a/mava/configs/env/gym.yaml b/mava/configs/env/gym.yaml deleted file mode 100644 index 9ddd16d41..000000000 --- a/mava/configs/env/gym.yaml +++ /dev/null @@ -1,21 +0,0 @@ -# ---Environment Configs--- - -scenario: rware:rware-tiny-4ag-v1 #Foraging-8x8-2p-1f-v2 #rware:rware-tiny-2ag-v1 # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] - -env_name: RobotWarehouse #LevelBasedForaging # Used for logging purposes. - -# Defines the metric that will be used to evaluate the performance of the agent. -# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. -eval_metric: episode_return - -# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. -# This should not be changed. -implicit_agent_id: False -# Whether or not to log the winrate of this environment. This should not be changed as not all -# environments have a winrate metric. -log_win_rate: False - -use_individual_rewards: True - -kwargs: - time_limit: 500 diff --git a/mava/systems/anakin/sac/__init__.py b/mava/systems/anakin/sac/__init__.py deleted file mode 100644 index 21db9ec1c..000000000 --- a/mava/systems/anakin/sac/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/mava/systems/anakin/__init__.py b/mava/systems/ppo/__init__.py similarity index 100% rename from mava/systems/anakin/__init__.py rename to mava/systems/ppo/__init__.py diff --git a/mava/systems/anakin/ppo/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py similarity index 99% rename from mava/systems/anakin/ppo/ff_ippo.py rename to mava/systems/ppo/anakin/ff_ippo.py index 408bdf36d..7c93f887d 100644 --- a/mava/systems/anakin/ppo/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -32,7 +32,7 @@ from mava.evaluator import make_anakin_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer diff --git a/mava/systems/anakin/ppo/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py similarity index 99% rename from mava/systems/anakin/ppo/ff_mappo.py rename to mava/systems/ppo/anakin/ff_mappo.py index 93d3f2c0b..17a5cbfcf 100644 --- a/mava/systems/anakin/ppo/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -31,7 +31,7 @@ from mava.evaluator import make_anakin_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer diff --git a/mava/systems/anakin/ppo/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py similarity index 99% rename from mava/systems/anakin/ppo/rec_ippo.py rename to mava/systems/ppo/anakin/rec_ippo.py index 583cd7acc..75f751dd1 100644 --- a/mava/systems/anakin/ppo/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -33,7 +33,7 @@ from mava.networks import RecurrentActor as Actor from mava.networks import RecurrentValueNet as Critic from mava.networks import ScannedRNN -from mava.systems.anakin.ppo.types import ( +from mava.systems.ppo.types import ( HiddenStates, OptStates, Params, diff --git a/mava/systems/anakin/ppo/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py similarity index 99% rename from mava/systems/anakin/ppo/rec_mappo.py rename to mava/systems/ppo/anakin/rec_mappo.py index 74179ab34..3534b96b8 100644 --- a/mava/systems/anakin/ppo/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -33,7 +33,7 @@ from mava.networks import RecurrentActor as Actor from mava.networks import RecurrentValueNet as Critic from mava.networks import ScannedRNN -from mava.systems.anakin.ppo.types import ( +from mava.systems.ppo.types import ( HiddenStates, OptStates, Params, diff --git a/mava/systems/sebulba/ppo/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py similarity index 99% rename from mava/systems/sebulba/ppo/ff_ippo.py rename to mava/systems/ppo/sebulba/ff_ippo.py index 42d2732ae..316ef0533 100644 --- a/mava/systems/sebulba/ppo/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -36,7 +36,7 @@ from mava.evaluator import make_sebulba_eval_fns as make_eval_fns from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.anakin.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition from mava.types import ( ActorApply, CriticApply, @@ -479,7 +479,7 @@ def learner_setup( # Get number of agents and actions. action_space = env.single_action_space config.system.num_agents = len(action_space) - config.system.num_actions = action_space[0].n + config.system.num_actions = int(action_space[0].n) # PRNG keys. key, actor_net_key, critic_net_key = keys diff --git a/mava/systems/anakin/ppo/types.py b/mava/systems/ppo/types.py similarity index 100% rename from mava/systems/anakin/ppo/types.py rename to mava/systems/ppo/types.py diff --git a/mava/systems/anakin/ppo/__init__.py b/mava/systems/q_learning/__init__.py similarity index 100% rename from mava/systems/anakin/ppo/__init__.py rename to mava/systems/q_learning/__init__.py diff --git a/mava/systems/anakin/q_learning/rec_iql.py b/mava/systems/q_learning/anakin/rec_iql.py similarity index 100% rename from mava/systems/anakin/q_learning/rec_iql.py rename to mava/systems/q_learning/anakin/rec_iql.py diff --git a/mava/systems/anakin/q_learning/types.py b/mava/systems/q_learning/types.py similarity index 100% rename from mava/systems/anakin/q_learning/types.py rename to mava/systems/q_learning/types.py diff --git a/mava/systems/anakin/q_learning/__init__.py b/mava/systems/sac/__init__.py similarity index 100% rename from mava/systems/anakin/q_learning/__init__.py rename to mava/systems/sac/__init__.py diff --git a/mava/systems/anakin/sac/ff_isac.py b/mava/systems/sac/anakin/ff_isac.py similarity index 100% rename from mava/systems/anakin/sac/ff_isac.py rename to mava/systems/sac/anakin/ff_isac.py diff --git a/mava/systems/anakin/sac/ff_masac.py b/mava/systems/sac/anakin/ff_masac.py similarity index 100% rename from mava/systems/anakin/sac/ff_masac.py rename to mava/systems/sac/anakin/ff_masac.py diff --git a/mava/systems/anakin/sac/types.py b/mava/systems/sac/types.py similarity index 100% rename from mava/systems/anakin/sac/types.py rename to mava/systems/sac/types.py diff --git a/mava/utils/checkpointing.py b/mava/utils/checkpointing.py index 230c4938d..8955f76ce 100644 --- a/mava/utils/checkpointing.py +++ b/mava/utils/checkpointing.py @@ -24,7 +24,7 @@ from jax.tree_util import tree_map from omegaconf import DictConfig, OmegaConf -from mava.systems.anakin.ppo.types import HiddenStates, Params +from mava.systems.ppo.types import HiddenStates, Params from mava.types import MavaState # Keep track of the version of the checkpointer From 0881d2f1ae12ee3e686dbdf7e53ed7d1cc209ce8 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 16:51:02 +0100 Subject: [PATCH 070/125] fix: removed deprecated jax call --- mava/systems/ppo/sebulba/ff_ippo.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 316ef0533..288249af5 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -148,7 +148,7 @@ def get_action_and_value( # Prepare the data storage_time_start = time.time() next_dones = np.logical_or(terminated, truncated) - metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) # Stack the metrics + metrics = jax.tree_util.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) # Stack the metrics # Append data to storage storage.append( @@ -170,11 +170,11 @@ def get_action_and_value( # Prepare data to share with learner # [PPOTransition() * rollout_len] --> PPOTransition[done=(rollout_len, num_envs, num_agents) # , action=(rollout_len, num_envs, num_agents, num_actions), ...] - stacked_storage = jax.tree_map(lambda *xs: jnp.stack(xs), *storage) + stacked_storage = jax.tree_util.tree_map(lambda *xs: jnp.stack(xs), *storage) # Split the arrays over the different learner_devices on the num_envs axis - sharded_storage = jax.tree_map( + sharded_storage = jax.tree_util.tree_map( lambda x: shard_split_payload(x, 1), stacked_storage ) # (num_learner_devices, rollout_len, num_envs, num_agents, ...) @@ -700,10 +700,10 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 rollout_times.append(time.time() - rollout_start_time) # Concatinate the returned trajectories on the n_env axis - sharded_storages = jax.tree_map( + sharded_storages = jax.tree_util.tree_map( lambda *x: jnp.concatenate(x, axis=2), *sharded_storages ) - sharded_next_obss = jax.tree_map( + sharded_next_obss = jax.tree_util.tree_map( lambda *x: jnp.concatenate(x, axis=1), *sharded_next_obss ) sharded_next_dones = jnp.concatenate(sharded_next_dones, axis=1) @@ -730,7 +730,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 # Log the results of the training. elapsed_time = time.time() - training_start_time t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics = jax.tree_map(lambda *x: np.asarray(x), *episode_metrics) + episode_metrics = jax.tree_util.tree_map(lambda *x: np.asarray(x), *episode_metrics) episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time @@ -744,7 +744,7 @@ def run_experiment(_config: DictConfig) -> float: # noqa: CCR001 logger.log(speed_info, t, eval_step, LogEvent.MISC) if ep_completed: # only log episode metrics if an episode was completed in the rollout. logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - train_metrics = jax.tree_map(lambda *x: np.asarray(x), *train_metrics) + train_metrics = jax.tree_util.tree_map(lambda *x: np.asarray(x), *train_metrics) logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) # Evaluation on the learner From b60cefe8e93797f47d66bf3ff23daadf934f5a9e Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 16:51:50 +0100 Subject: [PATCH 071/125] fix: env wrappers fix --- mava/utils/make_env.py | 4 ++-- mava/wrappers/gym.py | 22 ++++++++++------------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index e49d6344b..5755cc03c 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -36,7 +36,7 @@ from jumanji.environments.routing.robot_warehouse.generator import ( RandomGenerator as RwareRandomGenerator, ) -from lbforaging.foraging import environment as gym_lbf +from lbforaging.foraging import ForagingEnv as gym_ForagingEnv from omegaconf import DictConfig from rware.warehouse import Warehouse as gym_Warehouse @@ -76,7 +76,7 @@ _gym_registry = { "RobotWarehouse": (gym_Warehouse, GymWrapper), - "LevelBasedForaging": (gym_lbf, GymLBFWrapper), + "LevelBasedForaging": (gym_ForagingEnv, GymLBFWrapper), } diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 18d3ede73..35f3d2335 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -219,19 +219,17 @@ def async_multiagent_worker( # noqa CCR001 autoreset = False pipe.send(((observation, info), True)) elif command == "step": - if autoreset: + # Modified the step function to align with 'AutoResetWrapper'. + # The environment resets immediately upon termination or truncation. + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + if np.logical_or(terminated, truncated).all(): observation, info = env.reset() - reward, terminated, truncated = 0, False, False - else: - ( - observation, - reward, - terminated, - truncated, - info, - ) = env.step(data) - # The autoreset was modified to work with boolean arrays. - autoreset = np.logical_or(terminated, truncated).all() if shared_memory: write_to_shared_memory(observation_space, index, observation, shared_memory) From 21aafbffdc1740e99d9ad703e8adc4b5bb3cc8ef Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 16:53:02 +0100 Subject: [PATCH 072/125] fix: config changes --- mava/configs/env/lbf_gym.yaml | 2 +- mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml | 7 +++++-- mava/utils/logger.py | 2 +- 9 files changed, 37 insertions(+), 16 deletions(-) diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml index 6981f3492..b0d783a7e 100644 --- a/mava/configs/env/lbf_gym.yaml +++ b/mava/configs/env/lbf_gym.yaml @@ -1,7 +1,7 @@ # ---Environment Configs--- defaults: - _self_ - - scenario: gym-2s-8x8-2p-2f-coop # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] + - scenario: gym-lbf-2s-8x8-2p-2f-coop # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] env_name: LevelBasedForaging # Used for logging purposes. diff --git a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml index 386431be4..904d94197 100644 --- a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml @@ -5,11 +5,14 @@ task_name: 10x10-3p-3f task_config: field_size: [10,10] sight: 10 - num_agents: 3 - max_food: 3 + players: 3 + max_num_food: 3 max_player_level: 2 force_coop: False max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml index 1a8380511..6b24e8de8 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml @@ -5,11 +5,14 @@ task_name: 15x15-3p-5f task_config: field_size: [15, 15] sight: 15 - num_agents: 3 - max_food: 5 + players: 3 + max_num_food: 5 max_player_level: 2 force_coop: False max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml index fa22f737b..acbb1f6de 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml @@ -5,11 +5,14 @@ task_name: 15x15-4p-3f task_config: field_size: [15, 15] sight: 15 - num_agents: 4 - max_food: 3 + players: 4 + max_num_food: 3 max_player_level: 2 force_coop: False max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml index 28937215c..465385909 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml @@ -5,11 +5,14 @@ task_name: 15x15-4p-5f task_config: field_size: [15, 15] sight: 15 - num_agents: 4 - max_food: 5 + players: 4 + max_num_food: 5 max_player_level: 2 force_coop: False max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml index f0262eb8d..e6af1860f 100644 --- a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml @@ -5,11 +5,14 @@ task_name: 2s-10x10-3p-3f task_config: field_size: [10, 10] sight: 2 - num_agents: 3 - max_food: 3 + players: 3 + max_num_food: 3 max_player_level: 2 force_coop: False max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml index ffdc5be0e..3c318d3cf 100644 --- a/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml +++ b/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml @@ -5,11 +5,14 @@ task_name: 2s-8x8-2p-2f-coop task_config: field_size: [8, 8] # size of the grid to generate. sight: 2 # field of view of an agent. - num_agents: 2 # number of agents on the grid. - max_food: 2 # number of food in the environment. + players: 2 # number of agents on the grid. + max_num_food: 2 # number of food in the environment. max_player_level: 2 # maximum level of the agents (inclusive). force_coop: True # force cooperation between agents. max_episode_steps: 50 # max number of steps per episode. + min_player_level : 1 # minimum level of the agents (inclusive). + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml index 52519fecb..308b891dd 100644 --- a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml +++ b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml @@ -5,11 +5,14 @@ task_name: 8x8-2p-2f-coop task_config: field_size: [8, 8] sight: 8 - num_agents: 2 - max_food: 2 + players: 2 + max_num_food: 2 max_player_level: 2 force_coop: True max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/utils/logger.py b/mava/utils/logger.py index 4edad361e..1416c6061 100644 --- a/mava/utils/logger.py +++ b/mava/utils/logger.py @@ -337,7 +337,7 @@ def get_logger_path(config: DictConfig, logger_type: str) -> str: def describe(x: ArrayLike) -> Union[Dict[str, ArrayLike], ArrayLike]: """Generate summary statistics for an array of metrics (mean, std, min, max).""" - if not isinstance(x, jax.Array) or x.size <= 1: + if not isinstance(x, (jax.Array, np.ndarray)) or x.size <= 1: return x # np instead of jnp because we don't jit here From e09fd60f226f3de52ff4da949b7a53e069e9de21 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 16:55:06 +0100 Subject: [PATCH 073/125] chore: pre-commits --- mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml | 2 +- mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml | 2 +- mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml | 2 +- mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml | 2 +- mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml | 2 +- mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml | 2 +- mava/systems/ppo/sebulba/ff_ippo.py | 4 +++- mava/wrappers/gym.py | 2 -- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml index 904d94197..3aceaf74f 100644 --- a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml @@ -10,7 +10,7 @@ task_config: max_player_level: 2 force_coop: False max_episode_steps: 50 - min_player_level : 1 + min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml index 6b24e8de8..14953f3fc 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml @@ -10,7 +10,7 @@ task_config: max_player_level: 2 force_coop: False max_episode_steps: 50 - min_player_level : 1 + min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml index acbb1f6de..ef678025b 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml @@ -10,7 +10,7 @@ task_config: max_player_level: 2 force_coop: False max_episode_steps: 50 - min_player_level : 1 + min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml index 465385909..c4dcfb979 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml @@ -10,7 +10,7 @@ task_config: max_player_level: 2 force_coop: False max_episode_steps: 50 - min_player_level : 1 + min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml index e6af1860f..b094cda72 100644 --- a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml @@ -10,7 +10,7 @@ task_config: max_player_level: 2 force_coop: False max_episode_steps: 50 - min_player_level : 1 + min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml index 308b891dd..840bbf9f4 100644 --- a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml +++ b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml @@ -10,7 +10,7 @@ task_config: max_player_level: 2 force_coop: True max_episode_steps: 50 - min_player_level : 1 + min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 288249af5..0fe20165e 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -148,7 +148,9 @@ def get_action_and_value( # Prepare the data storage_time_start = time.time() next_dones = np.logical_or(terminated, truncated) - metrics = jax.tree_util.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) # Stack the metrics + metrics = jax.tree_util.tree_map( + lambda *x: jnp.asarray(x), *info["metrics"] + ) # Stack the metrics # Append data to storage storage.append( diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 35f3d2335..7ecfb4b27 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -203,7 +203,6 @@ def async_multiagent_worker( # noqa CCR001 env = env_fn() observation_space = env.observation_space action_space = env.action_space - autoreset = False parent_pipe.close() @@ -216,7 +215,6 @@ def async_multiagent_worker( # noqa CCR001 if shared_memory: write_to_shared_memory(observation_space, index, observation, shared_memory) observation = None - autoreset = False pipe.send(((observation, info), True)) elif command == "step": # Modified the step function to align with 'AutoResetWrapper'. From 2a6452d93b818cfb640e5e1939222ba9b79c3b36 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 17:01:00 +0100 Subject: [PATCH 074/125] fix: config file fixes --- mava/configs/env/lbf_gym.yaml | 2 +- mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml | 7 +++++-- mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml | 7 +++++-- 8 files changed, 36 insertions(+), 15 deletions(-) diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml index 6981f3492..b0d783a7e 100644 --- a/mava/configs/env/lbf_gym.yaml +++ b/mava/configs/env/lbf_gym.yaml @@ -1,7 +1,7 @@ # ---Environment Configs--- defaults: - _self_ - - scenario: gym-2s-8x8-2p-2f-coop # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] + - scenario: gym-lbf-2s-8x8-2p-2f-coop # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] env_name: LevelBasedForaging # Used for logging purposes. diff --git a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml index 386431be4..3aceaf74f 100644 --- a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml @@ -5,11 +5,14 @@ task_name: 10x10-3p-3f task_config: field_size: [10,10] sight: 10 - num_agents: 3 - max_food: 3 + players: 3 + max_num_food: 3 max_player_level: 2 force_coop: False max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml index 1a8380511..14953f3fc 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml @@ -5,11 +5,14 @@ task_name: 15x15-3p-5f task_config: field_size: [15, 15] sight: 15 - num_agents: 3 - max_food: 5 + players: 3 + max_num_food: 5 max_player_level: 2 force_coop: False max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml index fa22f737b..ef678025b 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml @@ -5,11 +5,14 @@ task_name: 15x15-4p-3f task_config: field_size: [15, 15] sight: 15 - num_agents: 4 - max_food: 3 + players: 4 + max_num_food: 3 max_player_level: 2 force_coop: False max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml index 28937215c..c4dcfb979 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml @@ -5,11 +5,14 @@ task_name: 15x15-4p-5f task_config: field_size: [15, 15] sight: 15 - num_agents: 4 - max_food: 5 + players: 4 + max_num_food: 5 max_player_level: 2 force_coop: False max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml index f0262eb8d..b094cda72 100644 --- a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml @@ -5,11 +5,14 @@ task_name: 2s-10x10-3p-3f task_config: field_size: [10, 10] sight: 2 - num_agents: 3 - max_food: 3 + players: 3 + max_num_food: 3 max_player_level: 2 force_coop: False max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml index ffdc5be0e..3c318d3cf 100644 --- a/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml +++ b/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml @@ -5,11 +5,14 @@ task_name: 2s-8x8-2p-2f-coop task_config: field_size: [8, 8] # size of the grid to generate. sight: 2 # field of view of an agent. - num_agents: 2 # number of agents on the grid. - max_food: 2 # number of food in the environment. + players: 2 # number of agents on the grid. + max_num_food: 2 # number of food in the environment. max_player_level: 2 # maximum level of the agents (inclusive). force_coop: True # force cooperation between agents. max_episode_steps: 50 # max number of steps per episode. + min_player_level : 1 # minimum level of the agents (inclusive). + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml index 52519fecb..840bbf9f4 100644 --- a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml +++ b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml @@ -5,11 +5,14 @@ task_name: 8x8-2p-2f-coop task_config: field_size: [8, 8] sight: 8 - num_agents: 2 - max_food: 2 + players: 2 + max_num_food: 2 max_player_level: 2 force_coop: True max_episode_steps: 50 + min_player_level : 1 + min_food_level : null + max_food_level : null env_kwargs: {} # there are no scenario specific env_kwargs for this env From e2f36f91e19c4f67510824939e0d909bdf96b22c Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 17:01:15 +0100 Subject: [PATCH 075/125] fix: LBF import --- mava/utils/make_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index e49d6344b..5755cc03c 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -36,7 +36,7 @@ from jumanji.environments.routing.robot_warehouse.generator import ( RandomGenerator as RwareRandomGenerator, ) -from lbforaging.foraging import environment as gym_lbf +from lbforaging.foraging import ForagingEnv as gym_ForagingEnv from omegaconf import DictConfig from rware.warehouse import Warehouse as gym_Warehouse @@ -76,7 +76,7 @@ _gym_registry = { "RobotWarehouse": (gym_Warehouse, GymWrapper), - "LevelBasedForaging": (gym_lbf, GymLBFWrapper), + "LevelBasedForaging": (gym_ForagingEnv, GymLBFWrapper), } From 29396c98dc474447a6512e3a39bae8738c2cc453 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 17:01:28 +0100 Subject: [PATCH 076/125] fix: Async worker auto-resetting --- mava/wrappers/gym.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 18d3ede73..7b76fc157 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -203,8 +203,6 @@ def async_multiagent_worker( # noqa CCR001 env = env_fn() observation_space = env.observation_space action_space = env.action_space - autoreset = False - parent_pipe.close() try: @@ -216,22 +214,19 @@ def async_multiagent_worker( # noqa CCR001 if shared_memory: write_to_shared_memory(observation_space, index, observation, shared_memory) observation = None - autoreset = False pipe.send(((observation, info), True)) elif command == "step": - if autoreset: + # Modified the step function to align with 'AutoResetWrapper'. + # The environment resets immediately upon termination or truncation. + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + if np.logical_or(terminated, truncated).all(): observation, info = env.reset() - reward, terminated, truncated = 0, False, False - else: - ( - observation, - reward, - terminated, - truncated, - info, - ) = env.step(data) - # The autoreset was modified to work with boolean arrays. - autoreset = np.logical_or(terminated, truncated).all() if shared_memory: write_to_shared_memory(observation_space, index, observation, shared_memory) From 6de0b1e1d999b3e2dbea3264c02a4be33cf2512d Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 19 Jul 2024 17:11:57 +0100 Subject: [PATCH 077/125] chore: minor changes --- mava/configs/default_ff_ippo.yaml | 2 +- mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml | 2 +- mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml | 2 +- mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml | 2 +- mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml | 2 +- mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml | 2 +- mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml | 2 +- mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml | 2 +- mava/utils/make_env.py | 3 +-- 9 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mava/configs/default_ff_ippo.yaml b/mava/configs/default_ff_ippo.yaml index c4aa6ea49..d942584ce 100644 --- a/mava/configs/default_ff_ippo.yaml +++ b/mava/configs/default_ff_ippo.yaml @@ -3,5 +3,5 @@ defaults: - arch: anakin - system: ppo/ff_ippo - network: mlp - - env: rware_gym + - env: rware - _self_ diff --git a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml index 3aceaf74f..a2150115b 100644 --- a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml @@ -9,7 +9,7 @@ task_config: max_num_food: 3 max_player_level: 2 force_coop: False - max_episode_steps: 50 + max_episode_steps: 100 min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml index 14953f3fc..70031bad0 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml @@ -9,7 +9,7 @@ task_config: max_num_food: 5 max_player_level: 2 force_coop: False - max_episode_steps: 50 + max_episode_steps: 100 min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml index ef678025b..b1fe6e4be 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml @@ -9,7 +9,7 @@ task_config: max_num_food: 3 max_player_level: 2 force_coop: False - max_episode_steps: 50 + max_episode_steps: 100 min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml index c4dcfb979..9ce0100f5 100644 --- a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml +++ b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml @@ -9,7 +9,7 @@ task_config: max_num_food: 5 max_player_level: 2 force_coop: False - max_episode_steps: 50 + max_episode_steps: 100 min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml index b094cda72..fea817887 100644 --- a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml +++ b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml @@ -9,7 +9,7 @@ task_config: max_num_food: 3 max_player_level: 2 force_coop: False - max_episode_steps: 50 + max_episode_steps: 100 min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml index 3c318d3cf..b0cacb95c 100644 --- a/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml +++ b/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml @@ -9,7 +9,7 @@ task_config: max_num_food: 2 # number of food in the environment. max_player_level: 2 # maximum level of the agents (inclusive). force_coop: True # force cooperation between agents. - max_episode_steps: 50 # max number of steps per episode. + max_episode_steps: 100 # max number of steps per episode. min_player_level : 1 # minimum level of the agents (inclusive). min_food_level : null max_food_level : null diff --git a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml index 840bbf9f4..3b9cee314 100644 --- a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml +++ b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml @@ -9,7 +9,7 @@ task_config: max_num_food: 2 max_player_level: 2 force_coop: True - max_episode_steps: 50 + max_episode_steps: 100 min_player_level : 1 min_food_level : null max_food_level : null diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 5755cc03c..21b595c06 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -231,8 +231,7 @@ def make_gym_env( Returns: Async environments. """ - base_env_name = config.env.scenario.name - env_maker, wrapper = _gym_registry[base_env_name] + env_maker, wrapper = _gym_registry[config.env.scenario.name] def create_gym_env(config: DictConfig, add_global_state: bool = False) -> Environment: env = env_maker(**config.env.scenario.task_config) From 7584ce5976fcdd5efda95a95a350438de77da8f0 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 22 Jul 2024 09:29:46 +0100 Subject: [PATCH 078/125] fixed: annotations and add agent id spaces --- mava/wrappers/gym.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 7b76fc157..0e1cf6529 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -21,6 +21,7 @@ import gymnasium import numpy as np +from gymnasium import spaces from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray @@ -178,14 +179,14 @@ def step(self, action: list) -> Tuple[NDArray, float, bool, bool, Dict]: obs = np.concatenate([self.agent_ids, obs], axis=1) return obs, reward, terminated, truncated, info - def modify_space(self, space: gymnasium.spaces) -> gymnasium.spaces: - if isinstance(space, gymnasium.spaces.Box): - new_shape = space.shape[0] + len(self.agent_ids) - return gymnasium.spaces.Box( - low=space.low, high=space.high, shape=new_shape, dtype=space.dtype + def modify_space(self, space: spaces.Space) -> spaces.Space: + if isinstance(space, spaces.Box): + new_shape = (space.shape[0] + len(self.agent_ids),) + return spaces.Box( + low=space.low[0], high=space.high[0], shape=new_shape, dtype=space.dtype ) - elif isinstance(space, gymnasium.spaces.Tuple): - return gymnasium.spaces.Tuple(self.modify_space(s) for s in space) + elif isinstance(space, spaces.Tuple): + return spaces.Tuple(self.modify_space(s) for s in space) else: raise ValueError(f"Space {type(space)} is not currently supported.") From e638e9fd36c793efd33ddd827843df3ef87f99ab Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 22 Jul 2024 09:35:54 +0100 Subject: [PATCH 079/125] fix: fixed the logging deadlock for sebulba --- mava/utils/logger.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mava/utils/logger.py b/mava/utils/logger.py index 4edad361e..bf502e25c 100644 --- a/mava/utils/logger.py +++ b/mava/utils/logger.py @@ -150,8 +150,11 @@ class NeptuneLogger(BaseLogger): def __init__(self, cfg: DictConfig, unique_token: str) -> None: tags = list(cfg.logger.kwargs.neptune_tag) project = cfg.logger.kwargs.neptune_project + mode = ( + "async" if cfg.arch.architecture_name == "anakin" else "sync" + ) # async logging leads to deadlocks in sebulba - self.logger = neptune.init_run(project=project, tags=tags) + self.logger = neptune.init_run(project=project, tags=tags, mode=mode) self.logger["config"] = stringify_unsupported(cfg) self.detailed_logging = cfg.logger.kwargs.detailed_neptune_logging From a85aa2fcbb373d554149a40bf0a29441ae15bad1 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 23 Jul 2024 09:34:42 +0100 Subject: [PATCH 080/125] chore: pre-commits --- mava/utils/make_env.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 3cf7982ea..887a987cb 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -228,40 +228,6 @@ def make_gigastep_env( return train_env, eval_env -def make_gym_env( - config: DictConfig, - num_env: int, - add_global_state: bool = False, -) -> gymnasium.vector.AsyncVectorEnv: - """ - Create a gymnasium environment. - - Args: - config (Dict): The configuration of the environment. - num_env (int) : The number of parallel envs to create. - add_global_state (bool): Whether to add the global state to the observation. Default False. - - Returns: - Async environments. - """ - env_maker, wrapper = _gym_registry[config.env.scenario.name] - - def create_gym_env(config: DictConfig, add_global_state: bool = False) -> Environment: - env = env_maker(**config.env.scenario.task_config) - wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state) - if config.env.add_agent_id: - wrapped_env = GymAgentIDWrapper(wrapped_env) - wrapped_env = GymRecordEpisodeMetrics(wrapped_env) - return wrapped_env - - envs = gymnasium.vector.AsyncVectorEnv( - [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)], - worker=async_multiagent_worker, - ) - - return envs - - def make_gym_env( config: DictConfig, num_env: int, From e504b478c7108d024a90f696e07b4e016a3a7ada Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 23 Jul 2024 09:54:13 +0100 Subject: [PATCH 081/125] pre-commit --- mava/wrappers/gym.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 0e1cf6529..520243e92 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -193,7 +193,7 @@ def modify_space(self, space: spaces.Space) -> spaces.Space: # Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py # Modified to work with multiple agents -def async_multiagent_worker( # noqa CCR001 +def async_multiagent_worker( # CCR001 index: int, env_fn: Callable, pipe: Connection, From a19056b431fd93c0a5926988b8c8b08b2e9ddf59 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 25 Jul 2024 22:47:27 +0100 Subject: [PATCH 082/125] feat : major code restructer, non-blocking evalutors --- mava/configs/arch/sebulba.yaml | 6 +- mava/configs/default_ff_ippo.yaml | 2 +- mava/evaluator.py | 176 ++++++------ mava/systems/ppo/sebulba/ff_ippo.py | 418 ++++++++++------------------ mava/utils/make_env.py | 5 +- mava/utils/sebulba_utils.py | 166 +++++++++++ mava/wrappers/__init__.py | 1 + mava/wrappers/gym.py | 63 +++++ 8 files changed, 466 insertions(+), 371 deletions(-) create mode 100644 mava/utils/sebulba_utils.py diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 0b539059b..9d21a51d3 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -2,18 +2,18 @@ architecture_name: sebulba # --- Training --- -num_envs: 32 # number of environments per thread. +num_envs: 2 # number of environments per thread. # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select # an action which corresponds to the greatest logit. If false, the policy will sample # from the logits. -num_eval_episodes: 32 # Number of episodes to evaluate per evaluation. +num_eval_episodes: 2 # Number of episodes to evaluate per evaluation. num_evaluation: 200 # Number of evenly spaced evaluations to perform during training. absolute_metric: True # Whether the absolute metric should be computed. For more details # on the absolute metric please see: https://arxiv.org/abs/2209.10485 # --- Sebulba devices config --- -n_threads_per_executor: 1 # num of different threads/env batches per actor +n_threads_per_executor: 2 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices diff --git a/mava/configs/default_ff_ippo.yaml b/mava/configs/default_ff_ippo.yaml index c4aa6ea49..d942584ce 100644 --- a/mava/configs/default_ff_ippo.yaml +++ b/mava/configs/default_ff_ippo.yaml @@ -3,5 +3,5 @@ defaults: - arch: anakin - system: ppo/ff_ippo - network: mlp - - env: rware_gym + - env: rware - _self_ diff --git a/mava/evaluator.py b/mava/evaluator.py index 2d0183878..e754899ae 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -17,7 +17,6 @@ import warnings from typing import Any, Callable, Dict, Protocol, Tuple, Union -import gymnasium import jax import jax.numpy as jnp import numpy as np @@ -35,7 +34,6 @@ Observation, ObservationGlobalState, RecActorApply, - SebulbaEvalFn, State, ) @@ -211,121 +209,109 @@ def eval_act_fn( return eval_act_fn -# todo : Update -def get_sebulba_ff_evaluator_fn( - env: gymnasium.Env, - apply_fn: ActorApply, +def get_sebulba_eval_fn( + env_maker: Callable, + act_fn: EvalActFn, config: DictConfig, np_rng: np.random.Generator, - log_win_rate: bool = False, -) -> SebulbaEvalFn: - """Get the evaluator function for feedforward networks. + absolute_metric: bool, +) -> EvalFn: + """Creates a function that can be used to evaluate agents on a given environment. Args: - env (Environment): An evironment instance for evaluation. - apply_fn (callable): Network forward pass method. - config (dict): Experiment configuration. + ---- + env: an environment that conforms to the mava environment spec. + act_fn: a function that takes in params, timestep, key and optionally a state + and returns actions and optionally a state (see `EvalActFn`). + config: the system config. + absolute_metric: whether or not this evaluator calculates the absolute_metric. + This determines how many evaluation episodes it does. """ + n_devices = jax.device_count() + eval_episodes = ( + config.arch.num_absolute_metric_eval_episodes + if absolute_metric + else config.arch.num_eval_episodes + ) - @jax.jit - def get_action( # todo explicetly put these on the learner? they should already be there - params: FrozenDict, - observation: Observation, - key: PRNGKey, - ) -> Array: - """Get action.""" - - pi = apply_fn(params, observation) - - if config.arch.evaluation_greedy: - action = pi.mode() - else: - action = pi.sample(seed=key) - - return action + n_parallel_envs = min(eval_episodes, config.arch.num_envs) + episode_loops = math.ceil(eval_episodes / n_parallel_envs) + env = env_maker(config, n_parallel_envs) - def eval_episodes(params: FrozenDict, key: PRNGKey) -> Any: - seeds = np_rng.integers(np.iinfo(np.int64).max, size=env.num_envs).tolist() - obs, info = env.reset(seed=seeds) - dones = np.full(env.num_envs, False) - eval_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) + # Warnings if num eval episodes is not divisible by num parallel envs. + if eval_episodes % n_parallel_envs != 0: + warnings.warn( + f"Number of evaluation episodes ({eval_episodes}) is not divisible by `num_envs` * " + f"`num_devices` ({n_parallel_envs} * {n_devices}). Some extra evaluations will be " + f"executed. New number of evaluation episodes = {episode_loops * n_parallel_envs}", + stacklevel=2, + ) - while not dones.all(): - key, policy_key = jax.random.split(key) + def eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics: + """Evaluates the given params on an environment and returns relevent metrics. - obs = jax.device_put(jnp.stack(obs, axis=1)) - action_mask = jax.device_put(np.stack(info["actions_mask"])) + Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length, + also win rate for environments that support it. - actions = get_action(params, Observation(obs, action_mask), policy_key) - cpu_action = jax.device_get(actions) + Returns: Dict[str, Array] - dictionary of metric name to metric values for each episode. + """ - obs, reward, terminated, truncated, info = env.step(cpu_action.swapaxes(0, 1)) + def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: + """Simulates `num_envs` episodes.""" - next_metrics = jax.tree_map(lambda *x: jnp.asarray(x), *info["metrics"]) + seeds = np_rng.integers(np.iinfo(np.int32).max, size=n_parallel_envs).tolist() + ts = env.reset(seed=seeds) - next_dones = next_metrics["is_terminal_step"] + timesteps = [ts] - update_flags = np.logical_and(next_dones, np.invert(dones)) + actor_state = init_act_state + finished_eps = ts.last() - update_metrics = lambda new_metric, old_metric, update_flags=update_flags: np.where( - (update_flags), new_metric, old_metric - ) + while not finished_eps.all(): + key, act_key = jax.random.split(key) + action, actor_state = act_fn(params, ts, act_key, actor_state) + cpu_action = jax.device_get(action).swapaxes(0, 1) + ts = env.step(cpu_action) + timesteps.append(ts) - eval_metrics = jax.tree_map(update_metrics, next_metrics, eval_metrics) + finished_eps = np.logical_or(finished_eps, ts.last()) - dones = np.logical_or(dones, next_dones) - eval_metrics.pop("is_terminal_step") + timesteps = jax.tree.map(lambda *x: np.stack(x), *timesteps) - return eval_metrics + metrics = timesteps.extras + if config.env.log_win_rate: + metrics["won_episode"] = timesteps.extras["won_episode"] - return eval_episodes + # find the first instance of done to get the metrics at that timestep, we don't + # care about subsequent steps because we only the results from the first episode + done_idx = jnp.argmax(timesteps.last(), axis=0) + metrics = jax.tree_map(lambda m: m[done_idx, jnp.arange(n_parallel_envs)], metrics) + del metrics["is_terminal_step"] # uneeded for logging + return key, metrics -def make_sebulba_eval_fns( - eval_env_fn: Callable, - network_apply_fn: Union[ActorApply, RecActorApply], - config: DictConfig, - np_rng: np.random.Generator, - add_global_state: bool = False, -) -> Tuple[SebulbaEvalFn, SebulbaEvalFn]: - """Initialize evaluator functions for reinforcement learning. + # This loop is important because we don't want too many parallel envs. + # So in evaluation we have num_envs parallel envs and loop enough times + # so that we do at least `eval_episodes` number of episodes. + metrics = [] + for _ in range(episode_loops): + key, metric = _episode(key) + metrics.append(metric) + + metrics: Metrics = jax.tree_map( + lambda *x: jnp.array(x).reshape(-1), *metrics + ) # flatten metrics + return metrics - Args: - eval_env_fn (Environment): The function to Create the eval envs. - network_apply_fn (Union[ActorApply,RecActorApply]): Creates a policy to sample. - config (DictConfig): The configuration settings for the evaluation. - use_recurrent_net (bool, optional): Whether to use a rnn. Defaults to False. - scanned_rnn (Optional[nn.Module], optional): The rnn module. - Required if `use_recurrent_net` is True. Defaults to None. - - Returns: - Tuple[SebulbaEvalFn, SebulbaEvalFn]: A tuple of two evaluation functions: - one for use during training and one for absolute metrics. - - Raises: - AssertionError: If `use_recurrent_net` is True but `scanned_rnn` is not provided. - """ - eval_env, absolute_eval_env = ( - eval_env_fn(config, config.arch.num_eval_episodes, add_global_state=add_global_state), - eval_env_fn(config, config.arch.num_eval_episodes * 10, add_global_state=add_global_state), - ) + def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics: + """Wrapper around eval function to time it and add in steps per second metric.""" + start_time = time.time() - # Check if win rate is required for evaluation. - log_win_rate = config.env.log_win_rate + metrics = eval_fn(params, key, init_act_state) - evaluator = get_sebulba_ff_evaluator_fn( - eval_env, - network_apply_fn, # type: ignore - config, - np_rng, - log_win_rate, # type: ignore - ) - absolute_metric_evaluator = get_sebulba_ff_evaluator_fn( - absolute_eval_env, - network_apply_fn, # type: ignore - config, - np_rng, - log_win_rate, # type: ignore - ) + end_time = time.time() + total_timesteps = jnp.sum(metrics["episode_length"]) + metrics["steps_per_second"] = total_timesteps / (end_time - start_time) + return metrics - return evaluator, absolute_metric_evaluator + return timed_eval_fn diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index f3a912f5d..fedc7f31d 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -15,9 +15,8 @@ import copy import queue import threading -import time -from collections import deque -from typing import Any, Dict, List, Tuple +from queue import Queue +from typing import Any, Dict, List, Sequence, Tuple import chex import flax @@ -33,7 +32,8 @@ from optax._src.base import OptState from rich.pretty import pprint -from mava.evaluator import make_sebulba_eval_fns as make_eval_fns +from mava.evaluator import get_sebulba_eval_fn as get_eval_fn +from mava.evaluator import make_ff_eval_act_fn from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition @@ -42,12 +42,14 @@ CriticApply, ExperimentOutput, Observation, + SebulbaEvalFn, SebulbaLearnerFn, ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer -from mava.utils.jax_utils import merge_leading_dims, unreplicate_n_dims +from mava.utils.jax_utils import merge_leading_dims from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.sebulba_utils import ParamsSource, Pipeline, ThreadLifetime from mava.utils.total_timestep_checker import sebulba_check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -56,12 +58,12 @@ def rollout( key: chex.PRNGKey, config: DictConfig, - rollout_queue: queue.Queue, - params_queue: queue.Queue, + rollout_pipeline: Pipeline, + params_source: ParamsSource, apply_fns: Tuple, - learner_devices: List, actor_device_id: int, seeds: List[int], + thread_lifetime: ThreadLifetime, ) -> None: # setup env = environments.make_gym_env(config, config.arch.num_envs) @@ -78,137 +80,80 @@ def get_action_and_value( """Get action and value.""" key, subkey = jax.random.split(key) - actor_policy = actor_apply_fn(params.actor_params, observation) # TODO: check vmapiing + actor_policy = actor_apply_fn(params.actor_params, observation) action = actor_policy.sample(seed=subkey) log_prob = actor_policy.log_prob(action) value = critic_apply_fn(params.critic_params, observation).squeeze() return action, log_prob, value, key - # Define queues to track time - params_queue_get_time: deque = deque(maxlen=1) - rollout_time: deque = deque(maxlen=1) - rollout_queue_put_time: deque = deque(maxlen=1) - next_obs, info = env.reset(seed=seeds) - next_dones = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jax.numpy.bool_) + timestep = env.reset(seed=seeds) + next_dones = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), + timestep.last(), + ) move_to_device = lambda x: jax.device_put(x, device=current_actor_device) - shard_split_payload = lambda x, axis: jax.device_put_sharded( - jnp.split(x, len(learner_devices), axis=axis), devices=learner_devices - ) - # Loop till the learner has finished training - for _update in range(config.system.num_updates): - inference_time: float = 0 - storage_time: float = 0 - env_send_time: float = 0 - - # Get the latest parameters from the learner - params_queue_get_time_start = time.time() - params = params_queue.get() - params_queue_get_time.append(time.time() - params_queue_get_time_start) - + while not thread_lifetime.should_stop(): # Rollout - rollout_time_start = time.time() - storage: List = [] - + traj: List = [] # Loop over the rollout length - for _ in range(0, config.system.rollout_length): - # Cached for transition - cached_next_obs = move_to_device( - jnp.stack(next_obs, axis=1) - ) # (num_envs, num_agents, ...) - cached_next_dones = move_to_device(next_dones) # (num_envs, num_agents) - cashed_action_mask = move_to_device( - np.stack(info["actions_mask"]) - ) # (num_envs, num_agents, num_actions) - - full_observation = Observation(cached_next_obs, cashed_action_mask) + for _ in range(config.system.rollout_length): + # Get the latest parameters from the learner + params = params_source.get() + + cached_next_obs = jax.tree.map(move_to_device, timestep.observation) + cached_next_dones = move_to_device(next_dones) + # Get action and value - inference_time_start = time.time() ( action, log_prob, value, key, - ) = get_action_and_value(params, full_observation, key) + ) = get_action_and_value(params, cached_next_obs, key) # Step the environment - inference_time += time.time() - inference_time_start - env_send_time_start = time.time() cpu_action = jax.device_get(action) - next_obs, next_reward, terminated, truncated, info = env.step( + timestep = env.step( cpu_action.swapaxes(0, 1) ) # (num_env, num_agents) --> (num_agents, num_env) - env_send_time += time.time() - env_send_time_start - # Prepare the data - storage_time_start = time.time() - next_dones = np.logical_or(terminated, truncated) - metrics = jax.tree_util.tree_map( - lambda *x: jnp.asarray(x), *info["metrics"] - ) # Stack the metrics + next_dones = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), + timestep.last(), + ) # Append data to storage - storage.append( + traj.append( PPOTransition( done=cached_next_dones, action=action, value=value, - reward=next_reward, + reward=timestep.reward, log_prob=log_prob, - obs=full_observation, - info=metrics, + obs=cached_next_obs, + info=timestep.extras, ) ) - storage_time += time.time() - storage_time_start - rollout_time.append(time.time() - rollout_time_start) - - parse_timer = time.time() - - # Prepare data to share with learner - # [PPOTransition() * rollout_len] --> PPOTransition[done=(rollout_len, num_envs, num_agents) - # , action=(rollout_len, num_envs, num_agents, num_actions), ...] - stacked_storage = jax.tree_util.tree_map(lambda *xs: jnp.stack(xs), *storage) - - # Split the arrays over the different learner_devices on the num_envs axis - - sharded_storage = jax.tree_util.tree_map( - lambda x: shard_split_payload(x, 1), stacked_storage - ) # (num_learner_devices, rollout_len, num_envs, num_agents, ...) - - # (num_learner_devices, num_envs, num_agents, ...) - sharded_next_obs = shard_split_payload(jnp.stack(next_obs, axis=1), 0) - sharded_next_action_mask = shard_split_payload(np.stack(info["actions_mask"]), 0) - sharded_next_done = shard_split_payload(next_dones, 0) - - # Pack the obs and action mask - payload_obs = Observation(sharded_next_obs, sharded_next_action_mask) - - # For debugging - speed_info = { # noqa F841 - "rollout_time": np.mean(rollout_time), - "params_queue_get_time": np.mean(params_queue_get_time), - "action_inference": inference_time, - "storage_time": storage_time, - "env_step_time": env_send_time, - "rollout_queue_put_time": ( - np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0 - ), - "parse_time": time.time() - parse_timer, - } - - payload = ( - sharded_storage, - payload_obs, - sharded_next_done, - ) + + # todo: replace with the record timer + # speed_info = { # F841 + # "rollout_time": np.mean(rollout_time), + # "params_queue_get_time": np.mean(params_queue_get_time), + # "action_inference": inference_time, + # "storage_time": storage_time, + # W "env_step_time": env_send_time, + # "rollout_queue_put_time": ( + # np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0 + # ), + # "parse_time": time.time() - parse_timer, + # } # Put data in the rollout queue to share it with the learner - rollout_queue_put_time_start = time.time() - rollout_queue.put(payload) - rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) + rollout_pipeline.put(traj, timestep.observation, next_dones) def get_learner_fn( @@ -397,11 +342,8 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) # SHUFFLE MINIBATCHES - batch_size = ( - config.system.rollout_length - * (config.arch.num_envs // len(config.arch.learner_device_ids)) - * len(config.arch.executor_device_ids) - * config.arch.n_threads_per_executor + batch_size = config.system.rollout_length * ( + config.arch.num_envs // len(config.arch.learner_device_ids) ) permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) @@ -435,7 +377,7 @@ def _critic_loss_fn( def learner_fn( learner_state: LearnerState, traj_batch: PPOTransition, - last_obs: chex.Array, + last_obs: Observation, last_dones: chex.Array, ) -> ExperimentOutput[LearnerState]: """Learner function. @@ -467,6 +409,37 @@ def learner_fn( return learner_fn +def evaluate( + logger: MavaLogger, + payload_queue: Queue, + evaluator: SebulbaEvalFn, + thread_lifetime: ThreadLifetime, + steps_per_rollout: int, + key: chex.PRNGKey, +): + eval_step = 1 + + while not thread_lifetime.should_stop(): + metrics, params = payload_queue.get() + t = int(steps_per_rollout * (eval_step + 1)) + + episode_metrics, train_metrics = jax.tree.map(lambda *x: np.asarray(x), *metrics) + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + + if ep_completed: + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) + + key, eval_key = jax.random.split(key, 2) + episode_metrics = evaluator(params.actor_params, eval_key, {}) + logger.log(episode_metrics, t, eval_step, LogEvent.EVAL) + + # todo add checkpointing + episode_return = jnp.mean(episode_metrics["episode_return"]) + + eval_step += 1 + + def learner_setup( keys: chex.Array, config: DictConfig, learner_devices: List ) -> Tuple[ @@ -572,14 +545,14 @@ def run_experiment(_config: DictConfig) -> float: # Sanity check of config assert ( config.arch.num_envs % len(config.arch.learner_device_ids) == 0 - ), "The number of environments must to be divisible by the number of learners " + ), "The number of environments must to be divisible by the number of learners." assert ( int(config.arch.num_envs / len(config.arch.learner_device_ids)) * config.arch.n_threads_per_executor % config.system.num_minibatches == 0 - ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" + ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches." # Setup learner. learn, apply_fns, learner_state = learner_setup( @@ -590,8 +563,10 @@ def run_experiment(_config: DictConfig) -> float: np_rng = np.random.default_rng(config.system.seed) # Setup evaluator. - evaluator, absolute_metric_evaluator = make_eval_fns( - environments.make_gym_env, apply_fns[0], config, np_rng + # One key per device for evaluation. + eval_act_fn = make_ff_eval_act_fn(apply_fns[0], config) + evaluator = get_eval_fn( + environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=False ) # Calculate total timesteps. @@ -601,18 +576,9 @@ def run_experiment(_config: DictConfig) -> float: ), "Number of updates per evaluation must be less than total number of updates." # Calculate number of updates per evaluation. config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation - config.arch.num_evaluation, remaining_updates = divmod( - config.system.num_updates, config.system.num_updates_per_eval - ) - config.arch.num_evaluation += ( - remaining_updates != 0 - ) # Add an evaluation step if the num_updates is not a multiple of num_evaluation + steps_per_rollout = ( - len(config.arch.executor_device_ids) - * config.arch.n_threads_per_executor - * config.system.rollout_length - * config.arch.num_envs - * config.system.num_updates_per_eval + config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval ) # Logger setup @@ -632,167 +598,77 @@ def run_experiment(_config: DictConfig) -> float: # Executor setup and launch. unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) - params_queues: List = [] - rollout_queues: List = [] - - for _d_idx, d_id in enumerate( # Loop through each executor device - config.arch.executor_device_ids - ): - # Replicate params per executor device - device_params = jax.device_put(unreplicated_params, devices[d_id]) + params_sources: Sequence[ParamsSource] = [] + thread_lifetimes: Sequence[ThreadLifetime] = [] + pipeline = Pipeline(128, learner_devices) # TODO: ADD THE MAX PIPILINE QUEUE SIZE TO THE CONFIG + pipeline.start() + + # Create the actor threads + for d_idx, d_id in enumerate(config.arch.executor_device_ids): # Loop through each executor thread - for _thread_id in range(config.arch.n_threads_per_executor): - seeds = np_rng.integers(np.iinfo(np.int64).max, size=config.arch.num_envs).tolist() - params_queues.append(queue.Queue(maxsize=1)) - rollout_queues.append(queue.Queue(maxsize=1)) - params_queues[-1].put(device_params) + for thread_id in range(config.arch.n_threads_per_executor): + seeds = np_rng.integers(np.iinfo(np.int32).max, size=config.arch.num_envs).tolist() + + params_source = ParamsSource(unreplicated_params, devices[d_id]) + params_source.start() + params_sources.append(params_source) + + lifetime = ThreadLifetime() + thread_lifetimes.append(lifetime) + threading.Thread( target=rollout, args=( jax.device_put(key, devices[d_id]), config, - rollout_queues[-1], - params_queues[-1], + pipeline, + params_sources[-1], apply_fns, - learner_devices, d_id, seeds, + lifetime, ), + name=f"Actor-{thread_id + d_idx * config.arch.n_threads_per_executor}", ).start() - # Run experiment for the total number of updates. - max_episode_return = jnp.float32(0.0) - best_params = None - for eval_step in range(config.arch.num_evaluation): - training_start_time = time.time() - learner_speeds = [] - rollout_times = [] - - episode_metrics = [] - train_metrics = [] - - # Full or partial last eval step. - num_updates_in_eval = ( - remaining_updates - if eval_step == config.arch.num_evaluation - 1 and remaining_updates - else config.system.num_updates_per_eval - ) - for _update in range(num_updates_in_eval): - sharded_storages = [] - sharded_next_obss = [] - sharded_next_dones = [] - - rollout_start_time = time.time() - # Loop through each executor device - for d_idx, _ in enumerate(config.arch.executor_device_ids): - # Loop through each executor thread - for thread_id in range(config.arch.n_threads_per_executor): - # Get data from rollout queue - ( - sharded_storage, - sharded_next_obs, - sharded_next_done, - ) = rollout_queues[d_idx * config.arch.n_threads_per_executor + thread_id].get() - sharded_storages.append(sharded_storage) - sharded_next_obss.append(sharded_next_obs) - sharded_next_dones.append(sharded_next_done) - - rollout_times.append(time.time() - rollout_start_time) - - # Concatinate the returned trajectories on the n_env axis - sharded_storages = jax.tree_util.tree_map( - lambda *x: jnp.concatenate(x, axis=2), *sharded_storages - ) - sharded_next_obss = jax.tree_util.tree_map( - lambda *x: jnp.concatenate(x, axis=1), *sharded_next_obss - ) - sharded_next_dones = jnp.concatenate(sharded_next_dones, axis=1) - - learner_start_time = time.time() - learner_output = learn( - learner_state, sharded_storages, sharded_next_obss, sharded_next_dones - ) - learner_speeds.append(time.time() - learner_start_time) - - # Stack the metrics - episode_metrics.append(learner_output.episode_metrics) - train_metrics.append(learner_output.train_metrics) - - # Send updated params to executors - unreplicated_params = flax.jax_utils.unreplicate(learner_output.learner_state.params) - for d_idx, d_id in enumerate(config.arch.executor_device_ids): - device_params = jax.device_put(unreplicated_params, devices[d_id]) - for thread_id in range(config.arch.n_threads_per_executor): - params_queues[d_idx * config.arch.n_threads_per_executor + thread_id].put( - device_params - ) - - # Log the results of the training. - elapsed_time = time.time() - training_start_time - t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics = jax.tree_util.tree_map(lambda *x: np.asarray(x), *episode_metrics) - episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time - - # Separately log timesteps, actoring metrics and training metrics. - speed_info = { - "total_time": elapsed_time, - "rollout_time": np.sum(rollout_times), - "learner_time": np.sum(learner_speeds), - "timestep": t, - } - logger.log(speed_info, t, eval_step, LogEvent.MISC) - if ep_completed: # only log episode metrics if an episode was completed in the rollout. - logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - train_metrics = jax.tree_util.tree_map(lambda *x: np.asarray(x), *train_metrics) - logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) - - # Evaluation on the learner - evaluation_start_timer = time.time() - key_e, eval_key = jax.random.split(key_e, 2) - episode_metrics = evaluator( - unreplicate_n_dims(learner_output.learner_state.params.actor_params, 1), eval_key - ) - - # Log the results of the evaluation. - elapsed_time = time.time() - evaluation_start_timer - episode_return = jnp.mean(episode_metrics["episode_return"]) - - steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) - episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(episode_metrics, t, eval_step, LogEvent.EVAL) - - if save_checkpoint: - # Save checkpoint of learner state - checkpointer.save( - timestep=steps_per_rollout * (eval_step + 1), - unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state, 1), - episode_return=episode_return, - ) - - if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(learner_output.learner_state.params.actor_params) - max_episode_return = episode_return - - # Update runner state to continue training. - learner_state = learner_output.learner_state - - # Record the performance for the final evaluation run. - eval_performance = float(jnp.mean(episode_metrics[config.env.eval_metric])) - - # Measure absolute metric. - if config.arch.absolute_metric: - start_time = time.time() + lifetime = ThreadLifetime() + evaluator_queue = Queue() # maxsize=1) + threading.Thread( + target=evaluate, + name="Evaluator", + args=(logger, evaluator_queue, evaluator, lifetime, steps_per_rollout, key), + ).start() + thread_lifetimes.append(lifetime) + + for eval_step in range( + config.arch.num_evaluation + ): # todo : replace :) if comment 3 is the way then this can be replaced with num_evaluation and the try catch in naother loop called num_updates per eval? + # should we have a loop over num actors? how much should we get? + # rn it trains over the output of a single actor + # we can leave it this way and think of other actor threads / devices as just a speed boost? I.e you should get ur desired batch sized base only on the num_envs * rollour_len ? + metrics: Sequence[Tuple[Dict, Dict]] = [] + _update = 0 + while _update != config.system.num_updates_per_eval: + try: + traj_batch, last_obs, last_dones = pipeline.get(block=True, timeout=1) + except queue.Empty: + continue + else: + learner_state, episode_metrics, train_metrics = learn( + learner_state, traj_batch, last_obs, last_dones + ) + metrics.append((episode_metrics, train_metrics)) + unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) - key_e, eval_key = jax.random.split(key_e, 2) - episode_metrics = absolute_metric_evaluator(unreplicate_n_dims(best_params, 1), eval_key) + for source in params_sources: + source.update(unreplicated_params) + _update += 1 - elapsed_time = time.time() - start_time - steps_per_eval = int(jnp.sum(episode_metrics["episode_length"])) + # Run the evaluator + evaluator_queue.put((metrics, unreplicated_params)) - t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + for thread_lifetime in thread_lifetimes: + thread_lifetime.stop() # Stop the logger. logger.stop() diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 887a987cb..405cb73b8 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -49,6 +49,7 @@ GymAgentIDWrapper, GymLBFWrapper, GymRecordEpisodeMetrics, + GymToJumanji, GymWrapper, LbfWrapper, MabraxWrapper, @@ -232,7 +233,7 @@ def make_gym_env( config: DictConfig, num_env: int, add_global_state: bool = False, -) -> gymnasium.vector.AsyncVectorEnv: +) -> GymToJumanji: """ Create a gymnasium environment. @@ -259,6 +260,8 @@ def create_gym_env(config: DictConfig, add_global_state: bool = False) -> gymnas worker=async_multiagent_worker, ) + envs = GymToJumanji(envs) + return envs diff --git a/mava/utils/sebulba_utils.py b/mava/utils/sebulba_utils.py new file mode 100644 index 000000000..073f735c5 --- /dev/null +++ b/mava/utils/sebulba_utils.py @@ -0,0 +1,166 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import queue +import threading +import time +from typing import Any, List, Sequence, Tuple, Union + +import jax +import jax.numpy as jnp +from chex import Array + +from mava.systems.ppo.types import Params, PPOTransition # todo: remove the ppo dependencies +from mava.types import Observation, ObservationGlobalState + + +# Copied from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py +class Pipeline(threading.Thread): + """ + The `Pipeline` shards trajectories into `learner_devices`, + ensuring trajectories are consumed in the right order to avoid being off-policy + and limit the max number of samples in device memory at one time to avoid OOM issues. + """ + + def __init__(self, max_size: int, learner_devices: List[jax.Device]): + """ + Initializes the pipeline with a maximum size and the devices to shard trajectories across. + + Args: + max_size: The maximum number of trajectories to keep in the pipeline. + learner_devices: The devices to shard trajectories across. + """ + super().__init__(name="Pipeline") + self.learner_devices = learner_devices + self.tickets_queue: queue.Queue = queue.Queue() + self._queue: queue.Queue = queue.Queue(maxsize=max_size) + + def run(self) -> None: + """ + This function ensures that trajectories on the queue are consumed in the right order. The + start_condition and end_condition are used to ensure that only 1 thread is processing an + item from the queue at one time, ensuring predictable memory usage. + """ + while True: # todo Thread lifetime + start_condition, end_condition = self.tickets_queue.get() + with end_condition: + with start_condition: + start_condition.notify() + end_condition.wait() + + def put( + self, + traj: Sequence[PPOTransition], + next_obs: Union[Observation, ObservationGlobalState], + next_dones: Array, + ) -> None: + """ + Put a trajectory on the queue to be consumed by the learner. + """ + start_condition, end_condition = (threading.Condition(), threading.Condition()) + with start_condition: + self.tickets_queue.put((start_condition, end_condition)) + start_condition.wait() # wait to be allowed to start + + # [PPOTransition()] * rollout_len --> PPOTransition[done=(rollout_len, num_envs, num_agents) + sharded_traj = jax.tree.map(lambda *x: self.shard_split_playload(jnp.stack(x), 1), *traj) + + # obs Tuple[(num_envs, num_agents, ...), ...] --> [(num_envs / num_learner_devices, num_agents, ...)] * num_learner_devices + sharded_next_obs = jax.tree.map(self.shard_split_playload, next_obs) + + # dones (num_envs, num_agents) --> [(num_envs / num_learner_devices, num_agents)] * num_learner_devices + sharded_next_dones = self.shard_split_playload(next_dones, 0) + + self._queue.put((sharded_traj, sharded_next_obs, sharded_next_dones)) + + with end_condition: + end_condition.notify() # tell we have finish + + def qsize(self) -> int: + """Returns the number of trajectories in the pipeline.""" + return self._queue.qsize() + + def get( + self, block: bool = True, timeout: Union[float, None] = None + ) -> Tuple[PPOTransition, Union[Observation, ObservationGlobalState], Array]: + """Get a trajectory from the pipeline.""" + return self._queue.get(block, timeout) # type: ignore + + def shard_split_playload(self, payload: Any, axis: int = 0): + split_payload = jnp.split(payload, len(self.learner_devices), axis=axis) + return jax.device_put_sharded(split_payload, devices=self.learner_devices) + + +class ParamsSource(threading.Thread): + """ + A `ParamSource` is a component that allows networks params to be passed from a + `Learner` component to `Actor` components. + """ + + def __init__(self, init_value: Params, device: jax.Device): + super().__init__(name=f"ParamsSource-{device.id}") + self.value = jax.device_put(init_value, device) + self.device = device + self.new_value: queue.Queue = queue.Queue() + + def run(self) -> None: + """ + This function is responsible for updating the value of the `ParamSource` when a new value + is available. + """ + while True: + try: + waiting = self.new_value.get(block=True, timeout=1) + self.value = jax.device_put(jax.block_until_ready(waiting), self.device) + except queue.Empty: + continue + + def update(self, new_params: Params) -> None: + """ + Update the value of the `ParamSource` with a new value. + + Args: + new_params: The new value to update the `ParamSource` with. + """ + self.new_value.put(new_params) + + def get(self) -> Params: + """Get the current value of the `ParamSource`.""" + return self.value + + +class RecordTimeTo: + def __init__(self, to: Any): + self.to = to + + def __enter__(self) -> None: + self.start = time.monotonic() + + def __exit__(self, *args: Any) -> None: + end = time.monotonic() + self.to.append(end - self.start) + + +class ThreadLifetime: + """Simple class for a mutable boolean that can be used to signal a thread to stop.""" + + def __init__(self): + self._stop = False + + def should_stop(self): + return self._stop + + def stop(self): + self._stop = True diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 550180ee5..a7b56c5da 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -20,6 +20,7 @@ GymAgentIDWrapper, GymLBFWrapper, GymRecordEpisodeMetrics, + GymToJumanji, GymWrapper, async_multiagent_worker, ) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 520243e92..5bfb24e8c 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -20,11 +20,15 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import gymnasium +import jax import numpy as np from gymnasium import spaces from gymnasium.vector.utils import write_to_shared_memory +from jumanji.types import StepType, TimeStep from numpy.typing import NDArray +from mava.types import Observation, ObservationGlobalState + # Filter out the warnings warnings.filterwarnings("ignore", module="gymnasium.utils.passive_env_checker") @@ -191,6 +195,65 @@ def modify_space(self, space: spaces.Space) -> spaces.Space: raise ValueError(f"Space {type(space)} is not currently supported.") +class GymToJumanji(gymnasium.Wrapper): + """Converts Gym outputs to Jumanji timesteps""" + + def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> TimeStep: + obs, info = self.env.reset(seed=seed, options=options) + + num_agents = len(self.env.single_action_space) + num_envs = self.env.num_envs + + ep_done = np.zeros(num_envs, dtype=float) + rewards = np.zeros((num_envs, num_agents), dtype=float) + + timestep = self._create_timestep(obs, ep_done, rewards, info) + + return timestep + + def step(self, action: list) -> TimeStep: + obs, rewards, terminated, truncated, info = self.env.step(action) + + ep_done = np.logical_or(terminated, truncated).all(axis=1) + + timestep = self._create_timestep(obs, ep_done, rewards, info) + + return timestep + + def _format_observation( + self, obs: NDArray, info: Dict + ) -> Union[Observation, ObservationGlobalState]: + """Create an observation from the raw observation and environment state.""" + + obs = np.array(obs).swapaxes( + 0, 1 + ) # (num_agents, num_envs, ...) -> (num_envs, num_agents, ...) + action_mask = np.stack(info["actions_mask"]) + obs_data = {"agents_view": obs, "action_mask": action_mask} + + if "global_obs" in info: + global_obs = np.array(info["global_obs"]).swapaxes(0, 1) + obs_data["global_state"] = global_obs + return ObservationGlobalState(**obs_data) + else: + return Observation(**obs_data) + + def _create_timestep( + self, obs: NDArray, ep_done: NDArray, rewards: NDArray, info: Dict + ) -> TimeStep: + obs = self._format_observation(obs, info) + extras = jax.tree.map(lambda *x: np.stack(x), *info["metrics"]) + step_type = np.where(ep_done, StepType.LAST, StepType.MID) + + return TimeStep( + step_type=step_type, + reward=rewards, + discount=1.0 - ep_done, + observation=obs, + extras=extras, + ) + + # Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py # Modified to work with multiple agents def async_multiagent_worker( # CCR001 From fc80b91def01524c0ce7d333c012393bfb52325f Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 26 Jul 2024 22:37:36 +0100 Subject: [PATCH 083/125] chore: code cleanup and sps calcs and learner threads --- mava/configs/arch/sebulba.yaml | 9 +- mava/systems/ppo/sebulba/ff_ippo.py | 272 +++++++++++++++------------- mava/utils/sebulba_utils.py | 23 ++- mava/wrappers/gym.py | 4 +- 4 files changed, 169 insertions(+), 139 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 9d21a51d3..e38691780 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -2,13 +2,13 @@ architecture_name: sebulba # --- Training --- -num_envs: 2 # number of environments per thread. +num_envs: 32 # number of environments per thread. # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select # an action which corresponds to the greatest logit. If false, the policy will sample # from the logits. -num_eval_episodes: 2 # Number of episodes to evaluate per evaluation. +num_eval_episodes: 32 # Number of episodes to evaluate per evaluation. num_evaluation: 200 # Number of evenly spaced evaluations to perform during training. absolute_metric: True # Whether the absolute metric should be computed. For more details # on the absolute metric please see: https://arxiv.org/abs/2209.10485 @@ -17,3 +17,8 @@ absolute_metric: True # Whether the absolute metric should be computed. For more n_threads_per_executor: 2 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices +Pilpeline_queue_size : 2 +# The size of the pipeline queue determines the extent of off-policy training allowed. A larger value permits more off-policy training. +# Too large of a value with too many actors will lead to all of the updates getting wasted in old episodes +# Too small of a value and the utility of having multiple actors is lost. +# A value of 1 leads to almost strictly on-policy training. diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index fedc7f31d..3f07adda8 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -13,7 +13,6 @@ # limitations under the License. import copy -import queue import threading from queue import Queue from typing import Any, Dict, List, Sequence, Tuple @@ -42,14 +41,13 @@ CriticApply, ExperimentOutput, Observation, - SebulbaEvalFn, SebulbaLearnerFn, ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import merge_leading_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.sebulba_utils import ParamsSource, Pipeline, ThreadLifetime +from mava.utils.sebulba_utils import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime from mava.utils.total_timestep_checker import sebulba_check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -69,6 +67,7 @@ def rollout( env = environments.make_gym_env(config, config.arch.num_envs) current_actor_device = jax.devices()[actor_device_id] actor_apply_fn, critic_apply_fn = apply_fns + num_agents, num_envs = config.system.num_agents, config.arch.num_envs # Define the util functions: select action function and prepare data to share it with learner. @jax.jit @@ -88,8 +87,9 @@ def get_action_and_value( return action, log_prob, value, key timestep = env.reset(seed=seeds) + next_dones = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), + lambda x: jnp.repeat(x, num_agents).reshape(num_envs, -1), timestep.last(), ) @@ -99,61 +99,52 @@ def get_action_and_value( while not thread_lifetime.should_stop(): # Rollout traj: List = [] - # Loop over the rollout length - for _ in range(config.system.rollout_length): - # Get the latest parameters from the learner - params = params_source.get() - - cached_next_obs = jax.tree.map(move_to_device, timestep.observation) - cached_next_dones = move_to_device(next_dones) - - # Get action and value - ( - action, - log_prob, - value, - key, - ) = get_action_and_value(params, cached_next_obs, key) - - # Step the environment - cpu_action = jax.device_get(action) - timestep = env.step( - cpu_action.swapaxes(0, 1) - ) # (num_env, num_agents) --> (num_agents, num_env) - - next_dones = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) + time_dict: Dict[str, List[float]] = {"single_rollout": [], "env_step_time": []} - # Append data to storage - traj.append( - PPOTransition( - done=cached_next_dones, - action=action, - value=value, - reward=timestep.reward, - log_prob=log_prob, - obs=cached_next_obs, - info=timestep.extras, + # Loop over the rollout length + with RecordTimeTo(time_dict["single_rollout"]): + for _ in range(config.system.rollout_length): + # Get the latest parameters from the learner + params = params_source.get() + + cached_next_obs = jax.tree.map(move_to_device, timestep.observation) + cached_next_dones = move_to_device(next_dones) + + # Get action and value + ( + action, + log_prob, + value, + key, + ) = get_action_and_value(params, cached_next_obs, key) + + # Step the environment + cpu_action = jax.device_get(action) + + with RecordTimeTo(time_dict["env_step_time"]): + timestep = env.step( + cpu_action.swapaxes(0, 1) + ) # (num_env, num_agents) --> (num_agents, num_env) + + next_dones = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, num_agents).reshape(num_envs, -1), + timestep.last(), ) - ) - # todo: replace with the record timer - # speed_info = { # F841 - # "rollout_time": np.mean(rollout_time), - # "params_queue_get_time": np.mean(params_queue_get_time), - # "action_inference": inference_time, - # "storage_time": storage_time, - # W "env_step_time": env_send_time, - # "rollout_queue_put_time": ( - # np.mean(rollout_queue_put_time) if rollout_queue_put_time else 0 - # ), - # "parse_time": time.time() - parse_timer, - # } + # Append data to storage + traj.append( + PPOTransition( + done=cached_next_dones, + action=action, + value=value, + reward=timestep.reward, + log_prob=log_prob, + obs=cached_next_obs, + info=timestep.extras, + ) + ) - # Put data in the rollout queue to share it with the learner - rollout_pipeline.put(traj, timestep.observation, next_dones) + rollout_pipeline.put(traj, timestep.observation, next_dones, time_dict) def get_learner_fn( @@ -190,7 +181,7 @@ def _update_step( _ (Any): The current metrics info. """ - def _calculate_gae( # todo: lake sure this is appropriate + def _calculate_gae( traj_batch: PPOTransition, last_val: chex.Array, last_done: chex.Array ) -> Tuple[chex.Array, chex.Array]: def _get_advantages( @@ -303,7 +294,7 @@ def _critic_loss_fn( # pmean over devices. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), - axis_name="device", # todo: pmean over learner devices not all + axis_name="device", ) # pmean over devices. @@ -394,8 +385,6 @@ def learner_fn( - env_state (LogEnvState): The environment state. - timesteps (TimeStep): The initial timestep in the initial trajectory. """ - - # todo: add update_batch_size learner_state, (episode_info, loss_info) = _update_step( learner_state, traj_batch, last_obs, last_dones ) @@ -409,37 +398,6 @@ def learner_fn( return learner_fn -def evaluate( - logger: MavaLogger, - payload_queue: Queue, - evaluator: SebulbaEvalFn, - thread_lifetime: ThreadLifetime, - steps_per_rollout: int, - key: chex.PRNGKey, -): - eval_step = 1 - - while not thread_lifetime.should_stop(): - metrics, params = payload_queue.get() - t = int(steps_per_rollout * (eval_step + 1)) - - episode_metrics, train_metrics = jax.tree.map(lambda *x: np.asarray(x), *metrics) - episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) - - if ep_completed: - logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) - - key, eval_key = jax.random.split(key, 2) - episode_metrics = evaluator(params.actor_params, eval_key, {}) - logger.log(episode_metrics, t, eval_step, LogEvent.EVAL) - - # todo add checkpointing - episode_return = jnp.mean(episode_metrics["episode_return"]) - - eval_step += 1 - - def learner_setup( keys: chex.Array, config: DictConfig, learner_devices: List ) -> Tuple[ @@ -530,6 +488,46 @@ def learner_setup( return learn, apply_fns, init_learner_state +def learner( + learn: SebulbaLearnerFn[LearnerState, PPOTransition], + learner_state: LearnerState, + config: DictConfig, + learner_queue: Queue, + pipeline: Pipeline, + params_sources: Sequence[ParamsSource], +) -> None: + for _eval_step in range(config.arch.num_evaluation): + metrics: List[Tuple[Dict, Dict]] = [] + rollout_times: List[Dict] = [] + eval_times: Dict[str, List[float]] = {"evaluator_blocked_time": [], "evaluation_time": []} + + for _update in range(config.system.num_updates_per_eval): + with RecordTimeTo(eval_times["evaluator_blocked_time"]): + traj_batch, last_obs, last_dones, rollout_time = pipeline.get(block=True) + + with RecordTimeTo(eval_times["evaluation_time"]): + learner_state, episode_metrics, train_metrics = learn( + learner_state, traj_batch, last_obs, last_dones + ) + + metrics.append((episode_metrics, train_metrics)) + rollout_times.append(rollout_time) + + unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) + + for source in params_sources: + source.update(unreplicated_params) + + # Pass to the evaluator + episode_metrics, train_metrics = jax.tree.map(lambda *x: np.asarray(x), *metrics) + + rollout_times = jax.tree.map(lambda *x: np.mean(x), *rollout_times) + times_dict = rollout_times | eval_times + times_dict = jax.tree.map(np.mean, times_dict, is_leaf=lambda x: isinstance(x, list)) + + learner_queue.put((episode_metrics, train_metrics, learner_state, times_dict)) + + def run_experiment(_config: DictConfig) -> float: """Runs experiment.""" config = copy.deepcopy(_config) @@ -597,10 +595,10 @@ def run_experiment(_config: DictConfig) -> float: ) # Executor setup and launch. - unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) - params_sources: Sequence[ParamsSource] = [] - thread_lifetimes: Sequence[ThreadLifetime] = [] - pipeline = Pipeline(128, learner_devices) # TODO: ADD THE MAX PIPILINE QUEUE SIZE TO THE CONFIG + unreplicated_inital_params = flax.jax_utils.unreplicate(learner_state.params) + params_sources: List[ParamsSource] = [] + thread_lifetimes: List[ThreadLifetime] = [] + pipeline = Pipeline(config.arh.Pilpeline_queue_size, learner_devices) pipeline.start() # Create the actor threads @@ -609,7 +607,7 @@ def run_experiment(_config: DictConfig) -> float: for thread_id in range(config.arch.n_threads_per_executor): seeds = np_rng.integers(np.iinfo(np.int32).max, size=config.arch.num_envs).tolist() - params_source = ParamsSource(unreplicated_params, devices[d_id]) + params_source = ParamsSource(unreplicated_inital_params, devices[d_id]) params_source.start() params_sources.append(params_source) @@ -631,45 +629,67 @@ def run_experiment(_config: DictConfig) -> float: name=f"Actor-{thread_id + d_idx * config.arch.n_threads_per_executor}", ).start() - lifetime = ThreadLifetime() - evaluator_queue = Queue() # maxsize=1) + learner_queue: Queue = Queue() threading.Thread( - target=evaluate, - name="Evaluator", - args=(logger, evaluator_queue, evaluator, lifetime, steps_per_rollout, key), + target=learner, + name="Learner", + args=(learn, learner_state, config, learner_queue, pipeline, params_sources), ).start() - thread_lifetimes.append(lifetime) - - for eval_step in range( - config.arch.num_evaluation - ): # todo : replace :) if comment 3 is the way then this can be replaced with num_evaluation and the try catch in naother loop called num_updates per eval? - # should we have a loop over num actors? how much should we get? - # rn it trains over the output of a single actor - # we can leave it this way and think of other actor threads / devices as just a speed boost? I.e you should get ur desired batch sized base only on the num_envs * rollour_len ? - metrics: Sequence[Tuple[Dict, Dict]] = [] - _update = 0 - while _update != config.system.num_updates_per_eval: - try: - traj_batch, last_obs, last_dones = pipeline.get(block=True, timeout=1) - except queue.Empty: - continue - else: - learner_state, episode_metrics, train_metrics = learn( - learner_state, traj_batch, last_obs, last_dones - ) - metrics.append((episode_metrics, train_metrics)) - unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) - for source in params_sources: - source.update(unreplicated_params) - _update += 1 + max_episode_return = -jnp.inf + best_params = unreplicated_inital_params.actor_params + + for eval_step in range(config.arch.num_evaluation): + # Get the next set of params and metrics from the evaluator + episode_metrics, train_metrics, learner_state, times_dict = learner_queue.get() - # Run the evaluator - evaluator_queue.put((metrics, unreplicated_params)) + t = int(steps_per_rollout * (eval_step + 1)) + + times_dict["timestep"] = t + logger.log(times_dict, t, eval_step, LogEvent.MISC) + + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / times_dict["single_rollout"] + if ep_completed: + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + + logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) + + unreplicated_actor_params = flax.jax_utils.unreplicate(learner_state.params.actor_params) + key, eval_key = jax.random.split(key, 2) + eval_metrics = evaluator(unreplicated_actor_params, eval_key, {}) + logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) + + episode_return = jnp.mean(eval_metrics["episode_return"]) + + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=steps_per_rollout * (eval_step + 1), + unreplicated_learner_state=learner_state, + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(unreplicated_actor_params) + max_episode_return = episode_return for thread_lifetime in thread_lifetimes: thread_lifetime.stop() + eval_performance = float(jnp.mean(eval_metrics[config.env.eval_metric])) + + # Measure absolute metric. + if config.arch.absolute_metric: + abs_metric_evaluator = get_eval_fn( + environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=True + ) + key, eval_key = jax.random.split(key, 2) + eval_metrics = abs_metric_evaluator(best_params, eval_key, {}) + + t = int(steps_per_rollout * (eval_step + 1)) + logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) + # Stop the logger. logger.stop() diff --git a/mava/utils/sebulba_utils.py b/mava/utils/sebulba_utils.py index 073f735c5..a5c0bdc14 100644 --- a/mava/utils/sebulba_utils.py +++ b/mava/utils/sebulba_utils.py @@ -16,7 +16,7 @@ import queue import threading import time -from typing import Any, List, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -65,6 +65,7 @@ def put( traj: Sequence[PPOTransition], next_obs: Union[Observation, ObservationGlobalState], next_dones: Array, + time_dict: Dict, ) -> None: """ Put a trajectory on the queue to be consumed by the learner. @@ -77,13 +78,15 @@ def put( # [PPOTransition()] * rollout_len --> PPOTransition[done=(rollout_len, num_envs, num_agents) sharded_traj = jax.tree.map(lambda *x: self.shard_split_playload(jnp.stack(x), 1), *traj) - # obs Tuple[(num_envs, num_agents, ...), ...] --> [(num_envs / num_learner_devices, num_agents, ...)] * num_learner_devices + # obs Tuple[(num_envs, num_agents, ...), ...] --> + # [(num_envs / num_learner_devices, num_agents, ...)] * num_learner_devices sharded_next_obs = jax.tree.map(self.shard_split_playload, next_obs) - # dones (num_envs, num_agents) --> [(num_envs / num_learner_devices, num_agents)] * num_learner_devices + # dones (num_envs, num_agents) --> + # [(num_envs / num_learner_devices, num_agents)] * num_learner_devices sharded_next_dones = self.shard_split_playload(next_dones, 0) - self._queue.put((sharded_traj, sharded_next_obs, sharded_next_dones)) + self._queue.put((sharded_traj, sharded_next_obs, sharded_next_dones, time_dict)) with end_condition: end_condition.notify() # tell we have finish @@ -94,11 +97,11 @@ def qsize(self) -> int: def get( self, block: bool = True, timeout: Union[float, None] = None - ) -> Tuple[PPOTransition, Union[Observation, ObservationGlobalState], Array]: + ) -> Tuple[PPOTransition, Union[Observation, ObservationGlobalState], Array, Dict]: """Get a trajectory from the pipeline.""" return self._queue.get(block, timeout) # type: ignore - def shard_split_playload(self, payload: Any, axis: int = 0): + def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: split_payload = jnp.split(payload, len(self.learner_devices), axis=axis) return jax.device_put_sharded(split_payload, devices=self.learner_devices) @@ -111,7 +114,7 @@ class ParamsSource(threading.Thread): def __init__(self, init_value: Params, device: jax.Device): super().__init__(name=f"ParamsSource-{device.id}") - self.value = jax.device_put(init_value, device) + self.value: Params = jax.device_put(init_value, device) self.device = device self.new_value: queue.Queue = queue.Queue() @@ -156,11 +159,11 @@ def __exit__(self, *args: Any) -> None: class ThreadLifetime: """Simple class for a mutable boolean that can be used to signal a thread to stop.""" - def __init__(self): + def __init__(self) -> None: self._stop = False - def should_stop(self): + def should_stop(self) -> bool: return self._stop - def stop(self): + def stop(self) -> None: self._stop = True diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 5bfb24e8c..35bd674bd 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -198,7 +198,9 @@ def modify_space(self, space: spaces.Space) -> spaces.Space: class GymToJumanji(gymnasium.Wrapper): """Converts Gym outputs to Jumanji timesteps""" - def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> TimeStep: + def reset( + self, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None + ) -> TimeStep: obs, info = self.env.reset(seed=seed, options=options) num_agents = len(self.env.single_action_space) From 18ec08f843460ca200f20d5cb40694bf87aac50b Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 29 Jul 2024 11:33:47 +0100 Subject: [PATCH 084/125] feat: shared time steps checker --- mava/configs/arch/sebulba.yaml | 2 +- mava/systems/ppo/anakin/ff_ippo.py | 4 +-- mava/systems/ppo/anakin/ff_mappo.py | 4 +-- mava/systems/ppo/anakin/rec_ippo.py | 4 +-- mava/systems/ppo/anakin/rec_mappo.py | 4 +-- mava/systems/ppo/sebulba/ff_ippo.py | 11 +++--- mava/systems/q_learning/anakin/rec_iql.py | 4 +-- mava/systems/sac/anakin/ff_isac.py | 4 +-- mava/systems/sac/anakin/ff_masac.py | 4 +-- mava/utils/total_timestep_checker.py | 44 ++++++----------------- 10 files changed, 29 insertions(+), 56 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index e38691780..e9865460a 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -17,7 +17,7 @@ absolute_metric: True # Whether the absolute metric should be computed. For more n_threads_per_executor: 2 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices -Pilpeline_queue_size : 2 +pilpeline_queue_size : 5 # The size of the pipeline queue determines the extent of off-policy training allowed. A larger value permits more off-policy training. # Too large of a value with too many actors will lead to all of the updates getting wasted in old episodes # Too small of a value and the utility of having multiple actors is lost. diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index d0fb9c30f..49c969cdb 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -41,7 +41,7 @@ unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import anakin_check_total_timesteps +from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -475,7 +475,7 @@ def run_experiment(_config: DictConfig) -> float: evaluator = get_eval_fn(eval_env, eval_act_fn, config, absolute_metric=False) # Calculate total timesteps. - config = anakin_check_total_timesteps(config) + config = check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index 20ae3272e..cafa42888 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -36,7 +36,7 @@ from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import anakin_check_total_timesteps +from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -459,7 +459,7 @@ def run_experiment(_config: DictConfig) -> float: evaluator = get_eval_fn(eval_env, eval_act_fn, config, absolute_metric=False) # Calculate total timesteps. - config = anakin_check_total_timesteps(config) + config = check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index a073d6dcb..230756295 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -50,7 +50,7 @@ from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import anakin_check_total_timesteps +from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -619,7 +619,7 @@ def run_experiment(_config: DictConfig) -> float: evaluator = get_eval_fn(eval_env, eval_act_fn, config, absolute_metric=False) # Calculate total timesteps. - config = anakin_check_total_timesteps(config) + config = check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index 3e741f5c1..53ae7c65d 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -50,7 +50,7 @@ from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import anakin_check_total_timesteps +from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -615,7 +615,7 @@ def run_experiment(_config: DictConfig) -> float: evaluator = get_eval_fn(eval_env, eval_act_fn, config, absolute_metric=False) # Calculate total timesteps. - config = anakin_check_total_timesteps(config) + config = check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 3f07adda8..b9f83f20b 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -48,7 +48,7 @@ from mava.utils.jax_utils import merge_leading_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.sebulba_utils import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime -from mava.utils.total_timestep_checker import sebulba_check_total_timesteps +from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -95,7 +95,7 @@ def get_action_and_value( move_to_device = lambda x: jax.device_put(x, device=current_actor_device) - # Loop till the learner has finished training + # Loop till the desired num_updates is reached. while not thread_lifetime.should_stop(): # Rollout traj: List = [] @@ -568,7 +568,7 @@ def run_experiment(_config: DictConfig) -> float: ) # Calculate total timesteps. - config = sebulba_check_total_timesteps(config) + config = check_total_timesteps(config) assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." @@ -598,7 +598,7 @@ def run_experiment(_config: DictConfig) -> float: unreplicated_inital_params = flax.jax_utils.unreplicate(learner_state.params) params_sources: List[ParamsSource] = [] thread_lifetimes: List[ThreadLifetime] = [] - pipeline = Pipeline(config.arh.Pilpeline_queue_size, learner_devices) + pipeline = Pipeline(config.arch.pilpeline_queue_size, learner_devices) pipeline.start() # Create the actor threads @@ -712,6 +712,3 @@ def hydra_entry_point(cfg: DictConfig) -> float: if __name__ == "__main__": hydra_entry_point() - -# learner_output.episode_metrics.keys() -# dict_keys(['episode_length', 'episode_return']) diff --git a/mava/systems/q_learning/anakin/rec_iql.py b/mava/systems/q_learning/anakin/rec_iql.py index a8fa7964b..05b860d85 100644 --- a/mava/systems/q_learning/anakin/rec_iql.py +++ b/mava/systems/q_learning/anakin/rec_iql.py @@ -54,7 +54,7 @@ unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import anakin_check_total_timesteps +from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics @@ -533,7 +533,7 @@ def update_step( def run_experiment(cfg: DictConfig) -> float: # Add runtime variables to config cfg.arch.n_devices = len(jax.devices()) - cfg = anakin_check_total_timesteps(cfg) + cfg = check_total_timesteps(cfg) # Number of env steps before evaluating/logging. steps_per_rollout = int(cfg.system.total_timesteps // cfg.arch.num_evaluation) diff --git a/mava/systems/sac/anakin/ff_isac.py b/mava/systems/sac/anakin/ff_isac.py index d0f243b3f..955725e00 100644 --- a/mava/systems/sac/anakin/ff_isac.py +++ b/mava/systems/sac/anakin/ff_isac.py @@ -51,7 +51,7 @@ from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import anakin_check_total_timesteps +from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics @@ -488,7 +488,7 @@ def update_step(carry: LearnerState, _: Any) -> Tuple[LearnerState, Tuple[Metric def run_experiment(cfg: DictConfig) -> float: # Add runtime variables to config cfg.arch.n_devices = len(jax.devices()) - cfg = anakin_check_total_timesteps(cfg) + cfg = check_total_timesteps(cfg) # Number of env steps before evaluating/logging. steps_per_rollout = int(cfg.system.total_timesteps // cfg.arch.num_evaluation) diff --git a/mava/systems/sac/anakin/ff_masac.py b/mava/systems/sac/anakin/ff_masac.py index bf45f4b83..2df296be4 100644 --- a/mava/systems/sac/anakin/ff_masac.py +++ b/mava/systems/sac/anakin/ff_masac.py @@ -52,7 +52,7 @@ from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import anakin_check_total_timesteps +from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics @@ -506,7 +506,7 @@ def update_step(carry: LearnerState, _: Any) -> Tuple[LearnerState, Tuple[Metric def run_experiment(cfg: DictConfig) -> float: # Add runtime variables to config cfg.arch.n_devices = len(jax.devices()) - cfg = anakin_check_total_timesteps(cfg) + cfg = check_total_timesteps(cfg) # Number of env steps before evaluating/logging. steps_per_rollout = int(cfg.system.total_timesteps // cfg.arch.num_evaluation) diff --git a/mava/utils/total_timestep_checker.py b/mava/utils/total_timestep_checker.py index 744451d1b..e48e40923 100644 --- a/mava/utils/total_timestep_checker.py +++ b/mava/utils/total_timestep_checker.py @@ -18,47 +18,23 @@ from omegaconf import DictConfig -def anakin_check_total_timesteps(config: DictConfig) -> DictConfig: +def check_total_timesteps(config: DictConfig) -> DictConfig: """Check if total_timesteps is set, if not, set it based on the other parameters""" - n_devices = len(jax.devices()) - if config.system.total_timesteps is None: - config.system.num_updates = int(config.system.num_updates) - config.system.total_timesteps = int( - n_devices - * config.system.num_updates - * config.system.rollout_length - * config.system.update_batch_size - * config.arch.num_envs - ) + if config.arch.architecture_name == "anakin": + n_devices = len(jax.devices()) + update_batch_size = config.system.update_batch_size else: - config.system.total_timesteps = int(config.system.total_timesteps) - config.system.num_updates = int( - config.system.total_timesteps - // config.system.rollout_length - // config.system.update_batch_size - // config.arch.num_envs - // n_devices - ) - print( - f"{Fore.RED}{Style.BRIGHT} Changing the number of updates " - + f"to {config.system.num_updates}: If you want to train" - + " for a specific number of updates, please set total_timesteps to None!" - + f"{Style.RESET_ALL}" - ) - return config - - -def sebulba_check_total_timesteps(config: DictConfig) -> DictConfig: - """Check if total_timesteps is set, if not, set it based on the other parameters""" + n_devices = 1 # We only use a single device's output when updating. + update_batch_size = 1 if config.system.total_timesteps is None: config.system.num_updates = int(config.system.num_updates) config.system.total_timesteps = int( - len(config.arch.executor_device_ids) - * config.arch.n_threads_per_executor + n_devices * config.system.num_updates * config.system.rollout_length + * update_batch_size * config.arch.num_envs ) else: @@ -66,9 +42,9 @@ def sebulba_check_total_timesteps(config: DictConfig) -> DictConfig: config.system.num_updates = int( config.system.total_timesteps // config.system.rollout_length + // update_batch_size // config.arch.num_envs - // config.arch.n_threads_per_executor - // len(config.arch.executor_device_ids) + // n_devices ) print( f"{Fore.RED}{Style.BRIGHT} Changing the number of updates " From 38e72291073fa9abeeffa719d48b661f861f18c4 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 29 Jul 2024 11:49:58 +0100 Subject: [PATCH 085/125] chore: removed unused eval type --- mava/types.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mava/types.py b/mava/types.py index 1c9f64640..1d5878c5a 100644 --- a/mava/types.py +++ b/mava/types.py @@ -157,7 +157,6 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]): [MavaState, MavaTransition, chex.Array, chex.Array], ExperimentOutput[MavaState] ] EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[MavaState]] -SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict] ActorApply = Callable[[FrozenDict, Observation], Distribution] CriticApply = Callable[[FrozenDict, Observation], Value] From 5a5e542c6b135bcc86d2d40c06ac6905e5f7b435 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 29 Jul 2024 11:53:22 +0100 Subject: [PATCH 086/125] chore: config file changes --- mava/configs/arch/sebulba.yaml | 3 ++- .../{default_ff_ippo_seb.yaml => default_ff_ippo_sebulba.yaml} | 2 +- mava/systems/ppo/sebulba/ff_ippo.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) rename mava/configs/{default_ff_ippo_seb.yaml => default_ff_ippo_sebulba.yaml} (84%) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index e9865460a..5934bb3d5 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -10,8 +10,9 @@ evaluation_greedy: False # Evaluate the policy greedily. If True the policy will # from the logits. num_eval_episodes: 32 # Number of episodes to evaluate per evaluation. num_evaluation: 200 # Number of evenly spaced evaluations to perform during training. +num_absolute_metric_eval_episodes: 320 # Number of episodes to evaluate the absolute metric (the final evaluation). absolute_metric: True # Whether the absolute metric should be computed. For more details - # on the absolute metric please see: https://arxiv.org/abs/2209.10485 +# on the absolute metric please see: https://arxiv.org/abs/2209.10485 # --- Sebulba devices config --- n_threads_per_executor: 2 # num of different threads/env batches per actor diff --git a/mava/configs/default_ff_ippo_seb.yaml b/mava/configs/default_ff_ippo_sebulba.yaml similarity index 84% rename from mava/configs/default_ff_ippo_seb.yaml rename to mava/configs/default_ff_ippo_sebulba.yaml index 204719232..3a7386969 100644 --- a/mava/configs/default_ff_ippo_seb.yaml +++ b/mava/configs/default_ff_ippo_sebulba.yaml @@ -3,5 +3,5 @@ defaults: - arch: sebulba - system: ppo/ff_ippo - network: mlp - - env: rware_gym + - env: lbf_gym - _self_ diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index b9f83f20b..946d92315 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -697,7 +697,7 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../../configs", config_name="default_ff_ippo_seb.yaml", version_base="1.2" + config_path="../../../configs", config_name="default_ff_ippo_sebulba.yaml", version_base="1.2" ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" From dcff2a1c2f4a60272a13404a854ddb563b0b460c Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 29 Jul 2024 15:42:31 +0100 Subject: [PATCH 087/125] fix: fixed stalling at the end of training --- mava/configs/arch/sebulba.yaml | 8 ++--- mava/evaluator.py | 4 +-- mava/systems/ppo/sebulba/ff_ippo.py | 48 +++++++++++++++++----------- mava/types.py | 2 -- mava/utils/sebulba_utils.py | 49 ++++++++++++++++------------- 5 files changed, 63 insertions(+), 48 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 5934bb3d5..342e0ee29 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -9,8 +9,8 @@ evaluation_greedy: False # Evaluate the policy greedily. If True the policy will # an action which corresponds to the greatest logit. If false, the policy will sample # from the logits. num_eval_episodes: 32 # Number of episodes to evaluate per evaluation. -num_evaluation: 200 # Number of evenly spaced evaluations to perform during training. -num_absolute_metric_eval_episodes: 320 # Number of episodes to evaluate the absolute metric (the final evaluation). +num_evaluation: 10 # Number of evenly spaced evaluations to perform during training. +num_absolute_metric_eval_episodes: 32 # Number of episodes to evaluate the absolute metric (the final evaluation). absolute_metric: True # Whether the absolute metric should be computed. For more details # on the absolute metric please see: https://arxiv.org/abs/2209.10485 @@ -18,8 +18,8 @@ absolute_metric: True # Whether the absolute metric should be computed. For more n_threads_per_executor: 2 # num of different threads/env batches per actor executor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices -pilpeline_queue_size : 5 +rollout_queue_size : 5 # The size of the pipeline queue determines the extent of off-policy training allowed. A larger value permits more off-policy training. # Too large of a value with too many actors will lead to all of the updates getting wasted in old episodes # Too small of a value and the utility of having multiple actors is lost. -# A value of 1 leads to almost strictly on-policy training. +# A value of 1 with a single actor leads to almost strictly on-policy training. diff --git a/mava/evaluator.py b/mava/evaluator.py index e754899ae..83e8841c3 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -215,7 +215,7 @@ def get_sebulba_eval_fn( config: DictConfig, np_rng: np.random.Generator, absolute_metric: bool, -) -> EvalFn: +) -> Tuple[EvalFn, Any]: """Creates a function that can be used to evaluate agents on a given environment. Args: @@ -314,4 +314,4 @@ def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) metrics["steps_per_second"] = total_timesteps / (end_time - start_time) return metrics - return timed_eval_fn + return timed_eval_fn, env diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 946d92315..2fd098a5d 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -145,6 +145,7 @@ def get_action_and_value( ) rollout_pipeline.put(traj, timestep.observation, next_dones, time_dict) + env.close() def get_learner_fn( @@ -408,7 +409,7 @@ def learner_setup( # create temporory envoirnments. env = environments.make_gym_env(config, config.arch.num_envs) # Get number of agents and actions. - action_space = env.single_action_space + action_space = env.unwrapped.single_action_space config.system.num_agents = len(action_space) config.system.num_actions = int(action_space[0].n) @@ -438,7 +439,7 @@ def learner_setup( ) # Initialise observation: Select only obs for a single agent. - init_obs = jnp.array([env.single_observation_space.sample()]) + init_obs = jnp.array([env.unwrapped.single_observation_space.sample()]) init_action_mask = jnp.ones((config.system.num_agents, config.system.num_actions)) init_x = Observation(init_obs, init_action_mask) @@ -563,7 +564,7 @@ def run_experiment(_config: DictConfig) -> float: # Setup evaluator. # One key per device for evaluation. eval_act_fn = make_ff_eval_act_fn(apply_fns[0], config) - evaluator = get_eval_fn( + evaluator, evaluator_envs = get_eval_fn( environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=False ) @@ -596,25 +597,29 @@ def run_experiment(_config: DictConfig) -> float: # Executor setup and launch. unreplicated_inital_params = flax.jax_utils.unreplicate(learner_state.params) - params_sources: List[ParamsSource] = [] - thread_lifetimes: List[ThreadLifetime] = [] - pipeline = Pipeline(config.arch.pilpeline_queue_size, learner_devices) + + pipeline_lifetime = ThreadLifetime() + pipeline = Pipeline(config.arch.rollout_queue_size, learner_devices, pipeline_lifetime) pipeline.start() + params_sources: List[ParamsSource] = [] + actor_threads: List[threading.Thread] = [] + actors_lifetime = ThreadLifetime() + params_sources_lifetime = ThreadLifetime() + # Create the actor threads for d_idx, d_id in enumerate(config.arch.executor_device_ids): # Loop through each executor thread for thread_id in range(config.arch.n_threads_per_executor): seeds = np_rng.integers(np.iinfo(np.int32).max, size=config.arch.num_envs).tolist() - params_source = ParamsSource(unreplicated_inital_params, devices[d_id]) + params_source = ParamsSource( + unreplicated_inital_params, devices[d_id], params_sources_lifetime + ) params_source.start() params_sources.append(params_source) - lifetime = ThreadLifetime() - thread_lifetimes.append(lifetime) - - threading.Thread( + actor = threading.Thread( target=rollout, args=( jax.device_put(key, devices[d_id]), @@ -624,10 +629,12 @@ def run_experiment(_config: DictConfig) -> float: apply_fns, d_id, seeds, - lifetime, + actors_lifetime, ), name=f"Actor-{thread_id + d_idx * config.arch.n_threads_per_executor}", - ).start() + ) + actor.start() + actor_threads.append(actor) learner_queue: Queue = Queue() threading.Thread( @@ -674,14 +681,19 @@ def run_experiment(_config: DictConfig) -> float: best_params = copy.deepcopy(unreplicated_actor_params) max_episode_return = episode_return - for thread_lifetime in thread_lifetimes: - thread_lifetime.stop() - + evaluator_envs.close() eval_performance = float(jnp.mean(eval_metrics[config.env.eval_metric])) + # Make sure all of the actors are done befor closing the pipeline + actors_lifetime.stop() + for actor in actor_threads: + actor.join() + pipeline_lifetime.stop() + params_sources_lifetime.stop() + # Measure absolute metric. if config.arch.absolute_metric: - abs_metric_evaluator = get_eval_fn( + abs_metric_evaluator, abs_metric_evaluator_envs = get_eval_fn( environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=True ) key, eval_key = jax.random.split(key, 2) @@ -689,7 +701,7 @@ def run_experiment(_config: DictConfig) -> float: t = int(steps_per_rollout * (eval_step + 1)) logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) - + abs_metric_evaluator_envs.close() # Stop the logger. logger.stop() diff --git a/mava/types.py b/mava/types.py index 1d5878c5a..fe51ce293 100644 --- a/mava/types.py +++ b/mava/types.py @@ -156,8 +156,6 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]): SebulbaLearnerFn = Callable[ [MavaState, MavaTransition, chex.Array, chex.Array], ExperimentOutput[MavaState] ] -EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[MavaState]] - ActorApply = Callable[[FrozenDict, Observation], Distribution] CriticApply = Callable[[FrozenDict, Observation], Value] RecActorApply = Callable[ diff --git a/mava/utils/sebulba_utils.py b/mava/utils/sebulba_utils.py index a5c0bdc14..e1fd34f37 100644 --- a/mava/utils/sebulba_utils.py +++ b/mava/utils/sebulba_utils.py @@ -27,6 +27,19 @@ # Copied from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py +class ThreadLifetime: + """Simple class for a mutable boolean that can be used to signal a thread to stop.""" + + def __init__(self) -> None: + self._stop = False + + def should_stop(self) -> bool: + return self._stop + + def stop(self) -> None: + self._stop = True + + class Pipeline(threading.Thread): """ The `Pipeline` shards trajectories into `learner_devices`, @@ -34,7 +47,7 @@ class Pipeline(threading.Thread): and limit the max number of samples in device memory at one time to avoid OOM issues. """ - def __init__(self, max_size: int, learner_devices: List[jax.Device]): + def __init__(self, max_size: int, learner_devices: List[jax.Device], lifetime: ThreadLifetime): """ Initializes the pipeline with a maximum size and the devices to shard trajectories across. @@ -46,6 +59,7 @@ def __init__(self, max_size: int, learner_devices: List[jax.Device]): self.learner_devices = learner_devices self.tickets_queue: queue.Queue = queue.Queue() self._queue: queue.Queue = queue.Queue(maxsize=max_size) + self.lifetime = lifetime def run(self) -> None: """ @@ -53,12 +67,15 @@ def run(self) -> None: start_condition and end_condition are used to ensure that only 1 thread is processing an item from the queue at one time, ensuring predictable memory usage. """ - while True: # todo Thread lifetime - start_condition, end_condition = self.tickets_queue.get() - with end_condition: - with start_condition: - start_condition.notify() - end_condition.wait() + while not self.lifetime.should_stop(): + try: + start_condition, end_condition = self.tickets_queue.get(timeout=1) + with end_condition: + with start_condition: + start_condition.notify() + end_condition.wait() + except queue.Empty: + continue def put( self, @@ -112,18 +129,19 @@ class ParamsSource(threading.Thread): `Learner` component to `Actor` components. """ - def __init__(self, init_value: Params, device: jax.Device): + def __init__(self, init_value: Params, device: jax.Device, lifetime: ThreadLifetime): super().__init__(name=f"ParamsSource-{device.id}") self.value: Params = jax.device_put(init_value, device) self.device = device self.new_value: queue.Queue = queue.Queue() + self.lifetime = lifetime def run(self) -> None: """ This function is responsible for updating the value of the `ParamSource` when a new value is available. """ - while True: + while not self.lifetime.should_stop(): try: waiting = self.new_value.get(block=True, timeout=1) self.value = jax.device_put(jax.block_until_ready(waiting), self.device) @@ -154,16 +172,3 @@ def __enter__(self) -> None: def __exit__(self, *args: Any) -> None: end = time.monotonic() self.to.append(end - self.start) - - -class ThreadLifetime: - """Simple class for a mutable boolean that can be used to signal a thread to stop.""" - - def __init__(self) -> None: - self._stop = False - - def should_stop(self) -> bool: - return self._stop - - def stop(self) -> None: - self._stop = True From d926c54f4b043f19e616cf28cfa5d4e1e09456c5 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 29 Jul 2024 16:51:55 +0100 Subject: [PATCH 088/125] chore: code cleanup --- mava/configs/arch/sebulba.yaml | 6 +-- mava/configs/system/ppo/ff_ippo.yaml | 2 +- mava/evaluator.py | 6 +-- mava/systems/ppo/sebulba/ff_ippo.py | 76 +++++++++++++++++----------- mava/wrappers/gym.py | 7 ++- 5 files changed, 57 insertions(+), 40 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 342e0ee29..65be6e68a 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -2,15 +2,15 @@ architecture_name: sebulba # --- Training --- -num_envs: 32 # number of environments per thread. +num_envs: 2 # number of environments per thread. # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select # an action which corresponds to the greatest logit. If false, the policy will sample # from the logits. -num_eval_episodes: 32 # Number of episodes to evaluate per evaluation. +num_eval_episodes: 2 # Number of episodes to evaluate per evaluation. num_evaluation: 10 # Number of evenly spaced evaluations to perform during training. -num_absolute_metric_eval_episodes: 32 # Number of episodes to evaluate the absolute metric (the final evaluation). +num_absolute_metric_eval_episodes: 2 # Number of episodes to evaluate the absolute metric (the final evaluation). absolute_metric: True # Whether the absolute metric should be computed. For more details # on the absolute metric please see: https://arxiv.org/abs/2209.10485 diff --git a/mava/configs/system/ppo/ff_ippo.yaml b/mava/configs/system/ppo/ff_ippo.yaml index 9efb0611a..622d94ca2 100644 --- a/mava/configs/system/ppo/ff_ippo.yaml +++ b/mava/configs/system/ppo/ff_ippo.yaml @@ -2,7 +2,7 @@ total_timesteps: ~ # Set the total environment steps. # If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. -num_updates: 1000 # Number of updates +num_updates: 200 # Number of updates seed: 42 # --- Agent observations --- diff --git a/mava/evaluator.py b/mava/evaluator.py index 83e8841c3..b16f43c75 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -284,8 +284,8 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: # find the first instance of done to get the metrics at that timestep, we don't # care about subsequent steps because we only the results from the first episode - done_idx = jnp.argmax(timesteps.last(), axis=0) - metrics = jax.tree_map(lambda m: m[done_idx, jnp.arange(n_parallel_envs)], metrics) + done_idx = np.argmax(timesteps.last(), axis=0) + metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) del metrics["is_terminal_step"] # uneeded for logging return key, metrics @@ -299,7 +299,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: metrics.append(metric) metrics: Metrics = jax.tree_map( - lambda *x: jnp.array(x).reshape(-1), *metrics + lambda *x: np.array(x).reshape(-1), *metrics ) # flatten metrics return metrics diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 2fd098a5d..00c699512 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -56,13 +56,28 @@ def rollout( key: chex.PRNGKey, config: DictConfig, - rollout_pipeline: Pipeline, + rollout_queue: Pipeline, params_source: ParamsSource, - apply_fns: Tuple, + apply_fns: Tuple[ActorApply, CriticApply], actor_device_id: int, seeds: List[int], thread_lifetime: ThreadLifetime, ) -> None: + """Runs rollouts to collect trajectories from the environment. + + Args: + key (chex.PRNGKey): The PRNGkey. + config (DictConfig): Configuration settings for the environment and rollout. + rollout_queue (Pipeline): Queue for sending collected rollouts. + params_source (ParamsSource): Source for fetching the latest network parameters. + apply_fns (Tuple): Functions for running the actor and critic networks. + actor_device_id (int): Actor device id for the current thread. + seeds (List[int]): Seeds for initializing the environment. + thread_lifetime (ThreadLifetime): Manages the thread's lifecycle. + + Returns: + None: This function updates the rollout queue with collected data. + """ # setup env = environments.make_gym_env(config, config.arch.num_envs) current_actor_device = jax.devices()[actor_device_id] @@ -88,10 +103,7 @@ def get_action_and_value( timestep = env.reset(seed=seeds) - next_dones = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, num_agents).reshape(num_envs, -1), - timestep.last(), - ) + next_dones = jnp.repeat(timestep.last(), num_agents).reshape(num_envs, -1) move_to_device = lambda x: jax.device_put(x, device=current_actor_device) @@ -99,13 +111,20 @@ def get_action_and_value( while not thread_lifetime.should_stop(): # Rollout traj: List = [] - time_dict: Dict[str, List[float]] = {"single_rollout": [], "env_step_time": []} + time_dict: Dict[str, List[float]] = { + "single_rollout_time": [], + "env_step_time": [], + "getting_params_time": [], + "putting_rollout_time": [], + } # Loop over the rollout length - with RecordTimeTo(time_dict["single_rollout"]): + with RecordTimeTo(time_dict["single_rollout_time"]): for _ in range(config.system.rollout_length): # Get the latest parameters from the learner - params = params_source.get() + + with RecordTimeTo(time_dict["getting_params_time"]): + params = params_source.get() cached_next_obs = jax.tree.map(move_to_device, timestep.observation) cached_next_dones = move_to_device(next_dones) @@ -126,10 +145,7 @@ def get_action_and_value( cpu_action.swapaxes(0, 1) ) # (num_env, num_agents) --> (num_agents, num_env) - next_dones = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, num_agents).reshape(num_envs, -1), - timestep.last(), - ) + next_dones = jnp.repeat(timestep.last(), num_agents).reshape(num_envs, -1) # Append data to storage traj.append( @@ -143,8 +159,9 @@ def get_action_and_value( info=timestep.extras, ) ) - - rollout_pipeline.put(traj, timestep.observation, next_dones, time_dict) + # send trajectories to learner + with RecordTimeTo(time_dict["putting_rollout_time"]): + rollout_queue.put(traj, timestep.observation, next_dones, time_dict) env.close() @@ -167,10 +184,9 @@ def _update_step( ) -> Tuple[LearnerState, Tuple]: """A single update of the network. - This function steps the environment and records the trajectory batch for - training. It then calculates advantages and targets based on the recorded - trajectory and updates the actor and critic networks based on the calculated - losses. + This function calculates advantages and targets based on the trajectories + from the actor and updates the actor and critic networks based on the + calculated losses. Args: learner_state (NamedTuple): @@ -295,12 +311,12 @@ def _critic_loss_fn( # pmean over devices. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), - axis_name="device", + axis_name="learner_devices", ) - # pmean over devices. + # pmean over learner devices. critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" + (critic_grads, critic_loss_info), axis_name="learner_devices" ) # UPDATE ACTOR PARAMS AND OPTIMISER STATE @@ -460,7 +476,7 @@ def learner_setup( # Get batched iterated update and replicate it to pmap it over learner cores. learn = get_learner_fn(apply_fns, update_fns, config) - learn = jax.pmap(learn, axis_name="device", devices=learner_devices) + learn = jax.pmap(learn, axis_name="learner_devices", devices=learner_devices) # Load model from checkpoint if specified. if config.logger.checkpointing.load_model: @@ -523,10 +539,10 @@ def learner( episode_metrics, train_metrics = jax.tree.map(lambda *x: np.asarray(x), *metrics) rollout_times = jax.tree.map(lambda *x: np.mean(x), *rollout_times) - times_dict = rollout_times | eval_times - times_dict = jax.tree.map(np.mean, times_dict, is_leaf=lambda x: isinstance(x, list)) + timing_dict = rollout_times | eval_times + timing_dict = jax.tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) - learner_queue.put((episode_metrics, train_metrics, learner_state, times_dict)) + learner_queue.put((episode_metrics, train_metrics, learner_state, timing_dict)) def run_experiment(_config: DictConfig) -> float: @@ -646,17 +662,19 @@ def run_experiment(_config: DictConfig) -> float: max_episode_return = -jnp.inf best_params = unreplicated_inital_params.actor_params + # This is the main loop, all it does is evaluation and logging. + # Acting and learning is happening in their own threads. + # This loop waits for the learner to finish an update before evaluation and logging. for eval_step in range(config.arch.num_evaluation): - # Get the next set of params and metrics from the evaluator + # Get the next set of params and metrics from the learner episode_metrics, train_metrics, learner_state, times_dict = learner_queue.get() t = int(steps_per_rollout * (eval_step + 1)) - times_dict["timestep"] = t logger.log(times_dict, t, eval_step, LogEvent.MISC) episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / times_dict["single_rollout"] + episode_metrics["steps_per_second"] = steps_per_rollout / times_dict["single_rollout_time"] if ep_completed: logger.log(episode_metrics, t, eval_step, LogEvent.ACT) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 35bd674bd..6dcbf9963 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -196,7 +196,7 @@ def modify_space(self, space: spaces.Space) -> spaces.Space: class GymToJumanji(gymnasium.Wrapper): - """Converts Gym outputs to Jumanji timesteps""" + """Converts from the Gym API to the dm_env API, Jumanji's Timestep type.""" def reset( self, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None @@ -227,9 +227,8 @@ def _format_observation( ) -> Union[Observation, ObservationGlobalState]: """Create an observation from the raw observation and environment state.""" - obs = np.array(obs).swapaxes( - 0, 1 - ) # (num_agents, num_envs, ...) -> (num_envs, num_agents, ...) + # (num_agents, num_envs, ...) -> (num_envs, num_agents, ...) + obs = np.array(obs).swapaxes(0, 1) action_mask = np.stack(info["actions_mask"]) obs_data = {"agents_view": obs, "action_mask": action_mask} From 7e4698a1446bc55d1e5d7aa17ad294d8a2142865 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 29 Jul 2024 17:25:29 +0100 Subject: [PATCH 089/125] chore : various changes --- mava/configs/arch/sebulba.yaml | 6 +++--- mava/configs/system/ppo/ff_ippo.yaml | 2 +- mava/systems/ppo/sebulba/ff_ippo.py | 14 +++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 65be6e68a..0c1c8880d 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -2,15 +2,15 @@ architecture_name: sebulba # --- Training --- -num_envs: 2 # number of environments per thread. +num_envs: 32 # number of environments per thread. # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select # an action which corresponds to the greatest logit. If false, the policy will sample # from the logits. -num_eval_episodes: 2 # Number of episodes to evaluate per evaluation. +num_eval_episodes: 200 # Number of episodes to evaluate per evaluation. num_evaluation: 10 # Number of evenly spaced evaluations to perform during training. -num_absolute_metric_eval_episodes: 2 # Number of episodes to evaluate the absolute metric (the final evaluation). +num_absolute_metric_eval_episodes: 32 # Number of episodes to evaluate the absolute metric (the final evaluation). absolute_metric: True # Whether the absolute metric should be computed. For more details # on the absolute metric please see: https://arxiv.org/abs/2209.10485 diff --git a/mava/configs/system/ppo/ff_ippo.yaml b/mava/configs/system/ppo/ff_ippo.yaml index 622d94ca2..9efb0611a 100644 --- a/mava/configs/system/ppo/ff_ippo.yaml +++ b/mava/configs/system/ppo/ff_ippo.yaml @@ -2,7 +2,7 @@ total_timesteps: ~ # Set the total environment steps. # If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. -num_updates: 200 # Number of updates +num_updates: 1000 # Number of updates seed: 42 # --- Agent observations --- diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 00c699512..31c7e26af 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -27,6 +27,7 @@ import optax from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict +from jax import tree from omegaconf import DictConfig, OmegaConf from optax._src.base import OptState from rich.pretty import pprint @@ -126,7 +127,7 @@ def get_action_and_value( with RecordTimeTo(time_dict["getting_params_time"]): params = params_source.get() - cached_next_obs = jax.tree.map(move_to_device, timestep.observation) + cached_next_obs = tree.map(move_to_device, timestep.observation) cached_next_dones = move_to_device(next_dones) # Get action and value @@ -474,7 +475,6 @@ def learner_setup( apply_fns = (actor_network.apply, critic_network.apply) update_fns = (actor_optim.update, critic_optim.update) - # Get batched iterated update and replicate it to pmap it over learner cores. learn = get_learner_fn(apply_fns, update_fns, config) learn = jax.pmap(learn, axis_name="learner_devices", devices=learner_devices) @@ -536,11 +536,11 @@ def learner( source.update(unreplicated_params) # Pass to the evaluator - episode_metrics, train_metrics = jax.tree.map(lambda *x: np.asarray(x), *metrics) + episode_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) - rollout_times = jax.tree.map(lambda *x: np.mean(x), *rollout_times) + rollout_times = tree.map(lambda *x: np.mean(x), *rollout_times) timing_dict = rollout_times | eval_times - timing_dict = jax.tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) + timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) learner_queue.put((episode_metrics, train_metrics, learner_state, timing_dict)) @@ -553,8 +553,8 @@ def run_experiment(_config: DictConfig) -> float: learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] # PRNG keys. - key, key_e, actor_net_key, critic_net_key = jax.random.split( - jax.random.PRNGKey(config.system.seed), num=4 + key, actor_net_key, critic_net_key = jax.random.split( + jax.random.PRNGKey(config.system.seed), num=3 ) # Sanity check of config From 6dac8c3206b806db26d598158b8de1aad571c755 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 30 Jul 2024 12:26:49 +0100 Subject: [PATCH 090/125] fix: prevent the pipeline from stalling and a lot of cleanup --- mava/systems/ppo/sebulba/ff_ippo.py | 88 +++++++++++++---------------- mava/utils/sebulba_utils.py | 25 ++++++++ mava/wrappers/gym.py | 2 +- 3 files changed, 65 insertions(+), 50 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 31c7e26af..04aeda480 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -48,7 +48,13 @@ from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import merge_leading_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.sebulba_utils import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime +from mava.utils.sebulba_utils import ( + ParamsSource, + Pipeline, + RecordTimeTo, + ThreadLifetime, + check_config, +) from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -69,15 +75,13 @@ def rollout( Args: key (chex.PRNGKey): The PRNGkey. config (DictConfig): Configuration settings for the environment and rollout. - rollout_queue (Pipeline): Queue for sending collected rollouts. - params_source (ParamsSource): Source for fetching the latest network parameters. + rollout_queue (Pipeline): Queue for sending collected rollouts to the learner. + params_source (ParamsSource): Source for fetching the latest network parameters + from the learner. apply_fns (Tuple): Functions for running the actor and critic networks. - actor_device_id (int): Actor device id for the current thread. + actor_device_id (int): Device ID for this actor thread. seeds (List[int]): Seeds for initializing the environment. thread_lifetime (ThreadLifetime): Manages the thread's lifecycle. - - Returns: - None: This function updates the rollout queue with collected data. """ # setup env = environments.make_gym_env(config, config.arch.num_envs) @@ -115,8 +119,8 @@ def get_action_and_value( time_dict: Dict[str, List[float]] = { "single_rollout_time": [], "env_step_time": [], - "getting_params_time": [], - "putting_rollout_time": [], + "get_params_time": [], + "put_rollout_time": [], } # Loop over the rollout length @@ -124,7 +128,7 @@ def get_action_and_value( for _ in range(config.system.rollout_length): # Get the latest parameters from the learner - with RecordTimeTo(time_dict["getting_params_time"]): + with RecordTimeTo(time_dict["get_params_time"]): params = params_source.get() cached_next_obs = tree.map(move_to_device, timestep.observation) @@ -142,9 +146,8 @@ def get_action_and_value( cpu_action = jax.device_get(action) with RecordTimeTo(time_dict["env_step_time"]): - timestep = env.step( - cpu_action.swapaxes(0, 1) - ) # (num_env, num_agents) --> (num_agents, num_env) + # (num_env, num_agents) --> (num_agents, num_env) + timestep = env.step(cpu_action.swapaxes(0, 1)) next_dones = jnp.repeat(timestep.last(), num_agents).reshape(num_envs, -1) @@ -161,7 +164,7 @@ def get_action_and_value( ) ) # send trajectories to learner - with RecordTimeTo(time_dict["putting_rollout_time"]): + with RecordTimeTo(time_dict["put_rollout_time"]): rollout_queue.put(traj, timestep.observation, next_dones, time_dict) env.close() @@ -190,12 +193,10 @@ def _update_step( calculated losses. Args: - learner_state (NamedTuple): - - params (Params): The current model parameters. - - opt_states (OptStates): The current optimizer states. - - key (PRNGKey): The random number generator state. - - env_state (State): The environment state. - - last_timestep (TimeStep): The last timestep in the current trajectory. + learner_state (LearnerState): contains all the items needed for learning. + traj_batch (PPOTransition): the batch of data to learn with. + last_obs (Observation): the final observations (for bootstrapping in GAE). + last_dones (Array): the final dones (for bootstrapping in GAE). _ (Any): The current metrics info. """ @@ -309,7 +310,7 @@ def _critic_loss_fn( # Compute the parallel mean (pmean) over the batch. # This calculation is inspired by the Anakin architecture demo notebook. # available at https://tinyurl.com/26tdzs5x - # pmean over devices. + # pmean over learner devices. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="learner_devices", @@ -509,20 +510,20 @@ def learner( learn: SebulbaLearnerFn[LearnerState, PPOTransition], learner_state: LearnerState, config: DictConfig, - learner_queue: Queue, + eval_queue: Queue, pipeline: Pipeline, params_sources: Sequence[ParamsSource], ) -> None: for _eval_step in range(config.arch.num_evaluation): metrics: List[Tuple[Dict, Dict]] = [] rollout_times: List[Dict] = [] - eval_times: Dict[str, List[float]] = {"evaluator_blocked_time": [], "evaluation_time": []} + eval_times: Dict[str, List[float]] = {"rollout_get_time": [], "learning_time": []} for _update in range(config.system.num_updates_per_eval): - with RecordTimeTo(eval_times["evaluator_blocked_time"]): + with RecordTimeTo(eval_times["rollout_get_time"]): traj_batch, last_obs, last_dones, rollout_time = pipeline.get(block=True) - with RecordTimeTo(eval_times["evaluation_time"]): + with RecordTimeTo(eval_times["learning_time"]): learner_state, episode_metrics, train_metrics = learn( learner_state, traj_batch, last_obs, last_dones ) @@ -542,7 +543,7 @@ def learner( timing_dict = rollout_times | eval_times timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) - learner_queue.put((episode_metrics, train_metrics, learner_state, timing_dict)) + eval_queue.put((episode_metrics, train_metrics, learner_state, timing_dict)) def run_experiment(_config: DictConfig) -> float: @@ -557,26 +558,14 @@ def run_experiment(_config: DictConfig) -> float: jax.random.PRNGKey(config.system.seed), num=3 ) - # Sanity check of config - assert ( - config.arch.num_envs % len(config.arch.learner_device_ids) == 0 - ), "The number of environments must to be divisible by the number of learners." - - assert ( - int(config.arch.num_envs / len(config.arch.learner_device_ids)) - * config.arch.n_threads_per_executor - % config.system.num_minibatches - == 0 - ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches." + # Numpy RNG. + np_rng = np.random.default_rng(config.system.seed) # Setup learner. learn, apply_fns, learner_state = learner_setup( (key, actor_net_key, critic_net_key), config, learner_devices ) - # Generate Numpy RNG for reproducibility - np_rng = np.random.default_rng(config.system.seed) - # Setup evaluator. # One key per device for evaluation. eval_act_fn = make_ff_eval_act_fn(apply_fns[0], config) @@ -586,11 +575,7 @@ def run_experiment(_config: DictConfig) -> float: # Calculate total timesteps. config = check_total_timesteps(config) - assert ( - config.system.num_updates > config.arch.num_evaluation - ), "Number of updates per evaluation must be less than total number of updates." - # Calculate number of updates per evaluation. - config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + check_config(config) steps_per_rollout = ( config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval @@ -652,11 +637,11 @@ def run_experiment(_config: DictConfig) -> float: actor.start() actor_threads.append(actor) - learner_queue: Queue = Queue() + eval_queue: Queue = Queue() threading.Thread( target=learner, name="Learner", - args=(learn, learner_state, config, learner_queue, pipeline, params_sources), + args=(learn, learner_state, config, eval_queue, pipeline, params_sources), ).start() max_episode_return = -jnp.inf @@ -667,7 +652,7 @@ def run_experiment(_config: DictConfig) -> float: # This loop waits for the learner to finish an update before evaluation and logging. for eval_step in range(config.arch.num_evaluation): # Get the next set of params and metrics from the learner - episode_metrics, train_metrics, learner_state, times_dict = learner_queue.get() + episode_metrics, train_metrics, learner_state, times_dict = eval_queue.get() t = int(steps_per_rollout * (eval_step + 1)) times_dict["timestep"] = t @@ -702,12 +687,17 @@ def run_experiment(_config: DictConfig) -> float: evaluator_envs.close() eval_performance = float(jnp.mean(eval_metrics[config.env.eval_metric])) - # Make sure all of the actors are done befor closing the pipeline + # Make sure all of the Threads are closed. actors_lifetime.stop() for actor in actor_threads: actor.join() + pipeline_lifetime.stop() + pipeline.join() + params_sources_lifetime.stop() + for params_source in params_sources: + params_source.join() # Measure absolute metric. if config.arch.absolute_metric: diff --git a/mava/utils/sebulba_utils.py b/mava/utils/sebulba_utils.py index e1fd34f37..8e84b4267 100644 --- a/mava/utils/sebulba_utils.py +++ b/mava/utils/sebulba_utils.py @@ -21,6 +21,7 @@ import jax import jax.numpy as jnp from chex import Array +from omegaconf import DictConfig from mava.systems.ppo.types import Params, PPOTransition # todo: remove the ppo dependencies from mava.types import Observation, ObservationGlobalState @@ -103,6 +104,12 @@ def put( # [(num_envs / num_learner_devices, num_agents)] * num_learner_devices sharded_next_dones = self.shard_split_playload(next_dones, 0) + # If the queue gets full at any point we prioritize taking new episodes. + # This also prevents the pipeline from stalling if the learner thread terminates + # before the actors finish putting the episodes in it. + if self._queue.full(): + self._queue.get() + self._queue.put((sharded_traj, sharded_next_obs, sharded_next_dones, time_dict)) with end_condition: @@ -172,3 +179,21 @@ def __enter__(self) -> None: def __exit__(self, *args: Any) -> None: end = time.monotonic() self.to.append(end - self.start) + + +def check_config(config: DictConfig) -> None: + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + + assert ( + config.arch.num_envs % len(config.arch.learner_device_ids) == 0 + ), "The number of environments must be divisible by the number of learners." + + assert ( + int(config.arch.num_envs / len(config.arch.learner_device_ids)) + * config.arch.n_threads_per_executor + % config.system.num_minibatches + == 0 + ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches." diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 6dcbf9963..e14389b24 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -196,7 +196,7 @@ def modify_space(self, space: spaces.Space) -> spaces.Space: class GymToJumanji(gymnasium.Wrapper): - """Converts from the Gym API to the dm_env API, Jumanji's Timestep type.""" + """Converts from the Gym API to the dm_env API, using Jumanji's Timestep type.""" def reset( self, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None From 23b582c6359d995f18f41b1c590e1146efc14c49 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 30 Jul 2024 12:44:26 +0100 Subject: [PATCH 091/125] chore : better error messeages --- mava/utils/sebulba_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mava/utils/sebulba_utils.py b/mava/utils/sebulba_utils.py index 8e84b4267..9077925a8 100644 --- a/mava/utils/sebulba_utils.py +++ b/mava/utils/sebulba_utils.py @@ -187,13 +187,16 @@ def check_config(config: DictConfig) -> None: ), "Number of updates per evaluation must be less than total number of updates." config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation - assert ( - config.arch.num_envs % len(config.arch.learner_device_ids) == 0 - ), "The number of environments must be divisible by the number of learners." + assert config.arch.num_envs % len(config.arch.learner_device_ids) == 0, ( + "Number of environments must be divisible by the number of learner." + + "The output of each actor is equally split across the learners." + ) - assert ( + num_eval_samples = ( int(config.arch.num_envs / len(config.arch.learner_device_ids)) - * config.arch.n_threads_per_executor - % config.system.num_minibatches - == 0 - ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches." + * config.system.rollout_length + ) + assert num_eval_samples % config.system.num_minibatches == 0, ( + f"Number of training samples per evaluator ({num_eval_samples})" + + f"must be divisible by num_minibatches ({config.system.num_minibatches})." + ) From c71dad86a0fd3d6c13f9ce2bdc173c73d88939fd Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 30 Jul 2024 13:23:44 +0100 Subject: [PATCH 092/125] fix: changed the timestep discount --- mava/wrappers/gym.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index e14389b24..ee4339afd 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -208,8 +208,9 @@ def reset( ep_done = np.zeros(num_envs, dtype=float) rewards = np.zeros((num_envs, num_agents), dtype=float) + teminated = np.zeros((num_envs, num_agents), dtype=float) - timestep = self._create_timestep(obs, ep_done, rewards, info) + timestep = self._create_timestep(obs, ep_done, teminated, rewards, info) return timestep @@ -218,7 +219,7 @@ def step(self, action: list) -> TimeStep: ep_done = np.logical_or(terminated, truncated).all(axis=1) - timestep = self._create_timestep(obs, ep_done, rewards, info) + timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) return timestep @@ -240,16 +241,17 @@ def _format_observation( return Observation(**obs_data) def _create_timestep( - self, obs: NDArray, ep_done: NDArray, rewards: NDArray, info: Dict + self, obs: NDArray, ep_done: NDArray, terminated: NDArray, rewards: NDArray, info: Dict ) -> TimeStep: obs = self._format_observation(obs, info) extras = jax.tree.map(lambda *x: np.stack(x), *info["metrics"]) step_type = np.where(ep_done, StepType.LAST, StepType.MID) + terminated = np.all(terminated, axis=1) return TimeStep( step_type=step_type, reward=rewards, - discount=1.0 - ep_done, + discount=1.0 - terminated, observation=obs, extras=extras, ) From bfea3aab662646a0a1dd71aaf4d433fefe5c2116 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 30 Jul 2024 16:03:03 +0200 Subject: [PATCH 093/125] chore: very nitpicky clean ups --- mava/systems/ppo/sebulba/ff_ippo.py | 171 +++++++++++----------------- 1 file changed, 67 insertions(+), 104 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 04aeda480..38cb2905b 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -27,9 +27,9 @@ import optax from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate from jax import tree from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState from rich.pretty import pprint from mava.evaluator import get_sebulba_eval_fn as get_eval_fn @@ -85,9 +85,10 @@ def rollout( """ # setup env = environments.make_gym_env(config, config.arch.num_envs) - current_actor_device = jax.devices()[actor_device_id] actor_apply_fn, critic_apply_fn = apply_fns num_agents, num_envs = config.system.num_agents, config.arch.num_envs + current_actor_device = jax.devices()[actor_device_id] + move_to_device = lambda x: jax.device_put(x, device=current_actor_device) # Define the util functions: select action function and prepare data to share it with learner. @jax.jit @@ -107,40 +108,31 @@ def get_action_and_value( return action, log_prob, value, key timestep = env.reset(seed=seeds) - next_dones = jnp.repeat(timestep.last(), num_agents).reshape(num_envs, -1) - move_to_device = lambda x: jax.device_put(x, device=current_actor_device) - # Loop till the desired num_updates is reached. while not thread_lifetime.should_stop(): # Rollout - traj: List = [] + traj: List[PPOTransition] = [] time_dict: Dict[str, List[float]] = { "single_rollout_time": [], "env_step_time": [], "get_params_time": [], - "put_rollout_time": [], + "rollout_put_time": [], } # Loop over the rollout length with RecordTimeTo(time_dict["single_rollout_time"]): for _ in range(config.system.rollout_length): - # Get the latest parameters from the learner - with RecordTimeTo(time_dict["get_params_time"]): + # Get the latest parameters from the learner params = params_source.get() cached_next_obs = tree.map(move_to_device, timestep.observation) cached_next_dones = move_to_device(next_dones) # Get action and value - ( - action, - log_prob, - value, - key, - ) = get_action_and_value(params, cached_next_obs, key) + action, log_prob, value, key = get_action_and_value(params, cached_next_obs, key) # Step the environment cpu_action = jax.device_get(action) @@ -152,19 +144,15 @@ def get_action_and_value( next_dones = jnp.repeat(timestep.last(), num_agents).reshape(num_envs, -1) # Append data to storage + reward = timestep.reward + info = timestep.extras traj.append( PPOTransition( - done=cached_next_dones, - action=action, - value=value, - reward=timestep.reward, - log_prob=log_prob, - obs=cached_next_obs, - info=timestep.extras, + cached_next_dones, action, value, reward, log_prob, cached_next_obs, info ) ) # send trajectories to learner - with RecordTimeTo(time_dict["put_rollout_time"]): + with RecordTimeTo(time_dict["rollout_put_time"]): rollout_queue.put(traj, timestep.observation, next_dones, time_dict) env.close() @@ -189,8 +177,7 @@ def _update_step( """A single update of the network. This function calculates advantages and targets based on the trajectories - from the actor and updates the actor and critic networks based on the - calculated losses. + from the actor and updates the actor and critic networks based on the losses. Args: learner_state (LearnerState): contains all the items needed for learning. @@ -222,7 +209,7 @@ def _get_advantages( ) return advantages, advantages + traj_batch.value - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, _, _ = learner_state last_val = critic_apply_fn(params.critic_params, last_obs) advantages, targets = _calculate_gae(traj_batch, last_val, last_dones) @@ -233,23 +220,22 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO + # Unpack train state and batch info params, opt_states, key = train_state traj_batch, advantages, targets = batch_info def _actor_loss_fn( actor_params: FrozenDict, - actor_opt_state: OptState, traj_batch: PPOTransition, gae: chex.Array, key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK + # Rerun network actor_policy = actor_apply_fn(actor_params, traj_batch.obs) log_prob = actor_policy.log_prob(traj_batch.action) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) gae = (gae - gae.mean()) / (gae.std() + 1e-8) loss_actor1 = ratio * gae @@ -270,16 +256,13 @@ def _actor_loss_fn( return total_loss_actor, (loss_actor, entropy) def _critic_loss_fn( - critic_params: FrozenDict, - critic_opt_state: OptState, - traj_batch: PPOTransition, - targets: chex.Array, + critic_params: FrozenDict, traj_batch: PPOTransition, targets: chex.Array ) -> Tuple: """Calculate the critic loss.""" - # RERUN NETWORK + # Rerun network value = critic_apply_fn(critic_params, traj_batch.obs) - # CALCULATE VALUE LOSS + # Calculate value loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) @@ -290,21 +273,17 @@ def _critic_loss_fn( critic_total_loss = config.system.vf_coef * value_loss return critic_total_loss, (value_loss) - # CALCULATE ACTOR LOSS + # Calculate actor loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( - params.actor_params, - opt_states.actor_opt_state, - traj_batch, - advantages, - entropy_key, + params.actor_params, traj_batch, advantages, entropy_key ) - # CALCULATE CRITIC LOSS + # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, opt_states.critic_opt_state, traj_batch, targets + params.critic_params, traj_batch, targets ) # Compute the parallel mean (pmean) over the batch. @@ -321,22 +300,22 @@ def _critic_loss_fn( (critic_grads, critic_loss_info), axis_name="learner_devices" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update actor params and optimiser state actor_updates, actor_new_opt_state = actor_update_fn( actor_grads, opt_states.actor_opt_state ) actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - # UPDATE CRITIC PARAMS AND OPTIMISER STATE + # Update critic params and optimiser state critic_updates, critic_new_opt_state = critic_update_fn( critic_grads, opt_states.critic_opt_state ) critic_new_params = optax.apply_updates(params.critic_params, critic_updates) - # PACK NEW PARAMS AND OPTIMISER STATE + # Pack new params and optimiser state new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO + # Pack loss info total_loss = actor_loss_info[0] + critic_loss_info[0] value_loss = critic_loss_info[1] actor_loss = actor_loss_info[1][0] @@ -351,21 +330,19 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES + # Shuffle minibatches batch_size = config.system.rollout_length * ( config.arch.num_envs // len(config.arch.learner_device_ids) ) permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) - shuffled_batch = jax.tree_util.tree_map( - lambda x: jnp.take(x, permutation, axis=0), batch - ) - minibatches = jax.tree_util.tree_map( + batch = tree.map(lambda x: merge_leading_dims(x, 2), batch) + shuffled_batch = tree.map(lambda x: jnp.take(x, permutation, axis=0), batch) + minibatches = tree.map( lambda x: jnp.reshape(x, (config.system.num_minibatches, -1, *x.shape[1:])), shuffled_batch, ) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches ) @@ -374,7 +351,7 @@ def _critic_loss_fn( return update_state, loss_info update_state = (params, opt_states, traj_batch, advantages, targets, key) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -418,7 +395,7 @@ def learner_fn( def learner_setup( - keys: chex.Array, config: DictConfig, learner_devices: List + key: chex.PRNGKey, config: DictConfig, learner_devices: List ) -> Tuple[ SebulbaLearnerFn[LearnerState, PPOTransition], Tuple[ActorApply, CriticApply], LearnerState ]: @@ -432,7 +409,7 @@ def learner_setup( config.system.num_actions = int(action_space[0].n) # PRNG keys. - key, actor_net_key, critic_net_key = keys + key, actor_key, critic_key = jax.random.split(key, 3) # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) @@ -462,11 +439,11 @@ def learner_setup( init_x = Observation(init_obs, init_action_mask) # Initialise actor params and optimiser state. - actor_params = actor_network.init(actor_net_key, init_x) + actor_params = actor_network.init(actor_key, init_x) actor_opt_state = actor_optim.init(actor_params) # Initialise critic params and optimiser state. - critic_params = critic_network.init(critic_net_key, init_x) + critic_params = critic_network.init(critic_key, init_x) critic_opt_state = critic_optim.init(critic_params) # Pack params. @@ -517,13 +494,13 @@ def learner( for _eval_step in range(config.arch.num_evaluation): metrics: List[Tuple[Dict, Dict]] = [] rollout_times: List[Dict] = [] - eval_times: Dict[str, List[float]] = {"rollout_get_time": [], "learning_time": []} + learn_times: Dict[str, List[float]] = {"rollout_get_time": [], "learning_time": []} for _update in range(config.system.num_updates_per_eval): - with RecordTimeTo(eval_times["rollout_get_time"]): + with RecordTimeTo(learn_times["rollout_get_time"]): traj_batch, last_obs, last_dones, rollout_time = pipeline.get(block=True) - with RecordTimeTo(eval_times["learning_time"]): + with RecordTimeTo(learn_times["learning_time"]): learner_state, episode_metrics, train_metrics = learn( learner_state, traj_batch, last_obs, last_dones ) @@ -531,7 +508,7 @@ def learner( metrics.append((episode_metrics, train_metrics)) rollout_times.append(rollout_time) - unreplicated_params = flax.jax_utils.unreplicate(learner_state.params) + unreplicated_params = unreplicate(learner_state.params) for source in params_sources: source.update(unreplicated_params) @@ -540,7 +517,7 @@ def learner( episode_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) rollout_times = tree.map(lambda *x: np.mean(x), *rollout_times) - timing_dict = rollout_times | eval_times + timing_dict = rollout_times | learn_times timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) eval_queue.put((episode_metrics, train_metrics, learner_state, timing_dict)) @@ -553,18 +530,12 @@ def run_experiment(_config: DictConfig) -> float: devices = jax.devices() learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] - # PRNG keys. - key, actor_net_key, critic_net_key = jax.random.split( - jax.random.PRNGKey(config.system.seed), num=3 - ) - - # Numpy RNG. + # JAX and numpy RNGs + key = jax.random.PRNGKey(config.system.seed) np_rng = np.random.default_rng(config.system.seed) # Setup learner. - learn, apply_fns, learner_state = learner_setup( - (key, actor_net_key, critic_net_key), config, learner_devices - ) + learn, apply_fns, learner_state = learner_setup(key, config, learner_devices) # Setup evaluator. # One key per device for evaluation. @@ -583,9 +554,9 @@ def run_experiment(_config: DictConfig) -> float: # Logger setup logger = MavaLogger(config) - cfg: Dict = OmegaConf.to_container(config, resolve=True) - cfg["arch"]["devices"] = jax.devices() - pprint(cfg) + print_cfg: Dict = OmegaConf.to_container(config, resolve=True) + print_cfg["arch"]["devices"] = jax.devices() + pprint(print_cfg) # Set up checkpointer save_checkpoint = config.logger.checkpointing.save_model @@ -597,13 +568,14 @@ def run_experiment(_config: DictConfig) -> float: ) # Executor setup and launch. - unreplicated_inital_params = flax.jax_utils.unreplicate(learner_state.params) + inital_params = unreplicate(learner_state.params) - pipeline_lifetime = ThreadLifetime() - pipeline = Pipeline(config.arch.rollout_queue_size, learner_devices, pipeline_lifetime) - pipeline.start() + # the rollout queue/ the pipe between actor and learner + pipe_lifetime = ThreadLifetime() + pipe = Pipeline(config.arch.rollout_queue_size, learner_devices, pipe_lifetime) + pipe.start() - params_sources: List[ParamsSource] = [] + param_sources: List[ParamsSource] = [] actor_threads: List[threading.Thread] = [] actors_lifetime = ThreadLifetime() params_sources_lifetime = ThreadLifetime() @@ -613,25 +585,16 @@ def run_experiment(_config: DictConfig) -> float: # Loop through each executor thread for thread_id in range(config.arch.n_threads_per_executor): seeds = np_rng.integers(np.iinfo(np.int32).max, size=config.arch.num_envs).tolist() + key, act_key = jax.random.split(key) + act_key = jax.device_put(key, devices[d_id]) - params_source = ParamsSource( - unreplicated_inital_params, devices[d_id], params_sources_lifetime - ) - params_source.start() - params_sources.append(params_source) + param_source = ParamsSource(inital_params, devices[d_id], params_sources_lifetime) + param_source.start() + param_sources.append(param_source) actor = threading.Thread( target=rollout, - args=( - jax.device_put(key, devices[d_id]), - config, - pipeline, - params_sources[-1], - apply_fns, - d_id, - seeds, - actors_lifetime, - ), + args=(act_key, config, pipe, param_source, apply_fns, d_id, seeds, actors_lifetime), name=f"Actor-{thread_id + d_idx * config.arch.n_threads_per_executor}", ) actor.start() @@ -641,11 +604,11 @@ def run_experiment(_config: DictConfig) -> float: threading.Thread( target=learner, name="Learner", - args=(learn, learner_state, config, eval_queue, pipeline, params_sources), + args=(learn, learner_state, config, eval_queue, pipe, param_sources), ).start() max_episode_return = -jnp.inf - best_params = unreplicated_inital_params.actor_params + best_params = inital_params.actor_params # This is the main loop, all it does is evaluation and logging. # Acting and learning is happening in their own threads. @@ -665,7 +628,7 @@ def run_experiment(_config: DictConfig) -> float: logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) - unreplicated_actor_params = flax.jax_utils.unreplicate(learner_state.params.actor_params) + unreplicated_actor_params = unreplicate(learner_state.params.actor_params) key, eval_key = jax.random.split(key, 2) eval_metrics = evaluator(unreplicated_actor_params, eval_key, {}) logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) @@ -692,12 +655,12 @@ def run_experiment(_config: DictConfig) -> float: for actor in actor_threads: actor.join() - pipeline_lifetime.stop() - pipeline.join() + pipe_lifetime.stop() + pipe.join() params_sources_lifetime.stop() - for params_source in params_sources: - params_source.join() + for param_source in param_sources: + param_source.join() # Measure absolute metric. if config.arch.absolute_metric: From de92f5a9e6f41825fabf8c0935215d5ae9f857bc Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 30 Jul 2024 16:30:55 +0200 Subject: [PATCH 094/125] feat: pass timestep instead of obs and done and fix potential race condition in pipeline --- mava/systems/ppo/sebulba/ff_ippo.py | 32 +++++++------------ mava/types.py | 4 +-- mava/utils/sebulba_utils.py | 49 ++++++++++------------------- 3 files changed, 30 insertions(+), 55 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 38cb2905b..f4905c1c6 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -153,7 +153,7 @@ def get_action_and_value( ) # send trajectories to learner with RecordTimeTo(time_dict["rollout_put_time"]): - rollout_queue.put(traj, timestep.observation, next_dones, time_dict) + rollout_queue.put(traj, timestep, time_dict) env.close() @@ -164,6 +164,8 @@ def get_learner_fn( ) -> SebulbaLearnerFn[LearnerState, PPOTransition]: """Get the learner function.""" + num_agents, num_envs = config.system.num_agents, config.arch.num_envs + # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns @@ -171,8 +173,6 @@ def get_learner_fn( def _update_step( learner_state: LearnerState, traj_batch: PPOTransition, - last_obs: Observation, - last_dones: chex.Array, ) -> Tuple[LearnerState, Tuple]: """A single update of the network. @@ -182,9 +182,6 @@ def _update_step( Args: learner_state (LearnerState): contains all the items needed for learning. traj_batch (PPOTransition): the batch of data to learn with. - last_obs (Observation): the final observations (for bootstrapping in GAE). - last_dones (Array): the final dones (for bootstrapping in GAE). - _ (Any): The current metrics info. """ def _calculate_gae( @@ -210,8 +207,9 @@ def _get_advantages( return advantages, advantages + traj_batch.value # Calculate advantage + last_dones = jnp.repeat(learner_state.timestep.last(), num_agents).reshape(num_envs, -1) params, opt_states, key, _, _ = learner_state - last_val = critic_apply_fn(params.critic_params, last_obs) + last_val = critic_apply_fn(params.critic_params, learner_state.timestep.observation) advantages, targets = _calculate_gae(traj_batch, last_val, last_dones) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: @@ -357,15 +355,12 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, None, None) + learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep) metric = traj_batch.info return learner_state, (metric, loss_info) def learner_fn( - learner_state: LearnerState, - traj_batch: PPOTransition, - last_obs: Observation, - last_dones: chex.Array, + learner_state: LearnerState, traj_batch: PPOTransition ) -> ExperimentOutput[LearnerState]: """Learner function. @@ -379,11 +374,9 @@ def learner_fn( - opt_states (OptStates): The initial optimizer state. - key (chex.PRNGKey): The random number generator state. - env_state (LogEnvState): The environment state. - - timesteps (TimeStep): The initial timestep in the initial trajectory. + - timesteps (TimeStep): The last timestep of the rollout. """ - learner_state, (episode_info, loss_info) = _update_step( - learner_state, traj_batch, last_obs, last_dones - ) + learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch) return ExperimentOutput( learner_state=learner_state, @@ -498,12 +491,11 @@ def learner( for _update in range(config.system.num_updates_per_eval): with RecordTimeTo(learn_times["rollout_get_time"]): - traj_batch, last_obs, last_dones, rollout_time = pipeline.get(block=True) + traj_batch, timestep, rollout_time = pipeline.get(block=True) + learner_state = learner_state._replace(timestep=timestep) with RecordTimeTo(learn_times["learning_time"]): - learner_state, episode_metrics, train_metrics = learn( - learner_state, traj_batch, last_obs, last_dones - ) + learner_state, episode_metrics, train_metrics = learn(learner_state, traj_batch) metrics.append((episode_metrics, train_metrics)) rollout_times.append(rollout_time) diff --git a/mava/types.py b/mava/types.py index fe51ce293..8a191f5ab 100644 --- a/mava/types.py +++ b/mava/types.py @@ -153,9 +153,7 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]): LearnerFn = Callable[[MavaState], ExperimentOutput[MavaState]] -SebulbaLearnerFn = Callable[ - [MavaState, MavaTransition, chex.Array, chex.Array], ExperimentOutput[MavaState] -] +SebulbaLearnerFn = Callable[[MavaState, MavaTransition], ExperimentOutput[MavaState]] ActorApply = Callable[[FrozenDict, Observation], Distribution] CriticApply = Callable[[FrozenDict, Observation], Value] RecActorApply = Callable[ diff --git a/mava/utils/sebulba_utils.py b/mava/utils/sebulba_utils.py index 9077925a8..b15edeba6 100644 --- a/mava/utils/sebulba_utils.py +++ b/mava/utils/sebulba_utils.py @@ -20,11 +20,10 @@ import jax import jax.numpy as jnp -from chex import Array +from jumanji.types import TimeStep from omegaconf import DictConfig from mava.systems.ppo.types import Params, PPOTransition # todo: remove the ppo dependencies -from mava.types import Observation, ObservationGlobalState # Copied from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py @@ -63,8 +62,7 @@ def __init__(self, max_size: int, learner_devices: List[jax.Device], lifetime: T self.lifetime = lifetime def run(self) -> None: - """ - This function ensures that trajectories on the queue are consumed in the right order. The + """This function ensures that trajectories on the queue are consumed in the right order. The start_condition and end_condition are used to ensure that only 1 thread is processing an item from the queue at one time, ensuring predictable memory usage. """ @@ -78,16 +76,8 @@ def run(self) -> None: except queue.Empty: continue - def put( - self, - traj: Sequence[PPOTransition], - next_obs: Union[Observation, ObservationGlobalState], - next_dones: Array, - time_dict: Dict, - ) -> None: - """ - Put a trajectory on the queue to be consumed by the learner. - """ + def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict) -> None: + """Put a trajectory on the queue to be consumed by the learner.""" start_condition, end_condition = (threading.Condition(), threading.Condition()) with start_condition: self.tickets_queue.put((start_condition, end_condition)) @@ -96,21 +86,18 @@ def put( # [PPOTransition()] * rollout_len --> PPOTransition[done=(rollout_len, num_envs, num_agents) sharded_traj = jax.tree.map(lambda *x: self.shard_split_playload(jnp.stack(x), 1), *traj) - # obs Tuple[(num_envs, num_agents, ...), ...] --> + # Timestep[(num_envs, num_agents, ...), ...] --> # [(num_envs / num_learner_devices, num_agents, ...)] * num_learner_devices - sharded_next_obs = jax.tree.map(self.shard_split_playload, next_obs) - - # dones (num_envs, num_agents) --> - # [(num_envs / num_learner_devices, num_agents)] * num_learner_devices - sharded_next_dones = self.shard_split_playload(next_dones, 0) + sharded_timestep = jax.tree.map(self.shard_split_playload, timestep) - # If the queue gets full at any point we prioritize taking new episodes. + # If the queue gets full at any point we prioritize taking removing the oldest rollouts. # This also prevents the pipeline from stalling if the learner thread terminates - # before the actors finish putting the episodes in it. - if self._queue.full(): - self._queue.get() + # with a full queue blocking the actors from placing items in it. + with self._queue.mutex: + if self._queue.maxsize >= self._queue._qsize(): # queue is full + self._queue.get() # throw away the transition - self._queue.put((sharded_traj, sharded_next_obs, sharded_next_dones, time_dict)) + self._queue.put((sharded_traj, sharded_timestep, time_dict)) with end_condition: end_condition.notify() # tell we have finish @@ -121,7 +108,7 @@ def qsize(self) -> int: def get( self, block: bool = True, timeout: Union[float, None] = None - ) -> Tuple[PPOTransition, Union[Observation, ObservationGlobalState], Array, Dict]: + ) -> Tuple[PPOTransition, TimeStep, Dict]: """Get a trajectory from the pipeline.""" return self._queue.get(block, timeout) # type: ignore @@ -131,8 +118,7 @@ def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: class ParamsSource(threading.Thread): - """ - A `ParamSource` is a component that allows networks params to be passed from a + """A `ParamSource` is a component that allows networks params to be passed from a `Learner` component to `Actor` components. """ @@ -144,8 +130,7 @@ def __init__(self, init_value: Params, device: jax.Device, lifetime: ThreadLifet self.lifetime = lifetime def run(self) -> None: - """ - This function is responsible for updating the value of the `ParamSource` when a new value + """This function is responsible for updating the value of the `ParamSource` when a new value is available. """ while not self.lifetime.should_stop(): @@ -156,8 +141,7 @@ def run(self) -> None: continue def update(self, new_params: Params) -> None: - """ - Update the value of the `ParamSource` with a new value. + """Update the value of the `ParamSource` with a new value. Args: new_params: The new value to update the `ParamSource` with. @@ -182,6 +166,7 @@ def __exit__(self, *args: Any) -> None: def check_config(config: DictConfig) -> None: + """Checks that the given config does not have conflicting values.""" assert ( config.system.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." From 1465133381431d5ead3d9f1189c0d434254ca7d1 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 30 Jul 2024 16:35:24 +0200 Subject: [PATCH 095/125] fix: deadlock in pipeline --- mava/utils/sebulba_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mava/utils/sebulba_utils.py b/mava/utils/sebulba_utils.py index b15edeba6..a25d1c117 100644 --- a/mava/utils/sebulba_utils.py +++ b/mava/utils/sebulba_utils.py @@ -93,9 +93,8 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict # If the queue gets full at any point we prioritize taking removing the oldest rollouts. # This also prevents the pipeline from stalling if the learner thread terminates # with a full queue blocking the actors from placing items in it. - with self._queue.mutex: - if self._queue.maxsize >= self._queue._qsize(): # queue is full - self._queue.get() # throw away the transition + if self._queue.full(): + self._queue.get() # throw away the transition self._queue.put((sharded_traj, sharded_timestep, time_dict)) From 6689c4951157909780b63a17889f51cdc0256ee0 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sun, 11 Aug 2024 14:16:55 +0100 Subject: [PATCH 096/125] fix: wasting samples --- mava/systems/ppo/sebulba/ff_ippo.py | 12 +++++++++++- mava/utils/sebulba_utils.py | 18 +++++++----------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index f4905c1c6..f05d3cbdc 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -13,7 +13,9 @@ # limitations under the License. import copy +import queue import threading +import warnings from queue import Queue from typing import Any, Dict, List, Sequence, Tuple @@ -153,7 +155,15 @@ def get_action_and_value( ) # send trajectories to learner with RecordTimeTo(time_dict["rollout_put_time"]): - rollout_queue.put(traj, timestep, time_dict) + try: + rollout_queue.put(traj, timestep, time_dict) + except queue.Full: + warnings.warn( + "Waited too long to add to the rollout queue, killing the actor thread", + stacklevel=2, + ) + break + env.close() diff --git a/mava/utils/sebulba_utils.py b/mava/utils/sebulba_utils.py index a25d1c117..041843d95 100644 --- a/mava/utils/sebulba_utils.py +++ b/mava/utils/sebulba_utils.py @@ -83,23 +83,19 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict self.tickets_queue.put((start_condition, end_condition)) start_condition.wait() # wait to be allowed to start - # [PPOTransition()] * rollout_len --> PPOTransition[done=(rollout_len, num_envs, num_agents) + # [PPOTransition()] * rollout_len --> PPOTransition[done=(rollout_len, num_envs, ...)] sharded_traj = jax.tree.map(lambda *x: self.shard_split_playload(jnp.stack(x), 1), *traj) # Timestep[(num_envs, num_agents, ...), ...] --> # [(num_envs / num_learner_devices, num_agents, ...)] * num_learner_devices sharded_timestep = jax.tree.map(self.shard_split_playload, timestep) - # If the queue gets full at any point we prioritize taking removing the oldest rollouts. - # This also prevents the pipeline from stalling if the learner thread terminates - # with a full queue blocking the actors from placing items in it. - if self._queue.full(): - self._queue.get() # throw away the transition - - self._queue.put((sharded_traj, sharded_timestep, time_dict)) - - with end_condition: - end_condition.notify() # tell we have finish + # The lock has to be released even if an exception is raised. + try: + self._queue.put((sharded_traj, sharded_timestep, time_dict), timeout=90) + finally: + with end_condition: + end_condition.notify() # tell we have finish def qsize(self) -> int: """Returns the number of trajectories in the pipeline.""" From c506da30201201599c5ee00fa4c04ef5e73157ba Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sun, 11 Aug 2024 14:43:21 +0100 Subject: [PATCH 097/125] chore: loss unpacking --- mava/configs/arch/sebulba.yaml | 2 +- mava/systems/ppo/sebulba/ff_ippo.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 0c1c8880d..eafeba202 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -9,7 +9,7 @@ evaluation_greedy: False # Evaluate the policy greedily. If True the policy will # an action which corresponds to the greatest logit. If false, the policy will sample # from the logits. num_eval_episodes: 200 # Number of episodes to evaluate per evaluation. -num_evaluation: 10 # Number of evenly spaced evaluations to perform during training. +num_evaluation: 200 # Number of evenly spaced evaluations to perform during training. num_absolute_metric_eval_episodes: 32 # Number of episodes to evaluate the absolute metric (the final evaluation). absolute_metric: True # Whether the absolute metric should be computed. For more details # on the absolute metric please see: https://arxiv.org/abs/2209.10485 diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index f05d3cbdc..06aa268a8 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -324,10 +324,9 @@ def _critic_loss_fn( new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) # Pack loss info - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + actor_total_loss, (actor_loss, entropy) = actor_loss_info + critic_total_loss, (value_loss) = critic_loss_info + total_loss = critic_total_loss + actor_total_loss loss_info = { "total_loss": total_loss, "value_loss": value_loss, From b24ac34e3ae3e38094522c915f5bd659773fc066 Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Thu, 10 Oct 2024 17:13:21 +0100 Subject: [PATCH 098/125] fix: updated to work with the latest gymnasium --- mava/systems/ppo/sebulba/ff_ippo.py | 4 ++-- mava/wrappers/gym.py | 26 ++++++++++++++++++++++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 06aa268a8..ed85de3bf 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -406,7 +406,7 @@ def learner_setup( # create temporory envoirnments. env = environments.make_gym_env(config, config.arch.num_envs) # Get number of agents and actions. - action_space = env.unwrapped.single_action_space + action_space = env.single_action_space config.system.num_agents = len(action_space) config.system.num_actions = int(action_space[0].n) @@ -436,7 +436,7 @@ def learner_setup( ) # Initialise observation: Select only obs for a single agent. - init_obs = jnp.array([env.unwrapped.single_observation_space.sample()]) + init_obs = jnp.array([env.single_observation_space.sample()]) init_action_mask = jnp.ones((config.system.num_agents, config.system.num_actions)) init_x = Observation(init_obs, init_action_mask) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index ee4339afd..2756b3511 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -20,9 +20,10 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import gymnasium -import jax +import gymnasium.vector.async_vector_env import numpy as np from gymnasium import spaces +from gymnasium.spaces.utils import is_space_dtype_shape_equiv from gymnasium.vector.utils import write_to_shared_memory from jumanji.types import StepType, TimeStep from numpy.typing import NDArray @@ -195,9 +196,14 @@ def modify_space(self, space: spaces.Space) -> spaces.Space: raise ValueError(f"Space {type(space)} is not currently supported.") -class GymToJumanji(gymnasium.Wrapper): +class GymToJumanji: """Converts from the Gym API to the dm_env API, using Jumanji's Timestep type.""" + def __init__(self, env: gymnasium.vector.async_vector_env): + self.env = env + self.single_action_space = env.unwrapped.single_action_space + self.single_observation_space = env.unwrapped.single_observation_space + def reset( self, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None ) -> TimeStep: @@ -244,7 +250,8 @@ def _create_timestep( self, obs: NDArray, ep_done: NDArray, terminated: NDArray, rewards: NDArray, info: Dict ) -> TimeStep: obs = self._format_observation(obs, info) - extras = jax.tree.map(lambda *x: np.stack(x), *info["metrics"]) + # Filter out the masks and auxiliary data + extras = {key: value for key, value in info["metrics"].items() if key[0] != "_"} step_type = np.where(ep_done, StepType.LAST, StepType.MID) terminated = np.all(terminated, axis=1) @@ -256,6 +263,9 @@ def _create_timestep( extras=extras, ) + def close(self) -> None: + self.env.close() + # Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py # Modified to work with multiple agents @@ -321,9 +331,17 @@ def async_multiagent_worker( # CCR001 env.set_wrapper_attr(name, value) pipe.send((None, True)) elif command == "_check_spaces": + obs_mode, single_obs_space, single_action_space = data pipe.send( ( - (data[0] == observation_space, data[1] == action_space), + ( + ( + single_obs_space == observation_space + if obs_mode == "same" + else is_space_dtype_shape_equiv(single_obs_space, observation_space) + ), + single_action_space == action_space, + ), True, ) ) From 1dfb24105d0c3593e4c139e68bf7d79d91a1df2f Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Thu, 10 Oct 2024 18:32:55 +0100 Subject: [PATCH 099/125] fix: jumanji --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0c68a3ca5..98a9f9912 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -9,7 +9,7 @@ id-marl-eval @ git+https://github.com/instadeepai/marl-eval jax jaxlib jaxmarl -jumanji @ git+https://github.com/sash-a/jumanji +jumanji @ git+https://github.com/sash-a/jumanji@old_jumanji lbforaging @ git+https://github.com/LukasSchaefer/lb-foraging.git@gymnasium_integration # fixes: https://github.com/semitable/lb-foraging/issues/20 matrax @ git+https://github.com/instadeepai/matrax mujoco==3.1.3 From fd8aece0d3590695e895f8047d916ff304c6d547 Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Thu, 10 Oct 2024 18:43:56 +0100 Subject: [PATCH 100/125] fix: removed depricated gymnasium import --- mava/utils/make_env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 405cb73b8..a5010307a 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -17,7 +17,6 @@ import gymnasium import gymnasium.vector import gymnasium.wrappers -import gymnasium.wrappers.compatibility import jaxmarl import jumanji import matrax From ae5341548738963673f9b4afc97b846bf24ec72b Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Thu, 10 Oct 2024 14:21:06 +0200 Subject: [PATCH 101/125] feat: minor refactor to sebulba utils --- mava/systems/ppo/anakin/ff_ippo.py | 2 +- mava/systems/ppo/anakin/ff_mappo.py | 2 +- mava/systems/ppo/anakin/rec_ippo.py | 2 +- mava/systems/ppo/anakin/rec_mappo.py | 2 +- mava/systems/ppo/sebulba/ff_ippo.py | 26 ++----- mava/systems/ppo/types.py | 4 +- mava/systems/q_learning/anakin/rec_iql.py | 2 +- mava/systems/sac/anakin/ff_isac.py | 2 +- mava/systems/sac/anakin/ff_masac.py | 2 +- .../{total_timestep_checker.py => config.py} | 22 ++++++ mava/utils/{sebulba_utils.py => sebulba.py} | 75 +++++++++++-------- 11 files changed, 84 insertions(+), 57 deletions(-) rename mava/utils/{total_timestep_checker.py => config.py} (67%) rename mava/utils/{sebulba_utils.py => sebulba.py} (70%) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index 49c969cdb..6fabdd715 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -35,13 +35,13 @@ from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import ( merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index cafa42888..ad14a2968 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -34,9 +34,9 @@ from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index 230756295..0c1a161fc 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -48,9 +48,9 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index 53ae7c65d..a83897a07 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -48,9 +48,9 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index ed85de3bf..fd13bbb19 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -39,25 +39,13 @@ from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition -from mava.types import ( - ActorApply, - CriticApply, - ExperimentOutput, - Observation, - SebulbaLearnerFn, -) +from mava.types import ActorApply, CriticApply, ExperimentOutput, Observation, SebulbaLearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_sebulba_config, check_total_timesteps from mava.utils.jax_utils import merge_leading_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.sebulba_utils import ( - ParamsSource, - Pipeline, - RecordTimeTo, - ThreadLifetime, - check_config, -) -from mava.utils.total_timestep_checker import check_total_timesteps +from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -95,7 +83,7 @@ def rollout( # Define the util functions: select action function and prepare data to share it with learner. @jax.jit def get_action_and_value( - params: FrozenDict, + params: Params, observation: Observation, key: chex.PRNGKey, ) -> Tuple: @@ -147,7 +135,8 @@ def get_action_and_value( # Append data to storage reward = timestep.reward - info = timestep.extras + info = timestep.extras # todo: [metrics]? + # todo: when logging make sure timing dict has parent timing/... traj.append( PPOTransition( cached_next_dones, action, value, reward, log_prob, cached_next_obs, info @@ -547,7 +536,7 @@ def run_experiment(_config: DictConfig) -> float: # Calculate total timesteps. config = check_total_timesteps(config) - check_config(config) + check_sebulba_config(config) steps_per_rollout = ( config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval @@ -674,6 +663,7 @@ def run_experiment(_config: DictConfig) -> float: t = int(steps_per_rollout * (eval_step + 1)) logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) abs_metric_evaluator_envs.close() + # Stop the logger. logger.stop() diff --git a/mava/systems/ppo/types.py b/mava/systems/ppo/types.py index f129b89d3..c8145b1a7 100644 --- a/mava/systems/ppo/types.py +++ b/mava/systems/ppo/types.py @@ -20,7 +20,7 @@ from optax._src.base import OptState from typing_extensions import NamedTuple -from mava.types import Action, Done, HiddenState, State, Value +from mava.types import Action, Done, HiddenState, Observation, State, Value class Params(NamedTuple): @@ -74,7 +74,7 @@ class PPOTransition(NamedTuple): value: Value reward: chex.Array log_prob: chex.Array - obs: chex.Array + obs: Observation info: Dict diff --git a/mava/systems/q_learning/anakin/rec_iql.py b/mava/systems/q_learning/anakin/rec_iql.py index 05b860d85..f37a20c5f 100644 --- a/mava/systems/q_learning/anakin/rec_iql.py +++ b/mava/systems/q_learning/anakin/rec_iql.py @@ -48,13 +48,13 @@ from mava.types import Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import ( switch_leading_axes, unreplicate_batch_dim, unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics diff --git a/mava/systems/sac/anakin/ff_isac.py b/mava/systems/sac/anakin/ff_isac.py index 955725e00..b767c98e3 100644 --- a/mava/systems/sac/anakin/ff_isac.py +++ b/mava/systems/sac/anakin/ff_isac.py @@ -49,9 +49,9 @@ from mava.types import MarlEnv, Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics diff --git a/mava/systems/sac/anakin/ff_masac.py b/mava/systems/sac/anakin/ff_masac.py index 2df296be4..296822b3a 100644 --- a/mava/systems/sac/anakin/ff_masac.py +++ b/mava/systems/sac/anakin/ff_masac.py @@ -50,9 +50,9 @@ from mava.utils import make_env as environments from mava.utils.centralised_training import get_joint_action, get_updated_joint_actions from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics diff --git a/mava/utils/total_timestep_checker.py b/mava/utils/config.py similarity index 67% rename from mava/utils/total_timestep_checker.py rename to mava/utils/config.py index e48e40923..23484311b 100644 --- a/mava/utils/total_timestep_checker.py +++ b/mava/utils/config.py @@ -18,6 +18,28 @@ from omegaconf import DictConfig +def check_sebulba_config(config: DictConfig) -> None: + """Checks that the given config does not have conflicting values.""" + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + + assert config.arch.num_envs % len(config.arch.learner_device_ids) == 0, ( + "Number of environments must be divisible by the number of learner." + + "The output of each actor is equally split across the learners." + ) + + num_eval_samples = ( + int(config.arch.num_envs / len(config.arch.learner_device_ids)) + * config.system.rollout_length + ) + assert num_eval_samples % config.system.num_minibatches == 0, ( + f"Number of training samples per evaluator ({num_eval_samples})" + + f"must be divisible by num_minibatches ({config.system.num_minibatches})." + ) + + def check_total_timesteps(config: DictConfig) -> DictConfig: """Check if total_timesteps is set, if not, set it based on the other parameters""" diff --git a/mava/utils/sebulba_utils.py b/mava/utils/sebulba.py similarity index 70% rename from mava/utils/sebulba_utils.py rename to mava/utils/sebulba.py index 041843d95..eee211828 100644 --- a/mava/utils/sebulba_utils.py +++ b/mava/utils/sebulba.py @@ -20,13 +20,16 @@ import jax import jax.numpy as jnp +from colorama import Fore, Style +from jax import tree from jumanji.types import TimeStep -from omegaconf import DictConfig -from mava.systems.ppo.types import Params, PPOTransition # todo: remove the ppo dependencies +# todo: remove the ppo dependencies +from mava.systems.ppo.types import Params, PPOTransition + +QUEUE_PUT_TIMEOUT = 180 -# Copied from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py class ThreadLifetime: """Simple class for a mutable boolean that can be used to signal a thread to stop.""" @@ -40,6 +43,14 @@ def stop(self) -> None: self._stop = True +@jax.jit +def _stack_trajectory(trajectory: List[PPOTransition]) -> PPOTransition: + """Stack a list of parallel_env transitions into a single + transition of shape [rollout_len, num_envs, ...].""" + return tree.map(lambda *x: jnp.stack(x, axis=0), *trajectory) # type: ignore + + +# Modified from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py class Pipeline(threading.Thread): """ The `Pipeline` shards trajectories into `learner_devices`, @@ -54,6 +65,7 @@ def __init__(self, max_size: int, learner_devices: List[jax.Device], lifetime: T Args: max_size: The maximum number of trajectories to keep in the pipeline. learner_devices: The devices to shard trajectories across. + lifetime: A `ThreadLifetime` which is used to stop this thread. """ super().__init__(name="Pipeline") self.learner_devices = learner_devices @@ -83,19 +95,39 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict self.tickets_queue.put((start_condition, end_condition)) start_condition.wait() # wait to be allowed to start - # [PPOTransition()] * rollout_len --> PPOTransition[done=(rollout_len, num_envs, ...)] - sharded_traj = jax.tree.map(lambda *x: self.shard_split_playload(jnp.stack(x), 1), *traj) + # [Transition(num_envs)] * rollout_len --> Transition[done=(rollout_len, num_envs, ...)] + traj = _stack_trajectory(traj) + # Split trajectory on the num envs axis so each learner device gets a valid full rollout + sharded_traj = jax.tree.map(lambda x: self.shard_split_playload(x, axis=1), traj) # Timestep[(num_envs, num_agents, ...), ...] --> # [(num_envs / num_learner_devices, num_agents, ...)] * num_learner_devices sharded_timestep = jax.tree.map(self.shard_split_playload, timestep) - # The lock has to be released even if an exception is raised. + # We block on the put to ensure that actors wait for the learners to catch up. This does two + # things: + # 1. It ensures that the actors don't get too far ahead of the learners, which could lead to + # off-policy data. + # 2. It ensures that the actors don't in a sense "waste" samples and their time by + # generating samples that the learners can't consume. + # However, we put a timeout of 180 seconds to avoid deadlocks in case the learner + # is not consuming the data. This is a safety measure and should not be hit in normal + # operation. We use a try-finally since the lock has to be released even if an exception + # is raised. try: - self._queue.put((sharded_traj, sharded_timestep, time_dict), timeout=90) + self._queue.put( + (sharded_traj, sharded_timestep, time_dict), + block=True, + timeout=QUEUE_PUT_TIMEOUT, + ) + except queue.Full: # todo: check if this is needed because we catch this exception outside + print( + f"{Fore.RED}{Style.BRIGHT}Pipeline is full and actor has timed out, " + f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}" + ) finally: with end_condition: - end_condition.notify() # tell we have finish + end_condition.notify() # notify that we have finished def qsize(self) -> int: """Returns the number of trajectories in the pipeline.""" @@ -107,6 +139,11 @@ def get( """Get a trajectory from the pipeline.""" return self._queue.get(block, timeout) # type: ignore + def clear(self) -> None: + """Clear the pipeline.""" + while not self._queue.empty(): + self._queue.get() + def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: split_payload = jnp.split(payload, len(self.learner_devices), axis=axis) return jax.device_put_sharded(split_payload, devices=self.learner_devices) @@ -158,25 +195,3 @@ def __enter__(self) -> None: def __exit__(self, *args: Any) -> None: end = time.monotonic() self.to.append(end - self.start) - - -def check_config(config: DictConfig) -> None: - """Checks that the given config does not have conflicting values.""" - assert ( - config.system.num_updates > config.arch.num_evaluation - ), "Number of updates per evaluation must be less than total number of updates." - config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation - - assert config.arch.num_envs % len(config.arch.learner_device_ids) == 0, ( - "Number of environments must be divisible by the number of learner." - + "The output of each actor is equally split across the learners." - ) - - num_eval_samples = ( - int(config.arch.num_envs / len(config.arch.learner_device_ids)) - * config.system.rollout_length - ) - assert num_eval_samples % config.system.num_minibatches == 0, ( - f"Number of training samples per evaluator ({num_eval_samples})" - + f"must be divisible by num_minibatches ({config.system.num_minibatches})." - ) From 724d2dc335a81aa44cd0b845a0c83eff1ccd9d17 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Thu, 10 Oct 2024 20:25:33 +0200 Subject: [PATCH 102/125] chore: a few minor changes to code style --- mava/configs/arch/sebulba.yaml | 2 +- mava/systems/ppo/sebulba/ff_ippo.py | 287 ++++++++++++++++------------ 2 files changed, 161 insertions(+), 128 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index eafeba202..d8f44fd3c 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -16,7 +16,7 @@ absolute_metric: True # Whether the absolute metric should be computed. For more # --- Sebulba devices config --- n_threads_per_executor: 2 # num of different threads/env batches per actor -executor_device_ids: [0] # ids of actor devices +actor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices rollout_queue_size : 5 # The size of the pipeline queue determines the extent of off-policy training allowed. A larger value permits more off-policy training. diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index fd13bbb19..311bb263f 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -16,6 +16,7 @@ import queue import threading import warnings +from collections import defaultdict from queue import Queue from typing import Any, Dict, List, Sequence, Tuple @@ -43,7 +44,7 @@ from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_sebulba_config, check_total_timesteps -from mava.utils.jax_utils import merge_leading_dims +from mava.utils.jax_utils import merge_leading_dims, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime from mava.utils.training import make_learning_rate @@ -56,7 +57,7 @@ def rollout( rollout_queue: Pipeline, params_source: ParamsSource, apply_fns: Tuple[ActorApply, CriticApply], - actor_device_id: int, + actor_device: int, seeds: List[int], thread_lifetime: ThreadLifetime, ) -> None: @@ -69,7 +70,7 @@ def rollout( params_source (ParamsSource): Source for fetching the latest network parameters from the learner. apply_fns (Tuple): Functions for running the actor and critic networks. - actor_device_id (int): Device ID for this actor thread. + actor_device (Device): Actor device to use for rollout. seeds (List[int]): Seeds for initializing the environment. thread_lifetime (ThreadLifetime): Manages the thread's lifecycle. """ @@ -77,86 +78,85 @@ def rollout( env = environments.make_gym_env(config, config.arch.num_envs) actor_apply_fn, critic_apply_fn = apply_fns num_agents, num_envs = config.system.num_agents, config.arch.num_envs - current_actor_device = jax.devices()[actor_device_id] - move_to_device = lambda x: jax.device_put(x, device=current_actor_device) + move_to_device = lambda x: jax.device_put(x, device=actor_device) # Define the util functions: select action function and prepare data to share it with learner. @jax.jit - def get_action_and_value( + def act_fn( params: Params, observation: Observation, key: chex.PRNGKey, ) -> Tuple: """Get action and value.""" - key, subkey = jax.random.split(key) - actor_policy = actor_apply_fn(params.actor_params, observation) - action = actor_policy.sample(seed=subkey) + action = actor_policy.sample(seed=key) log_prob = actor_policy.log_prob(action) value = critic_apply_fn(params.critic_params, observation).squeeze() - return action, log_prob, value, key + return action, log_prob, value timestep = env.reset(seed=seeds) next_dones = jnp.repeat(timestep.last(), num_agents).reshape(num_envs, -1) - # Loop till the desired num_updates is reached. - while not thread_lifetime.should_stop(): - # Rollout - traj: List[PPOTransition] = [] - time_dict: Dict[str, List[float]] = { - "single_rollout_time": [], - "env_step_time": [], - "get_params_time": [], - "rollout_put_time": [], - } - - # Loop over the rollout length - with RecordTimeTo(time_dict["single_rollout_time"]): - for _ in range(config.system.rollout_length): - with RecordTimeTo(time_dict["get_params_time"]): - # Get the latest parameters from the learner - params = params_source.get() - - cached_next_obs = tree.map(move_to_device, timestep.observation) - cached_next_dones = move_to_device(next_dones) - - # Get action and value - action, log_prob, value, key = get_action_and_value(params, cached_next_obs, key) - - # Step the environment - cpu_action = jax.device_get(action) - - with RecordTimeTo(time_dict["env_step_time"]): - # (num_env, num_agents) --> (num_agents, num_env) - timestep = env.step(cpu_action.swapaxes(0, 1)) - - next_dones = jnp.repeat(timestep.last(), num_agents).reshape(num_envs, -1) - - # Append data to storage - reward = timestep.reward - info = timestep.extras # todo: [metrics]? - # todo: when logging make sure timing dict has parent timing/... - traj.append( - PPOTransition( - cached_next_dones, action, value, reward, log_prob, cached_next_obs, info + with jax.default_device(actor_device): + # Loop till the desired num_updates is reached. + while not thread_lifetime.should_stop(): + # Rollout + traj: List[PPOTransition] = [] + actor_timings: Dict[str, List[float]] = defaultdict(list) + # Loop over the rollout length + with RecordTimeTo(actor_timings["rollout_time"]): + for _ in range(config.system.rollout_length): + with RecordTimeTo(actor_timings["get_params_time"]): + # Get the latest parameters from the learner + params = params_source.get() + + cached_next_obs = tree.map(move_to_device, timestep.observation) + cached_next_dones = move_to_device(next_dones) + + # Get action and value + with RecordTimeTo(actor_timings["compute_action_time"]): + key, act_key = jax.random.split(key) + action, log_prob, value = act_fn(params, cached_next_obs, act_key) + cpu_action = jax.device_get(action) + + # Step environment + with RecordTimeTo(actor_timings["env_step_time"]): + # (num_env, num_agents) --> (num_agents, num_env) + timestep = env.step(cpu_action.swapaxes(0, 1)) + + next_dones = jnp.repeat(timestep.last(), num_agents).reshape(num_envs, -1) + + # Append data to storage + reward = timestep.reward + info = timestep.extras # todo: [metrics]? + # todo: when logging make sure timing dict has parent timing/... + traj.append( + PPOTransition( + cached_next_dones, + action, + value, + reward, + log_prob, + cached_next_obs, + info, + ) ) - ) - # send trajectories to learner - with RecordTimeTo(time_dict["rollout_put_time"]): - try: - rollout_queue.put(traj, timestep, time_dict) - except queue.Full: - warnings.warn( - "Waited too long to add to the rollout queue, killing the actor thread", - stacklevel=2, - ) - break + # send trajectories to learner + with RecordTimeTo(actor_timings["rollout_put_time"]): + try: + rollout_queue.put(traj, timestep, actor_timings) + except queue.Full: + warnings.warn( + "Waited too long to add to the rollout queue, killing the actor thread", + stacklevel=2, + ) + break env.close() -def get_learner_fn( +def get_learner_step_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, @@ -385,6 +385,54 @@ def learner_fn( return learner_fn +def learner_thread( + learn_fn: SebulbaLearnerFn[LearnerState, PPOTransition], + learner_state: LearnerState, + config: DictConfig, + eval_queue: Queue, + pipeline: Pipeline, + params_sources: Sequence[ParamsSource], +) -> None: + for _ in range(config.arch.num_evaluation): + # Create the lists to store metrics and timings for this learning iteration. + metrics: List[Tuple[Dict, Dict]] = [] + rollout_times: List[Dict] = [] + learn_times: Dict[str, List[float]] = defaultdict(list) + + with RecordTimeTo(learn_times["learner_time_per_eval"]): + for _ in range(config.system.num_updates_per_eval): + # Get the trajectory batch from the pipeline + # This is blocking so it will wait until the pipeline has data. + with RecordTimeTo(learn_times["rollout_get_time"]): + traj_batch, timestep, rollout_time = pipeline.get(block=True) + + # Replace the timestep in the learner state with the latest timestep + # This means the learner has access to the entire trajectory as well as + # an additional timestep which it can use to bootstrap. + learner_state = learner_state._replace(timestep=timestep) + # Update the networks + with RecordTimeTo(learn_times["learning_time"]): + learner_state, episode_metrics, train_metrics = learn_fn( + learner_state, traj_batch + ) + + metrics.append((episode_metrics, train_metrics)) + rollout_times.append(rollout_time) + + # Update all the params sources so all actors can get the latest params + unreplicated_params = unreplicate(learner_state.params) + for source in params_sources: + source.update(unreplicated_params) + + # Pass all the metrics and params to the main thread (evaluator) for logging and evaluation + episode_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) + rollout_times = tree.map(lambda *x: np.mean(x), *rollout_times) + timing_dict = rollout_times | learn_times + timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) + + eval_queue.put((episode_metrics, train_metrics, learner_state, timing_dict)) + + def learner_setup( key: chex.PRNGKey, config: DictConfig, learner_devices: List ) -> Tuple[ @@ -444,7 +492,7 @@ def learner_setup( apply_fns = (actor_network.apply, critic_network.apply) update_fns = (actor_optim.update, critic_optim.update) - learn = get_learner_fn(apply_fns, update_fns, config) + learn = get_learner_step_fn(apply_fns, update_fns, config) learn = jax.pmap(learn, axis_name="learner_devices", devices=learner_devices) # Load model from checkpoint if specified. @@ -474,51 +522,16 @@ def learner_setup( return learn, apply_fns, init_learner_state -def learner( - learn: SebulbaLearnerFn[LearnerState, PPOTransition], - learner_state: LearnerState, - config: DictConfig, - eval_queue: Queue, - pipeline: Pipeline, - params_sources: Sequence[ParamsSource], -) -> None: - for _eval_step in range(config.arch.num_evaluation): - metrics: List[Tuple[Dict, Dict]] = [] - rollout_times: List[Dict] = [] - learn_times: Dict[str, List[float]] = {"rollout_get_time": [], "learning_time": []} - - for _update in range(config.system.num_updates_per_eval): - with RecordTimeTo(learn_times["rollout_get_time"]): - traj_batch, timestep, rollout_time = pipeline.get(block=True) - - learner_state = learner_state._replace(timestep=timestep) - with RecordTimeTo(learn_times["learning_time"]): - learner_state, episode_metrics, train_metrics = learn(learner_state, traj_batch) - - metrics.append((episode_metrics, train_metrics)) - rollout_times.append(rollout_time) - - unreplicated_params = unreplicate(learner_state.params) - - for source in params_sources: - source.update(unreplicated_params) - - # Pass to the evaluator - episode_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) - - rollout_times = tree.map(lambda *x: np.mean(x), *rollout_times) - timing_dict = rollout_times | learn_times - timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) - - eval_queue.put((episode_metrics, train_metrics, learner_state, timing_dict)) - - def run_experiment(_config: DictConfig) -> float: """Runs experiment.""" config = copy.deepcopy(_config) + local_devices = jax.local_devices() devices = jax.devices() + err = "Local and global devices must be the same, we dont support multihost yet" + assert len(local_devices) == len(devices), err learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] + actor_devices = [local_devices[device_id] for device_id in config.arch.actor_device_ids] # JAX and numpy RNGs key = jax.random.PRNGKey(config.system.seed) @@ -565,36 +578,45 @@ def run_experiment(_config: DictConfig) -> float: pipe = Pipeline(config.arch.rollout_queue_size, learner_devices, pipe_lifetime) pipe.start() - param_sources: List[ParamsSource] = [] + params_sources: List[ParamsSource] = [] actor_threads: List[threading.Thread] = [] - actors_lifetime = ThreadLifetime() + actor_lifetime = ThreadLifetime() params_sources_lifetime = ThreadLifetime() # Create the actor threads - for d_idx, d_id in enumerate(config.arch.executor_device_ids): - # Loop through each executor thread + for actor_device in actor_devices: + # Create 1 params source per device + params_source = ParamsSource(inital_params, actor_device, params_sources_lifetime) + params_source.start() + params_sources.append(params_source) + # Create multiple rollout threads per actor device for thread_id in range(config.arch.n_threads_per_executor): - seeds = np_rng.integers(np.iinfo(np.int32).max, size=config.arch.num_envs).tolist() key, act_key = jax.random.split(key) - act_key = jax.device_put(key, devices[d_id]) - - param_source = ParamsSource(inital_params, devices[d_id], params_sources_lifetime) - param_source.start() - param_sources.append(param_source) + seeds = np_rng.integers(np.iinfo(np.int32).max, size=config.arch.num_envs).tolist() + act_key = jax.device_put(key, actor_device) actor = threading.Thread( target=rollout, - args=(act_key, config, pipe, param_source, apply_fns, d_id, seeds, actors_lifetime), - name=f"Actor-{thread_id + d_idx * config.arch.n_threads_per_executor}", + args=( + act_key, + config, + pipe, + params_source, + apply_fns, + actor_device, + seeds, + actor_lifetime, + ), + name=f"Actor-{actor_device}-{thread_id}", ) actor.start() actor_threads.append(actor) eval_queue: Queue = Queue() threading.Thread( - target=learner, + target=learner_thread, name="Learner", - args=(learn, learner_state, config, eval_queue, pipe, param_sources), + args=(learn, learner_state, config, eval_queue, pipe, params_sources), ).start() max_episode_return = -jnp.inf @@ -605,17 +627,21 @@ def run_experiment(_config: DictConfig) -> float: # This loop waits for the learner to finish an update before evaluation and logging. for eval_step in range(config.arch.num_evaluation): # Get the next set of params and metrics from the learner - episode_metrics, train_metrics, learner_state, times_dict = eval_queue.get() + episode_metrics, train_metrics, learner_state, time_metrics = eval_queue.get() t = int(steps_per_rollout * (eval_step + 1)) - times_dict["timestep"] = t - logger.log(times_dict, t, eval_step, LogEvent.MISC) + time_metrics["timestep"] = t + logger.log(time_metrics, t, eval_step, LogEvent.MISC) episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / times_dict["single_rollout_time"] + episode_metrics["steps_per_second"] = steps_per_rollout / time_metrics["rollout_time"] if ep_completed: logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + train_metrics["learner_step"] = (eval_step + 1) * config.system.num_updates_per_eval + train_metrics["learner_steps_per_second"] = ( + config.system.num_updates_per_eval + ) / time_metrics["learner_time_per_eval"] logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) unreplicated_actor_params = unreplicate(learner_state.params.actor_params) @@ -625,11 +651,10 @@ def run_experiment(_config: DictConfig) -> float: episode_return = jnp.mean(eval_metrics["episode_return"]) - if save_checkpoint: - # Save checkpoint of learner state + if save_checkpoint: # Save a checkpoint of the learner state checkpointer.save( timestep=steps_per_rollout * (eval_step + 1), - unreplicated_learner_state=learner_state, + unreplicated_learner_state=unreplicate_n_dims(learner_state), episode_return=episode_return, ) @@ -640,20 +665,28 @@ def run_experiment(_config: DictConfig) -> float: evaluator_envs.close() eval_performance = float(jnp.mean(eval_metrics[config.env.eval_metric])) - # Make sure all of the Threads are closed. - actors_lifetime.stop() + print(f"{Fore.MAGENTA}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}") + # Make sure all of the Threads are stopped. + actor_lifetime.stop() for actor in actor_threads: + # We clear the pipeline before stopping each actor thread to avoid deadlock + pipe.clear() actor.join() + print(f"{Fore.MAGENTA}{Style.BRIGHT}Stopping pipeline...{Style.RESET_ALL}") pipe_lifetime.stop() pipe.join() + print(f"{Fore.MAGENTA}{Style.BRIGHT}Stopping params sources...{Style.RESET_ALL}") params_sources_lifetime.stop() - for param_source in param_sources: - param_source.join() + for params_source in params_sources: + params_source.join() + + print(f"{Fore.MAGENTA}{Style.BRIGHT}All threads stopped...{Style.RESET_ALL}") # Measure absolute metric. if config.arch.absolute_metric: + print(f"{Fore.BLUE}{Style.BRIGHT}Measuring absolute metric...{Style.RESET_ALL}") abs_metric_evaluator, abs_metric_evaluator_envs = get_eval_fn( environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=True ) From 47b8e036f57722d7a2b98d4d0801bdd186a77c1f Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Fri, 11 Oct 2024 09:51:59 +0200 Subject: [PATCH 103/125] fix: update configs to match latest mava --- mava/configs/default/ff_ippo_sebulba.yaml | 11 +++++++++++ mava/configs/default_ff_ippo_sebulba.yaml | 7 ------- mava/configs/env/lbf_gym.yaml | 2 +- mava/systems/ppo/sebulba/ff_ippo.py | 4 +++- 4 files changed, 15 insertions(+), 9 deletions(-) create mode 100644 mava/configs/default/ff_ippo_sebulba.yaml delete mode 100644 mava/configs/default_ff_ippo_sebulba.yaml diff --git a/mava/configs/default/ff_ippo_sebulba.yaml b/mava/configs/default/ff_ippo_sebulba.yaml new file mode 100644 index 000000000..babd113ee --- /dev/null +++ b/mava/configs/default/ff_ippo_sebulba.yaml @@ -0,0 +1,11 @@ +defaults: + - logger: logger + - arch: sebulba + - system: ppo/ff_ippo + - network: mlp # [mlp, continuous_mlp, cnn] + - env: lbf_gym # [rware_gym, lbf_gym] + - _self_ + +hydra: + searchpath: + - file://mava/configs diff --git a/mava/configs/default_ff_ippo_sebulba.yaml b/mava/configs/default_ff_ippo_sebulba.yaml deleted file mode 100644 index 3a7386969..000000000 --- a/mava/configs/default_ff_ippo_sebulba.yaml +++ /dev/null @@ -1,7 +0,0 @@ -defaults: - - logger: ff_ippo - - arch: sebulba - - system: ppo/ff_ippo - - network: mlp - - env: lbf_gym - - _self_ diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml index b0d783a7e..b6c380c9e 100644 --- a/mava/configs/env/lbf_gym.yaml +++ b/mava/configs/env/lbf_gym.yaml @@ -1,7 +1,7 @@ # ---Environment Configs--- defaults: - _self_ - - scenario: gym-lbf-2s-8x8-2p-2f-coop # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] + - scenario: gym-lbf-8x8-2p-2f-coop # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] env_name: LevelBasedForaging # Used for logging purposes. diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 311bb263f..1ce40ac8c 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -704,7 +704,9 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../../configs", config_name="default_ff_ippo_sebulba.yaml", version_base="1.2" + config_path="../../../configs/default/", + config_name="ff_ippo_sebulba.yaml", + version_base="1.2", ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" From 8be803782724c33b466012072397762b24d0a6ac Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Fri, 11 Oct 2024 13:04:59 +0000 Subject: [PATCH 104/125] fix: reshape with multiple learners and system name --- mava/systems/ppo/sebulba/ff_ippo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 1ce40ac8c..8db82fdea 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -164,6 +164,8 @@ def get_learner_step_fn( """Get the learner function.""" num_agents, num_envs = config.system.num_agents, config.arch.num_envs + num_learner_envs = int(num_envs // len(config.arch.learner_device_ids)) + # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns @@ -206,7 +208,7 @@ def _get_advantages( return advantages, advantages + traj_batch.value # Calculate advantage - last_dones = jnp.repeat(learner_state.timestep.last(), num_agents).reshape(num_envs, -1) + last_dones = jnp.repeat(learner_state.timestep.last(), num_agents).reshape(num_learner_envs, -1) params, opt_states, key, _, _ = learner_state last_val = critic_apply_fn(params.critic_params, learner_state.timestep.observation) advantages, targets = _calculate_gae(traj_batch, last_val, last_dones) @@ -327,9 +329,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) # Shuffle minibatches - batch_size = config.system.rollout_length * ( - config.arch.num_envs // len(config.arch.learner_device_ids) - ) + batch_size = config.system.rollout_length * num_learner_envs permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) batch = tree.map(lambda x: merge_leading_dims(x, 2), batch) @@ -712,6 +712,7 @@ def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. OmegaConf.set_struct(cfg, False) + cfg.logger.system_name = "ff_ippo_sebulba" # Run experiment. eval_performance = run_experiment(cfg) From 47486364921372f3a29b8cc8dd71df5de8137246 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Fri, 11 Oct 2024 16:27:07 +0200 Subject: [PATCH 105/125] fix: safer pipeline.clear() --- mava/utils/sebulba.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index eee211828..b9d95c7f5 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -142,7 +142,10 @@ def get( def clear(self) -> None: """Clear the pipeline.""" while not self._queue.empty(): - self._queue.get() + try: + self._queue.get(block=False) + except queue.Empty: + break def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: split_payload = jnp.split(payload, len(self.learner_devices), axis=axis) From 5593bde87a3aafb2f3cc7344ef87aa446f9637f1 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Mon, 14 Oct 2024 13:59:28 +0000 Subject: [PATCH 106/125] feat: avoid unecessary host-device transfers --- mava/systems/ppo/sebulba/ff_ippo.py | 113 +++++++++++++--------------- mava/utils/sebulba.py | 12 ++- mava/wrappers/gym.py | 37 ++++++++- 3 files changed, 98 insertions(+), 64 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 8db82fdea..326c94f35 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -74,13 +74,11 @@ def rollout( seeds (List[int]): Seeds for initializing the environment. thread_lifetime (ThreadLifetime): Manages the thread's lifecycle. """ - # setup env = environments.make_gym_env(config, config.arch.num_envs) actor_apply_fn, critic_apply_fn = apply_fns num_agents, num_envs = config.system.num_agents, config.arch.num_envs move_to_device = lambda x: jax.device_put(x, device=actor_device) - # Define the util functions: select action function and prepare data to share it with learner. @jax.jit def act_fn( params: Params, @@ -96,62 +94,57 @@ def act_fn( return action, log_prob, value timestep = env.reset(seed=seeds) - next_dones = jnp.repeat(timestep.last(), num_agents).reshape(num_envs, -1) - - with jax.default_device(actor_device): - # Loop till the desired num_updates is reached. - while not thread_lifetime.should_stop(): - # Rollout - traj: List[PPOTransition] = [] - actor_timings: Dict[str, List[float]] = defaultdict(list) - # Loop over the rollout length - with RecordTimeTo(actor_timings["rollout_time"]): - for _ in range(config.system.rollout_length): - with RecordTimeTo(actor_timings["get_params_time"]): - # Get the latest parameters from the learner - params = params_source.get() - - cached_next_obs = tree.map(move_to_device, timestep.observation) - cached_next_dones = move_to_device(next_dones) - - # Get action and value - with RecordTimeTo(actor_timings["compute_action_time"]): - key, act_key = jax.random.split(key) - action, log_prob, value = act_fn(params, cached_next_obs, act_key) - cpu_action = jax.device_get(action) - - # Step environment - with RecordTimeTo(actor_timings["env_step_time"]): - # (num_env, num_agents) --> (num_agents, num_env) - timestep = env.step(cpu_action.swapaxes(0, 1)) - - next_dones = jnp.repeat(timestep.last(), num_agents).reshape(num_envs, -1) - - # Append data to storage - reward = timestep.reward - info = timestep.extras # todo: [metrics]? - # todo: when logging make sure timing dict has parent timing/... - traj.append( - PPOTransition( - cached_next_dones, - action, - value, - reward, - log_prob, - cached_next_obs, - info, - ) - ) - # send trajectories to learner - with RecordTimeTo(actor_timings["rollout_put_time"]): - try: - rollout_queue.put(traj, timestep, actor_timings) - except queue.Full: - warnings.warn( - "Waited too long to add to the rollout queue, killing the actor thread", - stacklevel=2, + next_dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) + + # with jax.default_device(actor_device): + # Loop till the desired num_updates is reached. + while not thread_lifetime.should_stop(): + # Rollout + traj: List[PPOTransition] = [] + actor_timings: Dict[str, List[float]] = defaultdict(list) + with RecordTimeTo(actor_timings["rollout_time"]): + for _ in range(config.system.rollout_length): + with RecordTimeTo(actor_timings["get_params_time"]): + # Get the latest parameters from the learner + params = params_source.get() + + cached_next_obs = tree.map(move_to_device, timestep.observation) + cached_next_dones = move_to_device(next_dones) + + # Get action and value + with RecordTimeTo(actor_timings["compute_action_time"]): + key, act_key = jax.random.split(key) + action, log_prob, value = act_fn(params, cached_next_obs, act_key) + cpu_action = jax.device_get(action) + + # Step environment + with RecordTimeTo(actor_timings["env_step_time"]): + timestep = env.step(cpu_action.swapaxes(0, 1)) + + # todo: just for fixing transfer guard, real issue is the TimeStep.last() - need to make sebulba timestep type + next_dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) + + # Append data to storage + # todo: when logging make sure timing dict has parent timing/... + traj.append( + PPOTransition( + cached_next_dones, + action, + value, + timestep.reward, + log_prob, + cached_next_obs, + timestep.extras, ) - break + ) + # send trajectories to learner + with RecordTimeTo(actor_timings["rollout_put_time"]): + try: + rollout_queue.put(traj, timestep, actor_timings) + except queue.Full: + err = "Waited too long to add to the rollout queue, killing the actor thread" + warnings.warn(err, stacklevel=2) + break env.close() @@ -619,7 +612,7 @@ def run_experiment(_config: DictConfig) -> float: args=(learn, learner_state, config, eval_queue, pipe, params_sources), ).start() - max_episode_return = -jnp.inf + max_episode_return = -np.inf best_params = inital_params.actor_params # This is the main loop, all it does is evaluation and logging. @@ -649,7 +642,7 @@ def run_experiment(_config: DictConfig) -> float: eval_metrics = evaluator(unreplicated_actor_params, eval_key, {}) logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) - episode_return = jnp.mean(eval_metrics["episode_return"]) + episode_return = np.mean(eval_metrics["episode_return"]) if save_checkpoint: # Save a checkpoint of the learner state checkpointer.save( @@ -663,7 +656,7 @@ def run_experiment(_config: DictConfig) -> float: max_episode_return = episode_return evaluator_envs.close() - eval_performance = float(jnp.mean(eval_metrics[config.env.eval_metric])) + eval_performance = float(np.mean(eval_metrics[config.env.eval_metric])) print(f"{Fore.MAGENTA}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}") # Make sure all of the Threads are stopped. diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index b9d95c7f5..22753de0c 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -13,6 +13,7 @@ # limitations under the License. +from functools import partial import queue import threading import time @@ -20,6 +21,7 @@ import jax import jax.numpy as jnp +import numpy as np from colorama import Fore, Style from jax import tree from jumanji.types import TimeStep @@ -68,6 +70,7 @@ def __init__(self, max_size: int, learner_devices: List[jax.Device], lifetime: T lifetime: A `ThreadLifetime` which is used to stop this thread. """ super().__init__(name="Pipeline") + self.learner_devices = learner_devices self.tickets_queue: queue.Queue = queue.Queue() self._queue: queue.Queue = queue.Queue(maxsize=max_size) @@ -148,9 +151,14 @@ def clear(self) -> None: break def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: - split_payload = jnp.split(payload, len(self.learner_devices), axis=axis) - return jax.device_put_sharded(split_payload, devices=self.learner_devices) + return self.shard_payload(self.split_payload(payload, axis)) + + @partial(jax.jit, static_argnums=(0, 2)) + def split_payload(self, payload: Any, axis: int = 0): + return jnp.split(payload, len(self.learner_devices), axis=axis) + def shard_payload(self, payload: Any): + return jax.device_put_sharded(payload, devices=self.learner_devices) class ParamsSource(threading.Thread): """A `ParamSource` is a component that allows networks params to be passed from a diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 2756b3511..0b2dff78d 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -17,23 +17,56 @@ import warnings from multiprocessing import Queue from multiprocessing.connection import Connection -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union, NamedTuple, TYPE_CHECKING +from dataclasses import field import gymnasium import gymnasium.vector.async_vector_env import numpy as np from gymnasium import spaces from gymnasium.spaces.utils import is_space_dtype_shape_equiv from gymnasium.vector.utils import write_to_shared_memory -from jumanji.types import StepType, TimeStep from numpy.typing import NDArray from mava.types import Observation, ObservationGlobalState +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + # Filter out the warnings warnings.filterwarnings("ignore", module="gymnasium.utils.passive_env_checker") +# needed to avoid host -> device transfers when calling TimeStep.last() +class StepType: + """Coppy of Jumanji's step type but with numpy arrays""" + + FIRST = 0 + MID = 1 + LAST = 2 + + +@dataclass +class TimeStep: + step_type: StepType + reward: NDArray + discount: NDArray + observation: Observation + extras: Dict = field(default_factory=dict) + + + def first(self) -> bool: + return self.step_type == StepType.FIRST + + def mid(self) -> bool: + return self.step_type == StepType.MID + + def last(self) -> bool: + return self.step_type == StepType.LAST + + class GymWrapper(gymnasium.Wrapper): """Base wrapper for multi-agent gym environments. This wrapper works out of the box for RobotWarehouse. From 133ea1ad1cf00a4c1f58809835111d95a3f4ee02 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Mon, 14 Oct 2024 16:02:52 +0000 Subject: [PATCH 107/125] chore: remove some more device transfers --- mava/systems/ppo/sebulba/ff_ippo.py | 4 +--- mava/wrappers/episode_metrics.py | 5 +++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 326c94f35..cca138205 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -121,11 +121,9 @@ def act_fn( with RecordTimeTo(actor_timings["env_step_time"]): timestep = env.step(cpu_action.swapaxes(0, 1)) - # todo: just for fixing transfer guard, real issue is the TimeStep.last() - need to make sebulba timestep type next_dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) # Append data to storage - # todo: when logging make sure timing dict has parent timing/... traj.append( PPOTransition( cached_next_dones, @@ -623,7 +621,7 @@ def run_experiment(_config: DictConfig) -> float: episode_metrics, train_metrics, learner_state, time_metrics = eval_queue.get() t = int(steps_per_rollout * (eval_step + 1)) - time_metrics["timestep"] = t + time_metrics |= {"timestep": t, "pipline_size": pipe.qsize()} logger.log(time_metrics, t, eval_step, LogEvent.MISC) episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) diff --git a/mava/wrappers/episode_metrics.py b/mava/wrappers/episode_metrics.py index e9e130819..63d65e35e 100644 --- a/mava/wrappers/episode_metrics.py +++ b/mava/wrappers/episode_metrics.py @@ -20,6 +20,7 @@ from jax import tree from jumanji.types import TimeStep from jumanji.wrappers import Wrapper +import numpy as np from mava.types import MarlEnv, State @@ -120,12 +121,12 @@ def get_final_step_metrics(metrics: Dict[str, chex.Array]) -> Tuple[Dict[str, ch expects arrays for computing summary statistics on the episode metrics. """ is_final_ep = metrics.pop("is_terminal_step") - has_final_ep_step = bool(jnp.any(is_final_ep)) + has_final_ep_step = bool(np.any(is_final_ep)) final_metrics: Dict[str, chex.Array] # If it didn't make it to the final step, return zeros. if not has_final_ep_step: - final_metrics = tree.map(jnp.zeros_like, metrics) + final_metrics = tree.map(np.zeros_like, metrics) else: final_metrics = tree.map(lambda x: x[is_final_ep], metrics) From 9260e9b52080d434da599a7d3536f2832ccb8a1c Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Mon, 14 Oct 2024 19:38:11 +0000 Subject: [PATCH 108/125] chore: better graceful exit --- mava/systems/ppo/sebulba/ff_ippo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index cca138205..75409944c 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -659,10 +659,12 @@ def run_experiment(_config: DictConfig) -> float: print(f"{Fore.MAGENTA}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}") # Make sure all of the Threads are stopped. actor_lifetime.stop() + # We clear the pipeline before stopping the actor threads to avoid deadlock + pipe.clear() + print(f"{Fore.RED}{Style.BRIGHT}Pipe cleared: {pipe.qsize()}{Style.RESET_ALL}") for actor in actor_threads: - # We clear the pipeline before stopping each actor thread to avoid deadlock - pipe.clear() actor.join() + print(f"{Fore.RED}{Style.BRIGHT}{actor.name} stopped{Style.RESET_ALL}") print(f"{Fore.MAGENTA}{Style.BRIGHT}Stopping pipeline...{Style.RESET_ALL}") pipe_lifetime.stop() From d61dcfb4decc6790f2d8383cf80dea9601fef45c Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 15 Oct 2024 12:58:17 +0000 Subject: [PATCH 109/125] fix: create envs in main thread to avoid deadlocks --- mava/systems/ppo/sebulba/ff_ippo.py | 62 ++++++++++++++++++----------- mava/utils/logger.py | 1 + 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 75409944c..5208bc312 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -53,6 +53,7 @@ def rollout( key: chex.PRNGKey, + env, config: DictConfig, rollout_queue: Pipeline, params_source: ParamsSource, @@ -74,7 +75,8 @@ def rollout( seeds (List[int]): Seeds for initializing the environment. thread_lifetime (ThreadLifetime): Manages the thread's lifecycle. """ - env = environments.make_gym_env(config, config.arch.num_envs) + name = threading.current_thread().name + print(f"{Fore.BLUE}{Style.BRIGHT}Thread {name} started{Style.RESET_ALL}") actor_apply_fn, critic_apply_fn = apply_fns num_agents, num_envs = config.system.num_agents, config.arch.num_envs move_to_device = lambda x: jax.device_put(x, device=actor_device) @@ -96,7 +98,6 @@ def act_fn( timestep = env.reset(seed=seeds) next_dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) - # with jax.default_device(actor_device): # Loop till the desired num_updates is reached. while not thread_lifetime.should_stop(): # Rollout @@ -104,6 +105,10 @@ def act_fn( actor_timings: Dict[str, List[float]] = defaultdict(list) with RecordTimeTo(actor_timings["rollout_time"]): for _ in range(config.system.rollout_length): + # if thread_lifetime.should_stop(): + # env.close() + # return + with RecordTimeTo(actor_timings["get_params_time"]): # Get the latest parameters from the learner params = params_source.get() @@ -135,6 +140,7 @@ def act_fn( timestep.extras, ) ) + # send trajectories to learner with RecordTimeTo(actor_timings["rollout_put_time"]): try: @@ -574,8 +580,17 @@ def run_experiment(_config: DictConfig) -> float: actor_lifetime = ThreadLifetime() params_sources_lifetime = ThreadLifetime() + # Unfortunately we have to do this here, because creating envs inside the actor threads causes deadlocks + envs = [[] for i in range(len(actor_devices))] + print(f"{Fore.BLUE}{Style.BRIGHT}Starting up environments, this may take a while...{Style.RESET_ALL}") + for i in range(len(actor_devices)): + for _ in range(config.arch.n_threads_per_executor): + env = environments.make_gym_env(config, config.arch.num_envs) + envs[i].append(env) + print(f"{Fore.BLUE}{Style.BRIGHT}All environments created{Style.RESET_ALL}") + # Create the actor threads - for actor_device in actor_devices: + for dev_idx, actor_device in enumerate(actor_devices): # Create 1 params source per device params_source = ParamsSource(inital_params, actor_device, params_sources_lifetime) params_source.start() @@ -590,6 +605,7 @@ def run_experiment(_config: DictConfig) -> float: target=rollout, args=( act_key, + envs[dev_idx][thread_id], config, pipe, params_source, @@ -656,26 +672,6 @@ def run_experiment(_config: DictConfig) -> float: evaluator_envs.close() eval_performance = float(np.mean(eval_metrics[config.env.eval_metric])) - print(f"{Fore.MAGENTA}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}") - # Make sure all of the Threads are stopped. - actor_lifetime.stop() - # We clear the pipeline before stopping the actor threads to avoid deadlock - pipe.clear() - print(f"{Fore.RED}{Style.BRIGHT}Pipe cleared: {pipe.qsize()}{Style.RESET_ALL}") - for actor in actor_threads: - actor.join() - print(f"{Fore.RED}{Style.BRIGHT}{actor.name} stopped{Style.RESET_ALL}") - - print(f"{Fore.MAGENTA}{Style.BRIGHT}Stopping pipeline...{Style.RESET_ALL}") - pipe_lifetime.stop() - pipe.join() - - print(f"{Fore.MAGENTA}{Style.BRIGHT}Stopping params sources...{Style.RESET_ALL}") - params_sources_lifetime.stop() - for params_source in params_sources: - params_source.join() - - print(f"{Fore.MAGENTA}{Style.BRIGHT}All threads stopped...{Style.RESET_ALL}") # Measure absolute metric. if config.arch.absolute_metric: @@ -692,6 +688,26 @@ def run_experiment(_config: DictConfig) -> float: # Stop the logger. logger.stop() + # Ask actors to stop before running the evaluator + actor_lifetime.stop() + # We clear the pipeline before stopping the actor threads to avoid deadlock + pipe.clear() + print(f"{Fore.RED}{Style.BRIGHT}Pipe cleared: {pipe.qsize()}{Style.RESET_ALL}") + + print(f"{Fore.RED}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}") + for actor in actor_threads: + actor.join() + print(f"{Fore.RED}{Style.BRIGHT}{actor.name} stopped{Style.RESET_ALL}") + + print(f"{Fore.RED}{Style.BRIGHT}Stopping pipeline...{Style.RESET_ALL}") + pipe_lifetime.stop() + pipe.join() + + print(f"{Fore.RED}{Style.BRIGHT}Stopping params sources...{Style.RESET_ALL}") + params_sources_lifetime.stop() + for params_source in params_sources: + params_source.join() + print(f"{Fore.RED}{Style.BRIGHT}All threads stopped...{Style.RESET_ALL}") return eval_performance diff --git a/mava/utils/logger.py b/mava/utils/logger.py index d7af26402..bd090604b 100644 --- a/mava/utils/logger.py +++ b/mava/utils/logger.py @@ -178,6 +178,7 @@ def log_stat(self, key: str, value: float, step: int, eval_step: int, event: Log if not self.detailed_logging and not is_main_metric: return + value = value.item() if isinstance(value, (jax.Array, np.ndarray)) else value self.logger[f"{event.value}/{key}"].log(value, step=step) def stop(self) -> None: From 105d796a454a99a4a5d0ab2cbc67f16b33944a25 Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Tue, 15 Oct 2024 19:20:50 +0100 Subject: [PATCH 110/125] chore: use orginal rware and lbf --- mava/systems/ppo/sebulba/ff_ippo.py | 12 +++++++----- mava/utils/make_env.py | 3 +-- mava/utils/sebulba.py | 4 ++-- mava/wrappers/__init__.py | 1 - mava/wrappers/episode_metrics.py | 2 +- mava/wrappers/gym.py | 25 +++++-------------------- requirements/requirements.txt | 4 ++-- 7 files changed, 18 insertions(+), 33 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 5208bc312..2daaf30e7 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -146,7 +146,7 @@ def act_fn( try: rollout_queue.put(traj, timestep, actor_timings) except queue.Full: - err = "Waited too long to add to the rollout queue, killing the actor thread" + err = "Waited too long to add to the rollout queue, killing the actor thread" warnings.warn(err, stacklevel=2) break @@ -162,7 +162,6 @@ def get_learner_step_fn( num_agents, num_envs = config.system.num_agents, config.arch.num_envs num_learner_envs = int(num_envs // len(config.arch.learner_device_ids)) - # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns @@ -205,7 +204,9 @@ def _get_advantages( return advantages, advantages + traj_batch.value # Calculate advantage - last_dones = jnp.repeat(learner_state.timestep.last(), num_agents).reshape(num_learner_envs, -1) + last_dones = jnp.repeat(learner_state.timestep.last(), num_agents).reshape( + num_learner_envs, -1 + ) params, opt_states, key, _, _ = learner_state last_val = critic_apply_fn(params.critic_params, learner_state.timestep.observation) advantages, targets = _calculate_gae(traj_batch, last_val, last_dones) @@ -582,7 +583,9 @@ def run_experiment(_config: DictConfig) -> float: # Unfortunately we have to do this here, because creating envs inside the actor threads causes deadlocks envs = [[] for i in range(len(actor_devices))] - print(f"{Fore.BLUE}{Style.BRIGHT}Starting up environments, this may take a while...{Style.RESET_ALL}") + print( + f"{Fore.BLUE}{Style.BRIGHT}Starting up environments, this may take a while...{Style.RESET_ALL}" + ) for i in range(len(actor_devices)): for _ in range(config.arch.n_threads_per_executor): env = environments.make_gym_env(config, config.arch.num_envs) @@ -672,7 +675,6 @@ def run_experiment(_config: DictConfig) -> float: evaluator_envs.close() eval_performance = float(np.mean(eval_metrics[config.env.eval_metric])) - # Measure absolute metric. if config.arch.absolute_metric: print(f"{Fore.BLUE}{Style.BRIGHT}Measuring absolute metric...{Style.RESET_ALL}") diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index a5010307a..1d71ddce0 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -46,7 +46,6 @@ ConnectorWrapper, GigastepWrapper, GymAgentIDWrapper, - GymLBFWrapper, GymRecordEpisodeMetrics, GymToJumanji, GymWrapper, @@ -78,7 +77,7 @@ _gym_registry = { "RobotWarehouse": (gym_Warehouse, GymWrapper), - "LevelBasedForaging": (gym_ForagingEnv, GymLBFWrapper), + "LevelBasedForaging": (gym_ForagingEnv, GymWrapper), } diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index 22753de0c..cead3b6ba 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -13,15 +13,14 @@ # limitations under the License. -from functools import partial import queue import threading import time +from functools import partial from typing import Any, Dict, List, Sequence, Tuple, Union import jax import jax.numpy as jnp -import numpy as np from colorama import Fore, Style from jax import tree from jumanji.types import TimeStep @@ -160,6 +159,7 @@ def split_payload(self, payload: Any, axis: int = 0): def shard_payload(self, payload: Any): return jax.device_put_sharded(payload, devices=self.learner_devices) + class ParamsSource(threading.Thread): """A `ParamSource` is a component that allows networks params to be passed from a `Learner` component to `Actor` components. diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index a7b56c5da..f8cf8a64c 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -18,7 +18,6 @@ from mava.wrappers.gigastep import GigastepWrapper from mava.wrappers.gym import ( GymAgentIDWrapper, - GymLBFWrapper, GymRecordEpisodeMetrics, GymToJumanji, GymWrapper, diff --git a/mava/wrappers/episode_metrics.py b/mava/wrappers/episode_metrics.py index 63d65e35e..f4c34002e 100644 --- a/mava/wrappers/episode_metrics.py +++ b/mava/wrappers/episode_metrics.py @@ -17,10 +17,10 @@ import chex import jax import jax.numpy as jnp +import numpy as np from jax import tree from jumanji.types import TimeStep from jumanji.wrappers import Wrapper -import numpy as np from mava.types import MarlEnv, State diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 0b2dff78d..39870b211 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -15,11 +15,11 @@ import sys import traceback import warnings +from dataclasses import field from multiprocessing import Queue from multiprocessing.connection import Connection -from typing import Any, Callable, Dict, Optional, Tuple, Union, NamedTuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union -from dataclasses import field import gymnasium import gymnasium.vector.async_vector_env import numpy as np @@ -56,7 +56,6 @@ class TimeStep: observation: Observation extras: Dict = field(default_factory=dict) - def first(self) -> bool: return self.step_type == StepType.FIRST @@ -69,8 +68,7 @@ def last(self) -> bool: class GymWrapper(gymnasium.Wrapper): """Base wrapper for multi-agent gym environments. - This wrapper works out of the box for RobotWarehouse. - See `GymLBFWrapper` for how it can be modified to work for other environments. + This wrapper works out of the box for RobotWarehouse and level based foraging. """ def __init__( @@ -131,18 +129,6 @@ def get_global_obs(self, obs: NDArray) -> NDArray: return np.tile(global_obs, (self.num_agents, 1)) -class GymLBFWrapper(GymWrapper): - """Wrapper for the gym level based foraging environment.""" - - def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: - agents_view, reward, terminated, truncated, info = super().step(actions) - - truncated = np.repeat(truncated, self.num_agents) - terminated = np.repeat(terminated, self.num_agents) - - return agents_view, reward, terminated, truncated, info - - class GymRecordEpisodeMetrics(gymnasium.Wrapper): """Record the episode returns and lengths.""" @@ -247,7 +233,7 @@ def reset( ep_done = np.zeros(num_envs, dtype=float) rewards = np.zeros((num_envs, num_agents), dtype=float) - teminated = np.zeros((num_envs, num_agents), dtype=float) + teminated = np.zeros(num_envs, dtype=float) timestep = self._create_timestep(obs, ep_done, teminated, rewards, info) @@ -256,7 +242,7 @@ def reset( def step(self, action: list) -> TimeStep: obs, rewards, terminated, truncated, info = self.env.step(action) - ep_done = np.logical_or(terminated, truncated).all(axis=1) + ep_done = np.logical_or(terminated, truncated) timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) @@ -286,7 +272,6 @@ def _create_timestep( # Filter out the masks and auxiliary data extras = {key: value for key, value in info["metrics"].items() if key[0] != "_"} step_type = np.where(ep_done, StepType.LAST, StepType.MID) - terminated = np.all(terminated, axis=1) return TimeStep( step_type=step_type, diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 71432102f..61f7fe68a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,7 +11,7 @@ jax==0.4.30 jaxlib==0.4.30 jaxmarl jumanji @ git+https://github.com/sash-a/jumanji@old_jumanji # Includes a few extra MARL envs -lbforaging @ git+https://github.com/LukasSchaefer/lb-foraging.git@gymnasium_integration # fixes: https://github.com/semitable/lb-foraging/issues/20 +lbforaging matrax @ git+https://github.com/instadeepai/matrax mujoco==3.1.3 mujoco-mjx==3.1.3 @@ -20,7 +20,7 @@ numpy==1.26.4 omegaconf optax protobuf~=3.20 -rware @ git+https://github.com/RuanJohn/robotic-warehouse.git # compatibility with latest gymnasium +rware scipy==1.12.0 tensorboard_logger tensorflow_probability From f292bf303d42e66eb28775bbf6f4a9d52f6f338c Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Wed, 16 Oct 2024 12:48:27 +0200 Subject: [PATCH 111/125] fix: possible off by one fix --- mava/systems/ppo/sebulba/ff_ippo.py | 51 ++++++++++++++--------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 2daaf30e7..b0a74f716 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -81,6 +81,8 @@ def rollout( num_agents, num_envs = config.system.num_agents, config.arch.num_envs move_to_device = lambda x: jax.device_put(x, device=actor_device) + key = move_to_device(key) + @jax.jit def act_fn( params: Params, @@ -96,7 +98,7 @@ def act_fn( return action, log_prob, value timestep = env.reset(seed=seeds) - next_dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) + dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) # Loop till the desired num_updates is reached. while not thread_lifetime.should_stop(): @@ -105,38 +107,33 @@ def act_fn( actor_timings: Dict[str, List[float]] = defaultdict(list) with RecordTimeTo(actor_timings["rollout_time"]): for _ in range(config.system.rollout_length): - # if thread_lifetime.should_stop(): - # env.close() - # return - with RecordTimeTo(actor_timings["get_params_time"]): # Get the latest parameters from the learner params = params_source.get() - cached_next_obs = tree.map(move_to_device, timestep.observation) - cached_next_dones = move_to_device(next_dones) + obs_tpu = tree.map(move_to_device, timestep.observation) # Get action and value with RecordTimeTo(actor_timings["compute_action_time"]): key, act_key = jax.random.split(key) - action, log_prob, value = act_fn(params, cached_next_obs, act_key) + action, log_prob, value = act_fn(params, obs_tpu, act_key) cpu_action = jax.device_get(action) # Step environment with RecordTimeTo(actor_timings["env_step_time"]): timestep = env.step(cpu_action.swapaxes(0, 1)) - next_dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) + dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) # Append data to storage traj.append( PPOTransition( - cached_next_dones, + dones, action, value, timestep.reward, log_prob, - cached_next_obs, + obs_tpu, timestep.extras, ) ) @@ -182,21 +179,24 @@ def _update_step( """ def _calculate_gae( - traj_batch: PPOTransition, last_val: chex.Array, last_done: chex.Array + traj_batch: PPOTransition, last_val: chex.Array ) -> Tuple[chex.Array, chex.Array]: - def _get_advantages( - carry: Tuple[chex.Array, chex.Array, chex.Array], transition: PPOTransition - ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: - gae, next_value, next_done = carry + """Calculate the GAE.""" + + gamma, gae_lambda = config.system.gamma, config.system.gae_lambda + + def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple: + """Calculate the GAE for a single transition.""" + gae, next_value = gae_and_next_value done, value, reward = transition.done, transition.value, transition.reward - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - next_done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae - return (gae, value, done), gae + + delta = reward + gamma * next_value * (1 - done) - value + gae = delta + gamma * gae_lambda * (1 - done) * gae + return (gae, value), gae _, advantages = jax.lax.scan( _get_advantages, - (jnp.zeros_like(last_val), last_val, last_done), + (jnp.zeros_like(last_val), last_val), traj_batch, reverse=True, unroll=16, @@ -204,12 +204,9 @@ def _get_advantages( return advantages, advantages + traj_batch.value # Calculate advantage - last_dones = jnp.repeat(learner_state.timestep.last(), num_agents).reshape( - num_learner_envs, -1 - ) - params, opt_states, key, _, _ = learner_state - last_val = critic_apply_fn(params.critic_params, learner_state.timestep.observation) - advantages, targets = _calculate_gae(traj_batch, last_val, last_dones) + params, opt_states, key, _, final_timestep = learner_state + last_val = critic_apply_fn(params.critic_params, final_timestep.observation) + advantages, targets = _calculate_gae(traj_batch, last_val) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" From d42d7328bea97c1fd81faf17a1ef296b78385b2e Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Wed, 16 Oct 2024 16:26:05 +0200 Subject: [PATCH 112/125] fix: change to using gym.make to create envs and fix StepType --- mava/configs/default/ff_ippo_sebulba.yaml | 2 +- mava/configs/env/lbf_gym.yaml | 6 ++++-- mava/configs/env/rware_gym.yaml | 4 +++- .../env/scenario/gym-lbf-10x10-3p-3f.yaml | 18 ------------------ .../env/scenario/gym-lbf-15x15-3p-5f.yaml | 18 ------------------ .../env/scenario/gym-lbf-15x15-4p-3f.yaml | 18 ------------------ .../env/scenario/gym-lbf-15x15-4p-5f.yaml | 18 ------------------ .../env/scenario/gym-lbf-2s-10x10-3p-3f.yaml | 18 ------------------ .../scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml | 18 ------------------ .../env/scenario/gym-lbf-8x8-2p-2f-coop.yaml | 18 ------------------ .../env/scenario/gym-rware-small-4ag.yaml | 18 ------------------ .../env/scenario/gym-rware-tiny-2ag.yaml | 18 ------------------ .../env/scenario/gym-rware-tiny-4ag-easy.yaml | 18 ------------------ .../env/scenario/gym-rware-tiny-4ag.yaml | 18 ------------------ mava/utils/make_env.py | 12 ++++++------ mava/wrappers/gym.py | 9 ++++++--- 16 files changed, 20 insertions(+), 211 deletions(-) delete mode 100644 mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml delete mode 100644 mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml delete mode 100644 mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml delete mode 100644 mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml delete mode 100644 mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml delete mode 100644 mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml delete mode 100644 mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml delete mode 100644 mava/configs/env/scenario/gym-rware-small-4ag.yaml delete mode 100644 mava/configs/env/scenario/gym-rware-tiny-2ag.yaml delete mode 100644 mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml delete mode 100644 mava/configs/env/scenario/gym-rware-tiny-4ag.yaml diff --git a/mava/configs/default/ff_ippo_sebulba.yaml b/mava/configs/default/ff_ippo_sebulba.yaml index babd113ee..7669049b1 100644 --- a/mava/configs/default/ff_ippo_sebulba.yaml +++ b/mava/configs/default/ff_ippo_sebulba.yaml @@ -3,7 +3,7 @@ defaults: - arch: sebulba - system: ppo/ff_ippo - network: mlp # [mlp, continuous_mlp, cnn] - - env: lbf_gym # [rware_gym, lbf_gym] + - env: rware_gym # [rware_gym, lbf_gym] - _self_ hydra: diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml index b6c380c9e..39d624daa 100644 --- a/mava/configs/env/lbf_gym.yaml +++ b/mava/configs/env/lbf_gym.yaml @@ -1,16 +1,18 @@ # ---Environment Configs--- defaults: - _self_ - - scenario: gym-lbf-8x8-2p-2f-coop # [gym-lbf-2s-8x8-2p-2f-coop, gym-lbf-8x8-2p-2f-coop, gym-lbf-2s-10x10-3p-3f, gym-lbf-10x10-3p-3f, gym-lbf-15x15-3p-5f, gym-lbf-15x15-4p-3f, gym-lbf-15x15-4p-5f] env_name: LevelBasedForaging # Used for logging purposes. +scenario: + name: lbforaging + task_name: Foraging-8x8-2p-1f-v3 # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return # Whether the add agents IDs to the observations returned by the environment. -add_agent_id : False +add_agent_id: False # Whether or not to log the winrate of this environment. log_win_rate: False diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml index 87bd3a473..da8c73402 100644 --- a/mava/configs/env/rware_gym.yaml +++ b/mava/configs/env/rware_gym.yaml @@ -1,9 +1,11 @@ # ---Environment Configs--- defaults: - _self_ - - scenario: gym-rware-tiny-2ag # [gym-rware-tiny-2ag, gym-rware-tiny-4ag, gym-rware-tiny-4ag-easy, gym-rware-small-4ag] env_name: RobotWarehouse # Used for logging purposes. +scenario: + name: rware + task_name: rware-tiny-2ag-v2 # [rware-tiny-2ag-v2, rware-tiny-4ag-v2, rware-tiny-4ag-easy-v2, rware-small-4ag-v2] # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. diff --git a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml deleted file mode 100644 index a2150115b..000000000 --- a/mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the 10x10-3p-3f scenario with the VectorObserver set as default -name: LevelBasedForaging -task_name: 10x10-3p-3f - -task_config: - field_size: [10,10] - sight: 10 - players: 3 - max_num_food: 3 - max_player_level: 2 - force_coop: False - max_episode_steps: 100 - min_player_level : 1 - min_food_level : null - max_food_level : null - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml deleted file mode 100644 index 70031bad0..000000000 --- a/mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the 15x15-3p-5f scenario with the VectorObserver set as default -name: LevelBasedForaging -task_name: 15x15-3p-5f - -task_config: - field_size: [15, 15] - sight: 15 - players: 3 - max_num_food: 5 - max_player_level: 2 - force_coop: False - max_episode_steps: 100 - min_player_level : 1 - min_food_level : null - max_food_level : null - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml deleted file mode 100644 index b1fe6e4be..000000000 --- a/mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the 15x15-4p-3f scenario with the VectorObserver set as default -name: LevelBasedForaging -task_name: 15x15-4p-3f - -task_config: - field_size: [15, 15] - sight: 15 - players: 4 - max_num_food: 3 - max_player_level: 2 - force_coop: False - max_episode_steps: 100 - min_player_level : 1 - min_food_level : null - max_food_level : null - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml b/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml deleted file mode 100644 index 9ce0100f5..000000000 --- a/mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the 15x15-4p-5f scenario with the VectorObserver set as default -name: LevelBasedForaging -task_name: 15x15-4p-5f - -task_config: - field_size: [15, 15] - sight: 15 - players: 4 - max_num_food: 5 - max_player_level: 2 - force_coop: False - max_episode_steps: 100 - min_player_level : 1 - min_food_level : null - max_food_level : null - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml b/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml deleted file mode 100644 index fea817887..000000000 --- a/mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the 2s10x10-3p-3f scenario with the VectorObserver set as default -name: LevelBasedForaging -task_name: 2s-10x10-3p-3f - -task_config: - field_size: [10, 10] - sight: 2 - players: 3 - max_num_food: 3 - max_player_level: 2 - force_coop: False - max_episode_steps: 100 - min_player_level : 1 - min_food_level : null - max_food_level : null - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml deleted file mode 100644 index b0cacb95c..000000000 --- a/mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the 2s-8x8-2p-2f-coop scenario with the VectorObserver set as default. -name: LevelBasedForaging -task_name: 2s-8x8-2p-2f-coop - -task_config: - field_size: [8, 8] # size of the grid to generate. - sight: 2 # field of view of an agent. - players: 2 # number of agents on the grid. - max_num_food: 2 # number of food in the environment. - max_player_level: 2 # maximum level of the agents (inclusive). - force_coop: True # force cooperation between agents. - max_episode_steps: 100 # max number of steps per episode. - min_player_level : 1 # minimum level of the agents (inclusive). - min_food_level : null - max_food_level : null - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml b/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml deleted file mode 100644 index 3b9cee314..000000000 --- a/mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the 8x8-2p-2f-coop scenario with the VectorObserver set as default -name: LevelBasedForaging -task_name: 8x8-2p-2f-coop - -task_config: - field_size: [8, 8] - sight: 8 - players: 2 - max_num_food: 2 - max_player_level: 2 - force_coop: True - max_episode_steps: 100 - min_player_level : 1 - min_food_level : null - max_food_level : null - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-small-4ag.yaml b/mava/configs/env/scenario/gym-rware-small-4ag.yaml deleted file mode 100644 index 39f8efa4e..000000000 --- a/mava/configs/env/scenario/gym-rware-small-4ag.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the small-4ag environment -name: RobotWarehouse -task_name: small-4ag - -task_config: - column_height: 8 - shelf_rows: 2 - shelf_columns: 3 - n_agents: 4 - sensor_range: 1 - request_queue_size: 4 - msg_bits : 0 - max_inactivity_steps : null - max_steps : 500 - reward_type : 0 - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml b/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml deleted file mode 100644 index 95ef11fc2..000000000 --- a/mava/configs/env/scenario/gym-rware-tiny-2ag.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the tiny-2ag environment -name: RobotWarehouse -task_name: tiny-2ag - -task_config: - column_height: 8 - shelf_rows: 1 - shelf_columns: 3 - n_agents: 2 - sensor_range: 1 - request_queue_size: 2 - msg_bits : 0 - max_inactivity_steps : null - max_steps : 500 - reward_type : 0 - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml b/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml deleted file mode 100644 index 7753b73ec..000000000 --- a/mava/configs/env/scenario/gym-rware-tiny-4ag-easy.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the tiny-4ag-easy environment -name: RobotWarehouse -task_name: tiny-4ag-easy - -task_config: - column_height: 8 - shelf_rows: 1 - shelf_columns: 3 - n_agents: 4 - sensor_range: 1 - request_queue_size: 8 - msg_bits : 0 - max_inactivity_steps : null - max_steps : 500 - reward_type : 0 - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml b/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml deleted file mode 100644 index c28cf92c5..000000000 --- a/mava/configs/env/scenario/gym-rware-tiny-4ag.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The config of the tiny_4ag environment -name: RobotWarehouse -task_name: tiny-4ag - -task_config: - column_height: 8 - shelf_rows: 1 - shelf_columns: 3 - n_agents: 4 - sensor_range: 1 - request_queue_size: 4 - msg_bits : 0 - max_inactivity_steps : null - max_steps : 500 - reward_type : 0 - -env_kwargs: - {} # there are no scenario specific env_kwargs for this env diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 1d71ddce0..1c9e4dbd0 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -15,6 +15,7 @@ from typing import Dict, Tuple, Type import gymnasium +import gymnasium as gym import gymnasium.vector import gymnasium.wrappers import jaxmarl @@ -34,9 +35,7 @@ from jumanji.environments.routing.robot_warehouse.generator import ( RandomGenerator as RwareRandomGenerator, ) -from lbforaging.foraging import ForagingEnv as gym_ForagingEnv from omegaconf import DictConfig -from rware.warehouse import Warehouse as gym_Warehouse from mava.types import MarlEnv from mava.wrappers import ( @@ -76,8 +75,8 @@ _gigastep_registry = {"Gigastep": GigastepWrapper} _gym_registry = { - "RobotWarehouse": (gym_Warehouse, GymWrapper), - "LevelBasedForaging": (gym_ForagingEnv, GymWrapper), + "RobotWarehouse": GymWrapper, + "LevelBasedForaging": GymWrapper, } @@ -243,10 +242,11 @@ def make_gym_env( Returns: Async environments. """ - env_maker, wrapper = _gym_registry[config.env.scenario.name] + wrapper = _gym_registry[config.env.env_name] def create_gym_env(config: DictConfig, add_global_state: bool = False) -> gymnasium.Env: - env = env_maker(**config.env.scenario.task_config) + registered_name = f"{config.env.scenario.name}:{config.env.scenario.task_name}" + env = gym.make(registered_name, disable_env_checker=False) wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state) if config.env.add_agent_id: wrapped_env = GymAgentIDWrapper(wrapped_env) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 39870b211..a27b246ce 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -16,6 +16,7 @@ import traceback import warnings from dataclasses import field +from enum import IntEnum from multiprocessing import Queue from multiprocessing.connection import Connection from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union @@ -40,7 +41,7 @@ # needed to avoid host -> device transfers when calling TimeStep.last() -class StepType: +class StepType(IntEnum): """Coppy of Jumanji's step type but with numpy arrays""" FIRST = 0 @@ -53,7 +54,7 @@ class TimeStep: step_type: StepType reward: NDArray discount: NDArray - observation: Observation + observation: Union[Observation, ObservationGlobalState] extras: Dict = field(default_factory=dict) def first(self) -> bool: @@ -94,7 +95,9 @@ def __init__( def reset( self, seed: Optional[int] = None, options: Optional[dict] = None ) -> Tuple[NDArray, Dict]: - if seed is not None: + # todo: maybe we should just remove this? I think the hasattr could be slow and the + # `OrderEnforcingWrapper` blocks the seed call :/ + if seed is not None and hasattr(self.env, "seed"): self.env.seed(seed) agents_view, info = self._env.reset() From d4359c1cf6ac91415f8f3ae64a89959b4c317139 Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Thu, 17 Oct 2024 13:52:44 +0100 Subject: [PATCH 113/125] feat: learner env accumulation --- mava/configs/arch/sebulba.yaml | 1 + mava/systems/ppo/sebulba/ff_ippo.py | 31 +++++++++++++++++++++-------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index d8f44fd3c..278b0592d 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -18,6 +18,7 @@ absolute_metric: True # Whether the absolute metric should be computed. For more n_threads_per_executor: 2 # num of different threads/env batches per actor actor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices +n_learner_accumulate: 1 # Number of envoirnments to accumulate before updating the parameters. This determines the num_envs for learning updates which equals (num_envs * n_learner_accumulate) / len(learner_device_ids). rollout_queue_size : 5 # The size of the pipeline queue determines the extent of off-policy training allowed. A larger value permits more off-policy training. # Too large of a value with too many actors will lead to all of the updates getting wasted in old episodes diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index b0a74f716..a0026d95c 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -396,23 +396,38 @@ def learner_thread( with RecordTimeTo(learn_times["learner_time_per_eval"]): for _ in range(config.system.num_updates_per_eval): - # Get the trajectory batch from the pipeline - # This is blocking so it will wait until the pipeline has data. - with RecordTimeTo(learn_times["rollout_get_time"]): - traj_batch, timestep, rollout_time = pipeline.get(block=True) + # Accumulate the batches, timesteps, and rollout times + accumulated_traj_batches = [] + accumulated_timesteps = [] + + for _ in range(config.arch.n_learner_accumulate): + # Get the trajectory batch from the pipeline + # This is blocking so it will wait until the pipeline has data. + with RecordTimeTo(learn_times["rollout_get_time"]): + traj_batch, timestep, rollout_time = pipeline.get(block=True) + + # Store the retrieved data + accumulated_traj_batches.append(traj_batch) + accumulated_timesteps.append(timestep) + rollout_times.append(rollout_time) + + # Concatenate accumulated timesteps and trajectory batches on the num_envs axis + combined_traj_batch = jax.tree.map(lambda *x: jnp.concat(x, axis=2), *accumulated_traj_batches) + combined_timesteps = jax.tree.map(lambda *x: jnp.concat(x, axis=1), *accumulated_timesteps) + # Replace the timestep in the learner state with the latest timestep # This means the learner has access to the entire trajectory as well as # an additional timestep which it can use to bootstrap. - learner_state = learner_state._replace(timestep=timestep) + learner_state = learner_state._replace(timestep=combined_timesteps) # Update the networks with RecordTimeTo(learn_times["learning_time"]): learner_state, episode_metrics, train_metrics = learn_fn( - learner_state, traj_batch + learner_state, combined_traj_batch ) - + metrics.append((episode_metrics, train_metrics)) - rollout_times.append(rollout_time) + # Update all the params sources so all actors can get the latest params unreplicated_params = unreplicate(learner_state.params) From 7c784788ba6e7f59f27f8361a91c52de43bd03ed Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Thu, 17 Oct 2024 14:07:17 +0000 Subject: [PATCH 114/125] feat: jit evaluation on cpu --- mava/evaluator.py | 2 ++ mava/systems/ppo/sebulba/ff_ippo.py | 19 ++++++------------- mava/wrappers/gym.py | 6 ++---- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/mava/evaluator.py b/mava/evaluator.py index a306157ed..99d4eb8d4 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -239,6 +239,8 @@ def get_sebulba_eval_fn( episode_loops = math.ceil(eval_episodes / n_parallel_envs) env = env_maker(config, n_parallel_envs) + act_fn = jax.jit(act_fn, device=jax.devices('cpu')[0]) # cpu so that we don't block actors/learners + # Warnings if num eval episodes is not divisible by num parallel envs. if eval_episodes % n_parallel_envs != 0: warnings.warn( diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index b0a74f716..1f5aad316 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -50,7 +50,6 @@ from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics - def rollout( key: chex.PRNGKey, env, @@ -81,8 +80,6 @@ def rollout( num_agents, num_envs = config.system.num_agents, config.arch.num_envs move_to_device = lambda x: jax.device_put(x, device=actor_device) - key = move_to_device(key) - @jax.jit def act_fn( params: Params, @@ -579,6 +576,7 @@ def run_experiment(_config: DictConfig) -> float: params_sources_lifetime = ThreadLifetime() # Unfortunately we have to do this here, because creating envs inside the actor threads causes deadlocks + # todo: see what happens if we do this in the thread creating loop envs = [[] for i in range(len(actor_devices))] print( f"{Fore.BLUE}{Style.BRIGHT}Starting up environments, this may take a while...{Style.RESET_ALL}" @@ -633,7 +631,7 @@ def run_experiment(_config: DictConfig) -> float: # Acting and learning is happening in their own threads. # This loop waits for the learner to finish an update before evaluation and logging. for eval_step in range(config.arch.num_evaluation): - # Get the next set of params and metrics from the learner + # Sync with the learner - the get() is blocking so it keeps eval and learning in step. episode_metrics, train_metrics, learner_state, time_metrics = eval_queue.get() t = int(steps_per_rollout * (eval_step + 1)) @@ -653,7 +651,7 @@ def run_experiment(_config: DictConfig) -> float: unreplicated_actor_params = unreplicate(learner_state.params.actor_params) key, eval_key = jax.random.split(key, 2) - eval_metrics = evaluator(unreplicated_actor_params, eval_key, {}) + eval_metrics = evaluator(jax.device_get(unreplicated_actor_params), eval_key, {}) logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) episode_return = np.mean(eval_metrics["episode_return"]) @@ -685,23 +683,18 @@ def run_experiment(_config: DictConfig) -> float: logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) abs_metric_evaluator_envs.close() - # Stop the logger. + # Stop all the threads. logger.stop() - # Ask actors to stop before running the evaluator actor_lifetime.stop() - # We clear the pipeline before stopping the actor threads to avoid deadlock - pipe.clear() - print(f"{Fore.RED}{Style.BRIGHT}Pipe cleared: {pipe.qsize()}{Style.RESET_ALL}") - + pipe.clear() # We clear the pipeline before stopping the actor threads to avoid deadlock + print(f"{Fore.RED}{Style.BRIGHT}Pipe cleared{Style.RESET_ALL}") print(f"{Fore.RED}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}") for actor in actor_threads: actor.join() print(f"{Fore.RED}{Style.BRIGHT}{actor.name} stopped{Style.RESET_ALL}") - print(f"{Fore.RED}{Style.BRIGHT}Stopping pipeline...{Style.RESET_ALL}") pipe_lifetime.stop() pipe.join() - print(f"{Fore.RED}{Style.BRIGHT}Stopping params sources...{Style.RESET_ALL}") params_sources_lifetime.stop() for params_source in params_sources: diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index a27b246ce..048294893 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -95,10 +95,8 @@ def __init__( def reset( self, seed: Optional[int] = None, options: Optional[dict] = None ) -> Tuple[NDArray, Dict]: - # todo: maybe we should just remove this? I think the hasattr could be slow and the - # `OrderEnforcingWrapper` blocks the seed call :/ - if seed is not None and hasattr(self.env, "seed"): - self.env.seed(seed) + if seed is not None: + self.env.unwrapped.seed(seed) agents_view, info = self._env.reset() From c252ffeffa7169b378638cdd64604de29966e5e5 Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Thu, 17 Oct 2024 15:13:48 +0100 Subject: [PATCH 115/125] fix: timestep calculation with accumulation --- mava/systems/ppo/sebulba/ff_ippo.py | 2 +- mava/utils/config.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 95566efea..639ff1fe0 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -559,7 +559,7 @@ def run_experiment(_config: DictConfig) -> float: check_sebulba_config(config) steps_per_rollout = ( - config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval + config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval * config.arch.n_learner_accumulate ) # Logger setup diff --git a/mava/utils/config.py b/mava/utils/config.py index 23484311b..34a35f091 100644 --- a/mava/utils/config.py +++ b/mava/utils/config.py @@ -46,9 +46,11 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: if config.arch.architecture_name == "anakin": n_devices = len(jax.devices()) update_batch_size = config.system.update_batch_size + n_accumulate = 1 # We dont accumulate envs in anakin else: n_devices = 1 # We only use a single device's output when updating. update_batch_size = 1 + n_accumulate = config.arch.n_learner_accumulate if config.system.total_timesteps is None: config.system.num_updates = int(config.system.num_updates) @@ -58,6 +60,7 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: * config.system.rollout_length * update_batch_size * config.arch.num_envs + * n_accumulate ) else: config.system.total_timesteps = int(config.system.total_timesteps) @@ -67,6 +70,7 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: // update_batch_size // config.arch.num_envs // n_devices + // n_accumulate ) print( f"{Fore.RED}{Style.BRIGHT} Changing the number of updates " From fd7a0255d45b53691b486e39f1f59ace058a6bf7 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Thu, 17 Oct 2024 20:56:56 +0000 Subject: [PATCH 116/125] feat: shardmap almost working --- mava/systems/ppo/sebulba/ff_ippo.py | 25 +++++++++++++++++++---- mava/utils/sebulba.py | 31 ++++++++++++++--------------- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 639ff1fe0..e47a91c87 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -26,6 +26,10 @@ import jax import jax.debug import jax.numpy as jnp +from jax.sharding import Mesh, PartitionSpec as P +from jax.sharding import NamedSharding +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map import numpy as np import optax from colorama import Fore, Style @@ -409,8 +413,8 @@ def learner_thread( rollout_times.append(rollout_time) # Concatenate accumulated timesteps and trajectory batches on the num_envs axis - combined_traj_batch = jax.tree.map(lambda *x: jnp.concat(x, axis=2), *accumulated_traj_batches) - combined_timesteps = jax.tree.map(lambda *x: jnp.concat(x, axis=1), *accumulated_timesteps) + combined_traj_batch = jax.tree.map(lambda *x: jnp.concat(x, axis=0), *accumulated_traj_batches) + combined_timesteps = jax.tree.map(lambda *x: jnp.concat(x, axis=0), *accumulated_timesteps) # Replace the timestep in the learner state with the latest timestep @@ -454,6 +458,9 @@ def learner_setup( config.system.num_agents = len(action_space) config.system.num_actions = int(action_space[0].n) + devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) + mesh = Mesh(devices, axis_names=("learner_devices",)) + # PRNG keys. key, actor_key, critic_key = jax.random.split(key, 3) @@ -500,7 +507,13 @@ def learner_setup( update_fns = (actor_optim.update, critic_optim.update) learn = get_learner_step_fn(apply_fns, update_fns, config) - learn = jax.pmap(learn, axis_name="learner_devices", devices=learner_devices) + learn = jax.jit( + shard_map(learn, + mesh=mesh, + in_specs=P("learner_devices"), + out_specs=P("learner_devices")) + ) + # learn = jax.pmap(learn, axis_name="learner_devices", devices=learner_devices) # Load model from checkpoint if specified. if config.logger.checkpointing.load_model: @@ -581,8 +594,12 @@ def run_experiment(_config: DictConfig) -> float: inital_params = unreplicate(learner_state.params) # the rollout queue/ the pipe between actor and learner + # todo: return this from/pass into: learner setup + devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) + mesh = Mesh(devices, axis_names=("learner_devices",)) + sharding = NamedSharding(mesh, P("learner_devices")) pipe_lifetime = ThreadLifetime() - pipe = Pipeline(config.arch.rollout_queue_size, learner_devices, pipe_lifetime) + pipe = Pipeline(config.arch.rollout_queue_size, sharding, pipe_lifetime) pipe.start() params_sources: List[ParamsSource] = [] diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index cead3b6ba..e2c07cf79 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -21,6 +21,7 @@ import jax import jax.numpy as jnp +from jax.sharding import Sharding from colorama import Fore, Style from jax import tree from jumanji.types import TimeStep @@ -28,7 +29,7 @@ # todo: remove the ppo dependencies from mava.systems.ppo.types import Params, PPOTransition -QUEUE_PUT_TIMEOUT = 180 +QUEUE_PUT_TIMEOUT = 100 class ThreadLifetime: @@ -48,29 +49,29 @@ def stop(self) -> None: def _stack_trajectory(trajectory: List[PPOTransition]) -> PPOTransition: """Stack a list of parallel_env transitions into a single transition of shape [rollout_len, num_envs, ...].""" - return tree.map(lambda *x: jnp.stack(x, axis=0), *trajectory) # type: ignore + return tree.map(lambda *x: jnp.stack(x, axis=0).swapaxes(0, 1), *trajectory) # type: ignore # Modified from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py class Pipeline(threading.Thread): """ - The `Pipeline` shards trajectories into `learner_devices`, + The `Pipeline` shards trajectories into learner devices, ensuring trajectories are consumed in the right order to avoid being off-policy and limit the max number of samples in device memory at one time to avoid OOM issues. """ - def __init__(self, max_size: int, learner_devices: List[jax.Device], lifetime: ThreadLifetime): + def __init__(self, max_size: int, learner_sharding: Sharding, lifetime: ThreadLifetime): """ Initializes the pipeline with a maximum size and the devices to shard trajectories across. Args: max_size: The maximum number of trajectories to keep in the pipeline. - learner_devices: The devices to shard trajectories across. + learner_sharding: The sharding used for the learner's update function. lifetime: A `ThreadLifetime` which is used to stop this thread. """ super().__init__(name="Pipeline") - self.learner_devices = learner_devices + self.sharding = learner_sharding self.tickets_queue: queue.Queue = queue.Queue() self._queue: queue.Queue = queue.Queue(maxsize=max_size) self.lifetime = lifetime @@ -97,22 +98,17 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict self.tickets_queue.put((start_condition, end_condition)) start_condition.wait() # wait to be allowed to start - # [Transition(num_envs)] * rollout_len --> Transition[done=(rollout_len, num_envs, ...)] + # [Transition(num_envs)] * rollout_len -> Transition[done=(num_envs, rollout_len, ...)] traj = _stack_trajectory(traj) - # Split trajectory on the num envs axis so each learner device gets a valid full rollout - sharded_traj = jax.tree.map(lambda x: self.shard_split_playload(x, axis=1), traj) + sharded_traj, sharded_timestep = jax.device_put((traj, timestep), device=self.sharding, donate=True) - # Timestep[(num_envs, num_agents, ...), ...] --> - # [(num_envs / num_learner_devices, num_agents, ...)] * num_learner_devices - sharded_timestep = jax.tree.map(self.shard_split_playload, timestep) - - # We block on the put to ensure that actors wait for the learners to catch up. This does two - # things: + # We block on the put to ensure that actors wait for the learners to catch up. + # This does two things: # 1. It ensures that the actors don't get too far ahead of the learners, which could lead to # off-policy data. # 2. It ensures that the actors don't in a sense "waste" samples and their time by # generating samples that the learners can't consume. - # However, we put a timeout of 180 seconds to avoid deadlocks in case the learner + # However, we put a timeout of 100 seconds to avoid deadlocks in case the learner # is not consuming the data. This is a safety measure and should not be hit in normal # operation. We use a try-finally since the lock has to be released even if an exception # is raised. @@ -149,6 +145,9 @@ def clear(self) -> None: except queue.Empty: break + def shard(self, payload: Any): + ... + def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: return self.shard_payload(self.split_payload(payload, axis)) From 4013a22fc41b46b7e8e417e62f7cdb4a0e1b68c6 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Fri, 18 Oct 2024 14:17:50 +0000 Subject: [PATCH 117/125] feat: shard_map working --- mava/systems/ppo/sebulba/ff_ippo.py | 44 +++++++++++++++++------------ mava/utils/sebulba.py | 28 ++++-------------- 2 files changed, 32 insertions(+), 40 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index e47a91c87..c6e34a7db 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -109,8 +109,7 @@ def act_fn( with RecordTimeTo(actor_timings["rollout_time"]): for _ in range(config.system.rollout_length): with RecordTimeTo(actor_timings["get_params_time"]): - # Get the latest parameters from the learner - params = params_source.get() + params = params_source.get() # Get the latest parameters from the learner obs_tpu = tree.map(move_to_device, timestep.observation) @@ -320,6 +319,7 @@ def _critic_loss_fn( "actor_loss": actor_loss, "entropy": entropy, } + # todo: don't return ent key, only pass in return (new_params, new_opt_state, entropy_key), loss_info params, opt_states, traj_batch, advantages, targets, key = update_state @@ -353,6 +353,7 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) + # todo: shardmap decorator here? def learner_fn( learner_state: LearnerState, traj_batch: PPOTransition ) -> ExperimentOutput[LearnerState]: @@ -370,6 +371,9 @@ def learner_fn( - env_state (LogEnvState): The environment state. - timesteps (TimeStep): The last timestep of the rollout. """ + # This function is shard mapped on the batch axis, but `_update_step` needs + # the first axis to be time + traj_batch = tree.map(lambda x: x.swapaxes(0, 1), traj_batch) learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch) return ExperimentOutput( @@ -431,9 +435,8 @@ def learner_thread( # Update all the params sources so all actors can get the latest params - unreplicated_params = unreplicate(learner_state.params) for source in params_sources: - source.update(unreplicated_params) + source.update(learner_state.params) # Pass all the metrics and params to the main thread (evaluator) for logging and evaluation episode_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) @@ -460,6 +463,10 @@ def learner_setup( devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) mesh = Mesh(devices, axis_names=("learner_devices",)) + model_spec = P() + data_spec = P("learner_devices",) + model_sharding = NamedSharding(mesh, model_spec) # todo: return these + data_sharding = NamedSharding(mesh, data_spec) # PRNG keys. key, actor_key, critic_key = jax.random.split(key, 3) @@ -506,12 +513,15 @@ def learner_setup( apply_fns = (actor_network.apply, critic_network.apply) update_fns = (actor_optim.update, critic_optim.update) + learn_state_spec = LearnerState(model_spec, model_spec, model_spec, None, data_spec) learn = get_learner_step_fn(apply_fns, update_fns, config) learn = jax.jit( - shard_map(learn, - mesh=mesh, - in_specs=P("learner_devices"), - out_specs=P("learner_devices")) + shard_map( + learn, + mesh=mesh, + in_specs=(learn_state_spec, data_spec), + out_specs=ExperimentOutput(learn_state_spec, data_spec, data_spec), + ) ) # learn = jax.pmap(learn, axis_name="learner_devices", devices=learner_devices) @@ -529,13 +539,11 @@ def learner_setup( # Define params to be replicated across devices and batches. key, step_keys = jax.random.split(key) opt_states = OptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states, step_keys) # Duplicate learner across Learner devices. - replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=learner_devices) + params, opt_states, step_keys = jax.device_put((params, opt_states, step_keys), model_sharding) # Initialise learner state. - params, opt_states, step_keys = replicate_learner init_learner_state = LearnerState(params, opt_states, step_keys, None, None) env.close() @@ -591,7 +599,7 @@ def run_experiment(_config: DictConfig) -> float: ) # Executor setup and launch. - inital_params = unreplicate(learner_state.params) + inital_params = jax.device_put(learner_state.params, actor_devices[0]) # unreplicate # the rollout queue/ the pipe between actor and learner # todo: return this from/pass into: learner setup @@ -657,7 +665,7 @@ def run_experiment(_config: DictConfig) -> float: ).start() max_episode_return = -np.inf - best_params = inital_params.actor_params + best_params_cpu = jax.device_get(inital_params.actor_params) # This is the main loop, all it does is evaluation and logging. # Acting and learning is happening in their own threads. @@ -681,9 +689,9 @@ def run_experiment(_config: DictConfig) -> float: ) / time_metrics["learner_time_per_eval"] logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) - unreplicated_actor_params = unreplicate(learner_state.params.actor_params) + learner_state_cpu = jax.device_get(learner_state) key, eval_key = jax.random.split(key, 2) - eval_metrics = evaluator(jax.device_get(unreplicated_actor_params), eval_key, {}) + eval_metrics = evaluator(learner_state_cpu.params.actor_params, eval_key, {}) logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) episode_return = np.mean(eval_metrics["episode_return"]) @@ -691,12 +699,12 @@ def run_experiment(_config: DictConfig) -> float: if save_checkpoint: # Save a checkpoint of the learner state checkpointer.save( timestep=steps_per_rollout * (eval_step + 1), - unreplicated_learner_state=unreplicate_n_dims(learner_state), + unreplicated_learner_state=learner_state_cpu, episode_return=episode_return, ) if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(unreplicated_actor_params) + best_params_cpu = copy.deepcopy(learner_state_cpu.params.actor_params) max_episode_return = episode_return evaluator_envs.close() @@ -709,7 +717,7 @@ def run_experiment(_config: DictConfig) -> float: environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=True ) key, eval_key = jax.random.split(key, 2) - eval_metrics = abs_metric_evaluator(best_params, eval_key, {}) + eval_metrics = abs_metric_evaluator(best_params_cpu, eval_key, {}) t = int(steps_per_rollout * (eval_step + 1)) logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index e2c07cf79..4b1b9f758 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -102,16 +102,13 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict traj = _stack_trajectory(traj) sharded_traj, sharded_timestep = jax.device_put((traj, timestep), device=self.sharding, donate=True) - # We block on the put to ensure that actors wait for the learners to catch up. - # This does two things: - # 1. It ensures that the actors don't get too far ahead of the learners, which could lead to - # off-policy data. - # 2. It ensures that the actors don't in a sense "waste" samples and their time by - # generating samples that the learners can't consume. + # We block on the `put` to ensure that actors wait for the learners to catch up. + # This ensures two things: + # The actors don't get too far ahead of the learners, which could lead to off-policy data. + # The actors don't "waste" samples by generating samples that the learners can't consume. # However, we put a timeout of 100 seconds to avoid deadlocks in case the learner - # is not consuming the data. This is a safety measure and should not be hit in normal - # operation. We use a try-finally since the lock has to be released even if an exception - # is raised. + # is not consuming the data. This is a safety measure and should not occur in normal + # operation. We use a try-finally so the lock is released even if an exception is raised. try: self._queue.put( (sharded_traj, sharded_timestep, time_dict), @@ -145,19 +142,6 @@ def clear(self) -> None: except queue.Empty: break - def shard(self, payload: Any): - ... - - def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: - return self.shard_payload(self.split_payload(payload, axis)) - - @partial(jax.jit, static_argnums=(0, 2)) - def split_payload(self, payload: Any, axis: int = 0): - return jnp.split(payload, len(self.learner_devices), axis=axis) - - def shard_payload(self, payload: Any): - return jax.device_put_sharded(payload, devices=self.learner_devices) - class ParamsSource(threading.Thread): """A `ParamSource` is a component that allows networks params to be passed from a From 0e559d99e7deb4c3e1b56745f3cabc447516d103 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Sat, 19 Oct 2024 13:56:55 +0200 Subject: [PATCH 118/125] fix: key use in actor loss --- mava/systems/ppo/sebulba/ff_ippo.py | 32 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index c6e34a7db..a139fb77c 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -21,21 +21,19 @@ from typing import Any, Dict, List, Sequence, Tuple import chex -import flax import hydra import jax import jax.debug import jax.numpy as jnp -from jax.sharding import Mesh, PartitionSpec as P -from jax.sharding import NamedSharding -from jax.experimental import mesh_utils -from jax.experimental.shard_map import shard_map import numpy as np import optax from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict -from flax.jax_utils import unreplicate from jax import tree +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint @@ -44,19 +42,20 @@ from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition -from mava.types import ActorApply, CriticApply, ExperimentOutput, Observation, SebulbaLearnerFn +from mava.types import ActorApply, CriticApply, ExperimentOutput, MarlEnv, Observation, SebulbaLearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_sebulba_config, check_total_timesteps -from mava.utils.jax_utils import merge_leading_dims, unreplicate_n_dims +from mava.utils.jax_utils import merge_leading_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics + def rollout( key: chex.PRNGKey, - env, + env: MarlEnv, config: DictConfig, rollout_queue: Pipeline, params_source: ParamsSource, @@ -319,8 +318,7 @@ def _critic_loss_fn( "actor_loss": actor_loss, "entropy": entropy, } - # todo: don't return ent key, only pass in - return (new_params, new_opt_state, entropy_key), loss_info + return (new_params, new_opt_state, key), loss_info params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) @@ -335,7 +333,7 @@ def _critic_loss_fn( shuffled_batch, ) # Update minibatches - (params, opt_states, entropy_key), loss_info = jax.lax.scan( + (params, opt_states, _), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches ) @@ -430,9 +428,9 @@ def learner_thread( learner_state, episode_metrics, train_metrics = learn_fn( learner_state, combined_traj_batch ) - + metrics.append((episode_metrics, train_metrics)) - + # Update all the params sources so all actors can get the latest params for source in params_sources: @@ -517,9 +515,9 @@ def learner_setup( learn = get_learner_step_fn(apply_fns, update_fns, config) learn = jax.jit( shard_map( - learn, - mesh=mesh, - in_specs=(learn_state_spec, data_spec), + learn, + mesh=mesh, + in_specs=(learn_state_spec, data_spec), out_specs=ExperimentOutput(learn_state_spec, data_spec, data_spec), ) ) From 0a6bd49beb37d9e79896faef6b5abbaba2612c0e Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Sat, 19 Oct 2024 13:58:01 +0200 Subject: [PATCH 119/125] fix: align gym config with other configs --- mava/configs/env/lbf_gym.yaml | 9 +++++---- mava/configs/env/rware_gym.yaml | 9 +++++---- mava/utils/make_env.py | 3 ++- mava/utils/sebulba.py | 5 ++--- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml index 39d624daa..a7fa1be89 100644 --- a/mava/configs/env/lbf_gym.yaml +++ b/mava/configs/env/lbf_gym.yaml @@ -11,10 +11,11 @@ scenario: # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return -# Whether the add agents IDs to the observations returned by the environment. -add_agent_id: False - -# Whether or not to log the winrate of this environment. +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. log_win_rate: False # Weather or not to sum the returned rewards over all of the agents. diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml index da8c73402..d3d6a49b2 100644 --- a/mava/configs/env/rware_gym.yaml +++ b/mava/configs/env/rware_gym.yaml @@ -11,10 +11,11 @@ scenario: # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return -# Whether the add agents IDs to the observations returned by the environment. -add_agent_id : False - -# Whether or not to log the winrate of this environment. +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. log_win_rate: False # Weather or not to sum the returned rewards over all of the agents. diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 1c9e4dbd0..8b9c85afd 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -243,12 +243,13 @@ def make_gym_env( Async environments. """ wrapper = _gym_registry[config.env.env_name] + config.system.add_agent_id = config.system.add_agent_id & (~config.env.implicit_agent_id) def create_gym_env(config: DictConfig, add_global_state: bool = False) -> gymnasium.Env: registered_name = f"{config.env.scenario.name}:{config.env.scenario.task_name}" env = gym.make(registered_name, disable_env_checker=False) wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state) - if config.env.add_agent_id: + if config.system.add_agent_id: wrapped_env = GymAgentIDWrapper(wrapped_env) wrapped_env = GymRecordEpisodeMetrics(wrapped_env) return wrapped_env diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index 4b1b9f758..4083155d5 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -16,14 +16,13 @@ import queue import threading import time -from functools import partial from typing import Any, Dict, List, Sequence, Tuple, Union import jax import jax.numpy as jnp -from jax.sharding import Sharding from colorama import Fore, Style from jax import tree +from jax.sharding import Sharding from jumanji.types import TimeStep # todo: remove the ppo dependencies @@ -102,7 +101,7 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict traj = _stack_trajectory(traj) sharded_traj, sharded_timestep = jax.device_put((traj, timestep), device=self.sharding, donate=True) - # We block on the `put` to ensure that actors wait for the learners to catch up. + # We block on the `put` to ensure that actors wait for the learners to catch up. # This ensures two things: # The actors don't get too far ahead of the learners, which could lead to off-policy data. # The actors don't "waste" samples by generating samples that the learners can't consume. From 641a548905455874959e9e84a100449d7f24a064 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Sat, 19 Oct 2024 14:54:08 +0200 Subject: [PATCH 120/125] feat: better env creation and safer sharding --- mava/systems/ppo/sebulba/ff_ippo.py | 93 ++++++++++++++--------------- mava/utils/jax_utils.py | 3 +- mava/utils/sebulba.py | 12 ++-- mava/wrappers/jaxmarl.py | 1 - 4 files changed, 52 insertions(+), 57 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index a139fb77c..2312fb023 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -26,13 +26,14 @@ import jax.debug import jax.numpy as jnp import numpy as np +from numpy.typing import NDArray import optax from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict from jax import tree from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map -from jax.sharding import Mesh, NamedSharding +from jax.sharding import Mesh, NamedSharding, Sharding from jax.sharding import PartitionSpec as P from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint @@ -42,11 +43,18 @@ from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition -from mava.types import ActorApply, CriticApply, ExperimentOutput, MarlEnv, Observation, SebulbaLearnerFn +from mava.types import ( + ActorApply, + CriticApply, + ExperimentOutput, + MarlEnv, + Observation, + SebulbaLearnerFn, +) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_sebulba_config, check_total_timesteps -from mava.utils.jax_utils import merge_leading_dims +from mava.utils.jax_utils import merge_leading_dims, switch_leading_axes from mava.utils.logger import LogEvent, MavaLogger from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime from mava.utils.training import make_learning_rate @@ -351,7 +359,6 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - # todo: shardmap decorator here? def learner_fn( learner_state: LearnerState, traj_batch: PPOTransition ) -> ExperimentOutput[LearnerState]: @@ -371,7 +378,7 @@ def learner_fn( """ # This function is shard mapped on the batch axis, but `_update_step` needs # the first axis to be time - traj_batch = tree.map(lambda x: x.swapaxes(0, 1), traj_batch) + traj_batch = tree.map(switch_leading_axes, traj_batch) learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch) return ExperimentOutput( @@ -403,6 +410,7 @@ def learner_thread( accumulated_traj_batches = [] accumulated_timesteps = [] + # Possibly get many rollouts for 1 learn step - allows learning with large batches for _ in range(config.arch.n_learner_accumulate): # Get the trajectory batch from the pipeline # This is blocking so it will wait until the pipeline has data. @@ -414,43 +422,42 @@ def learner_thread( accumulated_timesteps.append(timestep) rollout_times.append(rollout_time) - # Concatenate accumulated timesteps and trajectory batches on the num_envs axis - combined_traj_batch = jax.tree.map(lambda *x: jnp.concat(x, axis=0), *accumulated_traj_batches) - combined_timesteps = jax.tree.map(lambda *x: jnp.concat(x, axis=0), *accumulated_timesteps) - + # Concatenate the accumulated timesteps and trajectory batches on the num_envs axis + traj_batches = tree.map(lambda *x: jnp.concat(x, axis=0), *accumulated_traj_batches) + timesteps = tree.map(lambda *x: jnp.concat(x, axis=0), *accumulated_timesteps) # Replace the timestep in the learner state with the latest timestep # This means the learner has access to the entire trajectory as well as # an additional timestep which it can use to bootstrap. - learner_state = learner_state._replace(timestep=combined_timesteps) + learner_state = learner_state._replace(timestep=timesteps) # Update the networks with RecordTimeTo(learn_times["learning_time"]): - learner_state, episode_metrics, train_metrics = learn_fn( - learner_state, combined_traj_batch - ) - - metrics.append((episode_metrics, train_metrics)) + learner_state, ep_metrics, train_metrics = learn_fn(learner_state, traj_batches) + metrics.append((ep_metrics, train_metrics)) # Update all the params sources so all actors can get the latest params for source in params_sources: source.update(learner_state.params) # Pass all the metrics and params to the main thread (evaluator) for logging and evaluation - episode_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) - rollout_times = tree.map(lambda *x: np.mean(x), *rollout_times) + ep_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) + rollout_times: Dict[str, NDArray] = tree.map(lambda *x: np.mean(x), *rollout_times) timing_dict = rollout_times | learn_times timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) - eval_queue.put((episode_metrics, train_metrics, learner_state, timing_dict)) + eval_queue.put((ep_metrics, train_metrics, learner_state, timing_dict)) def learner_setup( key: chex.PRNGKey, config: DictConfig, learner_devices: List ) -> Tuple[ - SebulbaLearnerFn[LearnerState, PPOTransition], Tuple[ActorApply, CriticApply], LearnerState + SebulbaLearnerFn[LearnerState, PPOTransition], + Tuple[ActorApply, CriticApply], + LearnerState, + Sharding, ]: - """Initialise learner_fn, network, optimiser, environment and states.""" + """Initialise learner_fn, network and learner state.""" # create temporory envoirnments. env = environments.make_gym_env(config, config.arch.num_envs) @@ -462,9 +469,8 @@ def learner_setup( devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) mesh = Mesh(devices, axis_names=("learner_devices",)) model_spec = P() - data_spec = P("learner_devices",) - model_sharding = NamedSharding(mesh, model_spec) # todo: return these - data_sharding = NamedSharding(mesh, data_spec) + data_spec = P("learner_devices") + learner_sharding = NamedSharding(mesh, model_spec) # PRNG keys. key, actor_key, critic_key = jax.random.split(key, 3) @@ -511,6 +517,7 @@ def learner_setup( apply_fns = (actor_network.apply, critic_network.apply) update_fns = (actor_optim.update, critic_optim.update) + # defines how the learner state is sharded: params, opt and key = replicated, timestep = sharded learn_state_spec = LearnerState(model_spec, model_spec, model_spec, None, data_spec) learn = get_learner_step_fn(apply_fns, update_fns, config) learn = jax.jit( @@ -521,7 +528,6 @@ def learner_setup( out_specs=ExperimentOutput(learn_state_spec, data_spec, data_spec), ) ) - # learn = jax.pmap(learn, axis_name="learner_devices", devices=learner_devices) # Load model from checkpoint if specified. if config.logger.checkpointing.load_model: @@ -539,13 +545,15 @@ def learner_setup( opt_states = OptStates(actor_opt_state, critic_opt_state) # Duplicate learner across Learner devices. - params, opt_states, step_keys = jax.device_put((params, opt_states, step_keys), model_sharding) + params, opt_states, step_keys = jax.device_put( + (params, opt_states, step_keys), learner_sharding + ) # Initialise learner state. - init_learner_state = LearnerState(params, opt_states, step_keys, None, None) + init_learner_state = LearnerState(params, opt_states, step_keys, None, None) # type: ignore env.close() - return learn, apply_fns, init_learner_state + return learn, apply_fns, init_learner_state, learner_sharding # type: ignore def run_experiment(_config: DictConfig) -> float: @@ -564,7 +572,7 @@ def run_experiment(_config: DictConfig) -> float: np_rng = np.random.default_rng(config.system.seed) # Setup learner. - learn, apply_fns, learner_state = learner_setup(key, config, learner_devices) + learn, apply_fns, learner_state, learner_sharding = learner_setup(key, config, learner_devices) # Setup evaluator. # One key per device for evaluation. @@ -578,7 +586,10 @@ def run_experiment(_config: DictConfig) -> float: check_sebulba_config(config) steps_per_rollout = ( - config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval * config.arch.n_learner_accumulate + config.system.rollout_length + * config.arch.num_envs + * config.system.num_updates_per_eval + * config.arch.n_learner_accumulate ) # Logger setup @@ -600,12 +611,8 @@ def run_experiment(_config: DictConfig) -> float: inital_params = jax.device_put(learner_state.params, actor_devices[0]) # unreplicate # the rollout queue/ the pipe between actor and learner - # todo: return this from/pass into: learner setup - devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) - mesh = Mesh(devices, axis_names=("learner_devices",)) - sharding = NamedSharding(mesh, P("learner_devices")) pipe_lifetime = ThreadLifetime() - pipe = Pipeline(config.arch.rollout_queue_size, sharding, pipe_lifetime) + pipe = Pipeline(config.arch.rollout_queue_size, learner_sharding, pipe_lifetime) pipe.start() params_sources: List[ParamsSource] = [] @@ -613,20 +620,9 @@ def run_experiment(_config: DictConfig) -> float: actor_lifetime = ThreadLifetime() params_sources_lifetime = ThreadLifetime() - # Unfortunately we have to do this here, because creating envs inside the actor threads causes deadlocks - # todo: see what happens if we do this in the thread creating loop - envs = [[] for i in range(len(actor_devices))] - print( - f"{Fore.BLUE}{Style.BRIGHT}Starting up environments, this may take a while...{Style.RESET_ALL}" - ) - for i in range(len(actor_devices)): - for _ in range(config.arch.n_threads_per_executor): - env = environments.make_gym_env(config, config.arch.num_envs) - envs[i].append(env) - print(f"{Fore.BLUE}{Style.BRIGHT}All environments created{Style.RESET_ALL}") - # Create the actor threads - for dev_idx, actor_device in enumerate(actor_devices): + print(f"{Fore.BLUE}{Style.BRIGHT}Starting up actor threads...{Style.RESET_ALL}") + for actor_device in actor_devices: # Create 1 params source per device params_source = ParamsSource(inital_params, actor_device, params_sources_lifetime) params_source.start() @@ -641,7 +637,8 @@ def run_experiment(_config: DictConfig) -> float: target=rollout, args=( act_key, - envs[dev_idx][thread_id], + # We have to do this here, creating envs inside actor threads causes deadlocks + environments.make_gym_env(config, config.arch.num_envs), config, pipe, params_source, diff --git a/mava/utils/jax_utils.py b/mava/utils/jax_utils.py index 3c03455f2..c89c6a4a4 100644 --- a/mava/utils/jax_utils.py +++ b/mava/utils/jax_utils.py @@ -71,5 +71,4 @@ def unreplicate_batch_dim(x: Any) -> Any: def switch_leading_axes(arr: chex.Array) -> chex.Array: """Switches the first two axes, generally used for BT -> TB.""" - arr = tree.map(lambda x: jax.numpy.swapaxes(x, 0, 1), arr) - return arr + return tree.map(lambda x: x.swapaxes(0, 1), arr) diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index 4083155d5..8fffe4d48 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -25,7 +25,7 @@ from jax.sharding import Sharding from jumanji.types import TimeStep -# todo: remove the ppo dependencies +# todo: remove the ppo dependencies when we make sebulba for other systems from mava.systems.ppo.types import Params, PPOTransition QUEUE_PUT_TIMEOUT = 100 @@ -99,22 +99,22 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict # [Transition(num_envs)] * rollout_len -> Transition[done=(num_envs, rollout_len, ...)] traj = _stack_trajectory(traj) - sharded_traj, sharded_timestep = jax.device_put((traj, timestep), device=self.sharding, donate=True) + traj, timestep = jax.device_put((traj, timestep), device=self.sharding, donate=True) # We block on the `put` to ensure that actors wait for the learners to catch up. # This ensures two things: # The actors don't get too far ahead of the learners, which could lead to off-policy data. # The actors don't "waste" samples by generating samples that the learners can't consume. # However, we put a timeout of 100 seconds to avoid deadlocks in case the learner - # is not consuming the data. This is a safety measure and should not occur in normal - # operation. We use a try-finally so the lock is released even if an exception is raised. + # is not consuming the data. This is a safety measure and should not normally occur. + # We use a try-finally so the lock is released even if an exception is raised. try: self._queue.put( - (sharded_traj, sharded_timestep, time_dict), + (traj, timestep, time_dict), block=True, timeout=QUEUE_PUT_TIMEOUT, ) - except queue.Full: # todo: check if this is needed because we catch this exception outside + except queue.Full: print( f"{Fore.RED}{Style.BRIGHT}Pipeline is full and actor has timed out, " f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}" diff --git a/mava/wrappers/jaxmarl.py b/mava/wrappers/jaxmarl.py index 72608f85f..f6ad51558 100644 --- a/mava/wrappers/jaxmarl.py +++ b/mava/wrappers/jaxmarl.py @@ -214,7 +214,6 @@ def reset( def step( self, state: JaxMarlState, action: Array ) -> Tuple[JaxMarlState, TimeStep[Union[Observation, ObservationGlobalState]]]: - # todo: how do you know if it's a truncation with only dones? key, step_key = jax.random.split(state.key) obs, env_state, reward, done, _ = self._env.step( step_key, state.state, unbatchify(action, self.agents) From c0c88bc2b782d05a7b1b2d2fbdfe552fec9d14f9 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Sat, 19 Oct 2024 15:09:28 +0200 Subject: [PATCH 121/125] chore: minor env typing fixes --- mava/systems/ppo/sebulba/ff_ippo.py | 7 ++++--- mava/wrappers/gym.py | 20 +++++++++----------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 2312fb023..35a5d86ec 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -47,7 +47,6 @@ ActorApply, CriticApply, ExperimentOutput, - MarlEnv, Observation, SebulbaLearnerFn, ) @@ -59,11 +58,12 @@ from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics +from mava.wrappers.gym import GymToJumanji def rollout( key: chex.PRNGKey, - env: MarlEnv, + env: GymToJumanji, config: DictConfig, rollout_queue: Pipeline, params_source: ParamsSource, @@ -101,7 +101,8 @@ def act_fn( actor_policy = actor_apply_fn(params.actor_params, observation) action = actor_policy.sample(seed=key) log_prob = actor_policy.log_prob(action) - + # It may be faster to calculate the values in the learner as + # then we won't need to pass critic params to actors. value = critic_apply_fn(params.critic_params, observation).squeeze() return action, log_prob, value diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 048294893..fa42e5e82 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -29,7 +29,7 @@ from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray -from mava.types import Observation, ObservationGlobalState +from mava.types import MarlEnv, Observation, ObservationGlobalState if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 from dataclasses import dataclass @@ -217,19 +217,17 @@ def modify_space(self, space: spaces.Space) -> spaces.Space: class GymToJumanji: - """Converts from the Gym API to the dm_env API, using Jumanji's Timestep type.""" + """Converts from the Gym API to the dm_env API.""" - def __init__(self, env: gymnasium.vector.async_vector_env): + def __init__(self, env: gymnasium.vector.VectorEnv): self.env = env self.single_action_space = env.unwrapped.single_action_space self.single_observation_space = env.unwrapped.single_observation_space - def reset( - self, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None - ) -> TimeStep: - obs, info = self.env.reset(seed=seed, options=options) + def reset(self, seed: Optional[list[int]] = None, options: Optional[dict] = None) -> TimeStep: + obs, info = self.env.reset(seed=seed, options=options) # type: ignore - num_agents = len(self.env.single_action_space) + num_agents = len(self.env.single_action_space) # type: ignore num_envs = self.env.num_envs ep_done = np.zeros(num_envs, dtype=float) @@ -269,16 +267,16 @@ def _format_observation( def _create_timestep( self, obs: NDArray, ep_done: NDArray, terminated: NDArray, rewards: NDArray, info: Dict ) -> TimeStep: - obs = self._format_observation(obs, info) + observation = self._format_observation(obs, info) # Filter out the masks and auxiliary data extras = {key: value for key, value in info["metrics"].items() if key[0] != "_"} step_type = np.where(ep_done, StepType.LAST, StepType.MID) return TimeStep( - step_type=step_type, + step_type=step_type, # type: ignore reward=rewards, discount=1.0 - terminated, - observation=obs, + observation=observation, extras=extras, ) From 6b2d01c2fc854be4342c4049d88e7b79397894cd Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Mon, 21 Oct 2024 11:11:09 +0100 Subject: [PATCH 122/125] fix: start actors simultaneously to avoid deadlocks --- mava/systems/ppo/sebulba/ff_ippo.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 35a5d86ec..971088a97 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -650,8 +650,11 @@ def run_experiment(_config: DictConfig) -> float: ), name=f"Actor-{actor_device}-{thread_id}", ) - actor.start() actor_threads.append(actor) + + # Start the actors simultaneously + for actor in actor_threads: + actor.start() eval_queue: Queue = Queue() threading.Thread( From a13ab65cd4cb4aaa5d54643c1b01c989800023b9 Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Wed, 23 Oct 2024 14:05:53 +0100 Subject: [PATCH 123/125] feat: support for smac --- mava/configs/default/ff_ippo_sebulba.yaml | 2 +- mava/configs/env/lbf_gym.yaml | 3 +++ mava/configs/env/rware_gym.yaml | 3 +++ mava/configs/env/smac_gym.yaml | 25 +++++++++++++++++++++++ mava/utils/make_env.py | 4 +++- mava/utils/sebulba.py | 2 +- mava/wrappers/__init__.py | 1 + mava/wrappers/gym.py | 15 +++++++++++++- requirements/requirements.txt | 1 + 9 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 mava/configs/env/smac_gym.yaml diff --git a/mava/configs/default/ff_ippo_sebulba.yaml b/mava/configs/default/ff_ippo_sebulba.yaml index 7669049b1..cc2b4acae 100644 --- a/mava/configs/default/ff_ippo_sebulba.yaml +++ b/mava/configs/default/ff_ippo_sebulba.yaml @@ -3,7 +3,7 @@ defaults: - arch: sebulba - system: ppo/ff_ippo - network: mlp # [mlp, continuous_mlp, cnn] - - env: rware_gym # [rware_gym, lbf_gym] + - env: smac_gym # [rware_gym, lbf_gym, smac_gym] - _self_ hydra: diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml index a7fa1be89..7ae03d010 100644 --- a/mava/configs/env/lbf_gym.yaml +++ b/mava/configs/env/lbf_gym.yaml @@ -20,3 +20,6 @@ log_win_rate: False # Weather or not to sum the returned rewards over all of the agents. use_shared_rewards: True + +kwargs: + max_episode_steps: 100 \ No newline at end of file diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml index d3d6a49b2..0fcd41a2b 100644 --- a/mava/configs/env/rware_gym.yaml +++ b/mava/configs/env/rware_gym.yaml @@ -20,3 +20,6 @@ log_win_rate: False # Weather or not to sum the returned rewards over all of the agents. use_shared_rewards: True + +kwargs: + max_episode_steps: 500 \ No newline at end of file diff --git a/mava/configs/env/smac_gym.yaml b/mava/configs/env/smac_gym.yaml new file mode 100644 index 000000000..a4d8b7031 --- /dev/null +++ b/mava/configs/env/smac_gym.yaml @@ -0,0 +1,25 @@ +# ---Environment Configs--- +defaults: + - _self_ + +env_name: Starcraft # Used for logging purposes. +scenario: + name: smaclite + task_name: smaclite/2s3z-v0 # smaclite/ + ['10m_vs_11m-v0', '27m_vs_30m-v0', '3s5z_vs_3s6z-v0', '2s3z-v0', '3s5z-v0', 'MMM-v0', 'MMM2-v0', '2c_vs_64zg-v0', 'bane_vs_bane-v0', 'corridor-v0', '2s_vs_1sc-v0', '3s_vs_5z-v0'] + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. +log_win_rate: False + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True + +kwargs: + max_episode_steps: 500 \ No newline at end of file diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 8b9c85afd..32a85155c 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -54,6 +54,7 @@ RecordEpisodeMetrics, RwareWrapper, SmaxWrapper, + SmacWrapper, async_multiagent_worker, ) from mava.wrappers.jaxmarl import JaxMarlWrapper @@ -77,6 +78,7 @@ _gym_registry = { "RobotWarehouse": GymWrapper, "LevelBasedForaging": GymWrapper, + "Starcraft": SmacWrapper, } @@ -247,7 +249,7 @@ def make_gym_env( def create_gym_env(config: DictConfig, add_global_state: bool = False) -> gymnasium.Env: registered_name = f"{config.env.scenario.name}:{config.env.scenario.task_name}" - env = gym.make(registered_name, disable_env_checker=False) + env = gym.make(registered_name, disable_env_checker=False, **config.env.kwargs) wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state) if config.system.add_agent_id: wrapped_env = GymAgentIDWrapper(wrapped_env) diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index 8fffe4d48..cab1ddd0e 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -99,7 +99,7 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict # [Transition(num_envs)] * rollout_len -> Transition[done=(num_envs, rollout_len, ...)] traj = _stack_trajectory(traj) - traj, timestep = jax.device_put((traj, timestep), device=self.sharding, donate=True) + traj, timestep = jax.device_put((traj, timestep), device=self.sharding) # We block on the `put` to ensure that actors wait for the learners to catch up. # This ensures two things: diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index f8cf8a64c..f7e89d756 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -21,6 +21,7 @@ GymRecordEpisodeMetrics, GymToJumanji, GymWrapper, + SmacWrapper, async_multiagent_worker, ) from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index fa42e5e82..aa64e2755 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -106,7 +106,7 @@ def reset( return np.array(agents_view), info - def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + def step(self, actions: Tuple) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: agents_view, reward, terminated, truncated, info = self._env.step(actions) info = {"actions_mask": self.get_actions_mask(info)} @@ -128,7 +128,20 @@ def get_actions_mask(self, info: Dict) -> NDArray: def get_global_obs(self, obs: NDArray) -> NDArray: global_obs = np.concatenate(obs, axis=0) return np.tile(global_obs, (self.num_agents, 1)) + +class SmacWrapper(GymWrapper): + """A wrapper that converts actions step to integers.""" + + def step(self, actions: Tuple) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + # Convert actions to integers before passing them to the environment + actions = [int(action) for action in actions] + + agents_view, reward, terminated, truncated, info = super().step(actions) + return agents_view, reward, terminated, truncated, info + + def get_actions_mask(self, info: Dict) -> NDArray: + return np.array(self._env.unwrapped.get_avail_actions()) class GymRecordEpisodeMetrics(gymnasium.Wrapper): """Record the episode returns and lengths.""" diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 61f7fe68a..5522b2e82 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -25,3 +25,4 @@ scipy==1.12.0 tensorboard_logger tensorflow_probability type_enforced # needed because gigastep is missing this dependency +smaclite @ git+https://github.com/uoe-agents/smaclite.git \ No newline at end of file From bc55375a399c6a8ba2ac702a31791214e4026cd6 Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Wed, 23 Oct 2024 14:43:55 +0100 Subject: [PATCH 124/125] chore: pre-commits --- mava/configs/env/lbf_gym.yaml | 2 +- mava/configs/env/rware_gym.yaml | 2 +- mava/configs/env/smac_gym.yaml | 2 +- mava/evaluator.py | 4 +++- mava/systems/ppo/sebulba/ff_ippo.py | 17 ++++++++--------- mava/utils/config.py | 2 +- mava/utils/make_env.py | 2 +- mava/wrappers/gym.py | 14 ++++++++------ requirements/requirements.txt | 2 +- 9 files changed, 25 insertions(+), 22 deletions(-) diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml index 7ae03d010..f001e0913 100644 --- a/mava/configs/env/lbf_gym.yaml +++ b/mava/configs/env/lbf_gym.yaml @@ -22,4 +22,4 @@ log_win_rate: False use_shared_rewards: True kwargs: - max_episode_steps: 100 \ No newline at end of file + max_episode_steps: 100 diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml index 0fcd41a2b..facf7f8d7 100644 --- a/mava/configs/env/rware_gym.yaml +++ b/mava/configs/env/rware_gym.yaml @@ -22,4 +22,4 @@ log_win_rate: False use_shared_rewards: True kwargs: - max_episode_steps: 500 \ No newline at end of file + max_episode_steps: 500 diff --git a/mava/configs/env/smac_gym.yaml b/mava/configs/env/smac_gym.yaml index a4d8b7031..1f2f48c89 100644 --- a/mava/configs/env/smac_gym.yaml +++ b/mava/configs/env/smac_gym.yaml @@ -22,4 +22,4 @@ log_win_rate: False use_shared_rewards: True kwargs: - max_episode_steps: 500 \ No newline at end of file + max_episode_steps: 500 diff --git a/mava/evaluator.py b/mava/evaluator.py index 99d4eb8d4..8e4dd5dee 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -239,7 +239,9 @@ def get_sebulba_eval_fn( episode_loops = math.ceil(eval_episodes / n_parallel_envs) env = env_maker(config, n_parallel_envs) - act_fn = jax.jit(act_fn, device=jax.devices('cpu')[0]) # cpu so that we don't block actors/learners + act_fn = jax.jit( + act_fn, device=jax.devices("cpu")[0] + ) # cpu so that we don't block actors/learners # Warnings if num eval episodes is not divisible by num parallel envs. if eval_episodes % n_parallel_envs != 0: diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 971088a97..2ab554f69 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -26,15 +26,14 @@ import jax.debug import jax.numpy as jnp import numpy as np -from numpy.typing import NDArray import optax from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict from jax import tree from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map -from jax.sharding import Mesh, NamedSharding, Sharding -from jax.sharding import PartitionSpec as P +from jax.sharding import Mesh, NamedSharding, PartitionSpec, Sharding +from numpy.typing import NDArray from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint @@ -165,7 +164,7 @@ def get_learner_step_fn( ) -> SebulbaLearnerFn[LearnerState, PPOTransition]: """Get the learner function.""" - num_agents, num_envs = config.system.num_agents, config.arch.num_envs + num_envs = config.arch.num_envs num_learner_envs = int(num_envs // len(config.arch.learner_device_ids)) # Get apply and update functions for actor and critic networks. @@ -469,8 +468,8 @@ def learner_setup( devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) mesh = Mesh(devices, axis_names=("learner_devices",)) - model_spec = P() - data_spec = P("learner_devices") + model_spec = PartitionSpec() + data_spec = PartitionSpec("learner_devices") learner_sharding = NamedSharding(mesh, model_spec) # PRNG keys. @@ -651,8 +650,8 @@ def run_experiment(_config: DictConfig) -> float: name=f"Actor-{actor_device}-{thread_id}", ) actor_threads.append(actor) - - # Start the actors simultaneously + + # Start the actors simultaneously for actor in actor_threads: actor.start() @@ -704,7 +703,7 @@ def run_experiment(_config: DictConfig) -> float: if config.arch.absolute_metric and max_episode_return <= episode_return: best_params_cpu = copy.deepcopy(learner_state_cpu.params.actor_params) - max_episode_return = episode_return + max_episode_return = float(episode_return) evaluator_envs.close() eval_performance = float(np.mean(eval_metrics[config.env.eval_metric])) diff --git a/mava/utils/config.py b/mava/utils/config.py index 34a35f091..c82e3a315 100644 --- a/mava/utils/config.py +++ b/mava/utils/config.py @@ -46,7 +46,7 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: if config.arch.architecture_name == "anakin": n_devices = len(jax.devices()) update_batch_size = config.system.update_batch_size - n_accumulate = 1 # We dont accumulate envs in anakin + n_accumulate = 1 # We dont accumulate envs in anakin else: n_devices = 1 # We only use a single device's output when updating. update_batch_size = 1 diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 32a85155c..1206d3886 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -53,8 +53,8 @@ MatraxWrapper, RecordEpisodeMetrics, RwareWrapper, - SmaxWrapper, SmacWrapper, + SmaxWrapper, async_multiagent_worker, ) from mava.wrappers.jaxmarl import JaxMarlWrapper diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index aa64e2755..020abf158 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -19,7 +19,7 @@ from enum import IntEnum from multiprocessing import Queue from multiprocessing.connection import Connection -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import gymnasium import gymnasium.vector.async_vector_env @@ -29,7 +29,7 @@ from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray -from mava.types import MarlEnv, Observation, ObservationGlobalState +from mava.types import Observation, ObservationGlobalState if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 from dataclasses import dataclass @@ -106,7 +106,7 @@ def reset( return np.array(agents_view), info - def step(self, actions: Tuple) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: agents_view, reward, terminated, truncated, info = self._env.step(actions) info = {"actions_mask": self.get_actions_mask(info)} @@ -128,21 +128,23 @@ def get_actions_mask(self, info: Dict) -> NDArray: def get_global_obs(self, obs: NDArray) -> NDArray: global_obs = np.concatenate(obs, axis=0) return np.tile(global_obs, (self.num_agents, 1)) - + + class SmacWrapper(GymWrapper): """A wrapper that converts actions step to integers.""" - def step(self, actions: Tuple) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: # Convert actions to integers before passing them to the environment actions = [int(action) for action in actions] agents_view, reward, terminated, truncated, info = super().step(actions) return agents_view, reward, terminated, truncated, info - + def get_actions_mask(self, info: Dict) -> NDArray: return np.array(self._env.unwrapped.get_avail_actions()) + class GymRecordEpisodeMetrics(gymnasium.Wrapper): """Record the episode returns and lengths.""" diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 5522b2e82..13ff3a050 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -22,7 +22,7 @@ optax protobuf~=3.20 rware scipy==1.12.0 +smaclite @ git+https://github.com/uoe-agents/smaclite.git tensorboard_logger tensorflow_probability type_enforced # needed because gigastep is missing this dependency -smaclite @ git+https://github.com/uoe-agents/smaclite.git \ No newline at end of file From c6d460f73d9ed00cd635f2a45f99b9f946825249 Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Sun, 27 Oct 2024 16:09:04 +0100 Subject: [PATCH 125/125] fix: random segfault --- mava/systems/ppo/sebulba/ff_ippo.py | 3 ++- mava/utils/sebulba.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 2ab554f69..1869ba092 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -437,8 +437,9 @@ def learner_thread( metrics.append((ep_metrics, train_metrics)) # Update all the params sources so all actors can get the latest params + params = jax.block_until_ready(learner_state.params) for source in params_sources: - source.update(learner_state.params) + source.update(params) # Pass all the metrics and params to the main thread (evaluator) for logging and evaluation ep_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index cab1ddd0e..0e2e6261d 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -161,7 +161,7 @@ def run(self) -> None: while not self.lifetime.should_stop(): try: waiting = self.new_value.get(block=True, timeout=1) - self.value = jax.device_put(jax.block_until_ready(waiting), self.device) + self.value = jax.device_put(waiting, self.device) except queue.Empty: continue