Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set the action head automatically #1104

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions mava/advanced_usage/ff_ippo_store_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.make_env import make
from mava.utils.network_utils import get_action_head
from mava.wrappers.episode_metrics import get_final_step_metrics

StoreExpLearnerFn = Callable[[MavaState], Tuple[ExperimentOutput[MavaState], PPOTransition]]
Expand Down Expand Up @@ -351,17 +352,17 @@ def learner_setup(
n_devices = len(jax.devices())

# Get number of actions and agents.
num_actions = int(env.action_spec().num_values[0])
num_agents = env.action_spec().shape[0]
config.system.num_agents = num_agents
num_actions = env.action_dim
config.system.num_agents = env.num_agents
config.system.num_actions = num_actions

# PRNG keys.
key, key_p = keys

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(config.network.action_head, action_dim=num_actions)
action_head, _ = get_action_head(env)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

actor_network = Actor(torso=actor_torso, action_head=actor_action_head)
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/ff_ippo
- network: mlp # [mlp, continuous_mlp, cnn]
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax]
- _self_

Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_isac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- logger: logger
- arch: anakin
- system: sac/ff_isac
- network: continuous_mlp # [continuous_mlp]
- network: mlp
- env: mabrax # [mabrax]

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_mappo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/ff_mappo
- network: mlp # [mlp, continuous_mlp, cnn]
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax]
- _self_

Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_masac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- logger: logger
- arch: anakin
- system: sac/ff_masac
- network: continuous_mlp # [continuous_mlp]
- network: mlp
- env: mabrax # [mabrax]

hydra:
Expand Down
3 changes: 0 additions & 3 deletions mava/configs/network/cnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ actor_network:
use_layer_norm: False
activation: relu

action_head:
_target_: mava.networks.heads.DiscreteActionHead # [DiscreteActionHead, ContinuousActionHead]

critic_network:
pre_torso:
_target_: mava.networks.torsos.CNNTorso
Expand Down
17 changes: 0 additions & 17 deletions mava/configs/network/continuous_mlp.yaml

This file was deleted.

3 changes: 0 additions & 3 deletions mava/configs/network/mlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ actor_network:
use_layer_norm: False
activation: relu

action_head:
_target_: mava.networks.heads.DiscreteActionHead # [DiscreteActionHead, ContinuousActionHead]

critic_network:
pre_torso:
_target_: mava.networks.torsos.MLPTorso
Expand Down
3 changes: 0 additions & 3 deletions mava/configs/network/rcnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ actor_network:
use_layer_norm: False
activation: relu

action_head:
_target_: mava.networks.heads.DiscreteActionHead # [DiscreteActionHead, ContinuousActionHead]

critic_network:
pre_torso:
_target_: mava.networks.torsos.CNNTorso
Expand Down
3 changes: 0 additions & 3 deletions mava/configs/network/rnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ actor_network:
use_layer_norm: False
activation: relu

action_head:
_target_: mava.networks.heads.DiscreteActionHead # [DiscreteActionHead, ContinuousActionHead]

critic_network:
pre_torso:
_target_: mava.networks.torsos.MLPTorso
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
unreplicate_n_dims,
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -361,9 +362,8 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
config.network.action_head, action_dim=env.action_dim
)
action_head, _ = get_action_head(env)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

actor_network = Actor(torso=actor_torso, action_head=actor_action_head)
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -345,9 +346,8 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
config.network.action_head, action_dim=env.action_dim
)
action_head, _ = get_action_head(env)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

actor_network = Actor(torso=actor_torso, action_head=actor_action_head)
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -456,9 +457,8 @@ def learner_setup(
# Define network and optimisers.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
actor_action_head = hydra.utils.instantiate(
config.network.action_head, action_dim=env.action_dim
)
action_head, _ = get_action_head(env)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)

Expand Down
6 changes: 3 additions & 3 deletions mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -451,9 +452,8 @@ def learner_setup(
# Define network and optimiser.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
actor_action_head = hydra.utils.instantiate(
config.network.action_head, action_dim=env.action_dim
)
action_head, _ = get_action_head(env)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)

Expand Down
6 changes: 3 additions & 3 deletions mava/systems/sac/anakin/ff_isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.wrappers import episode_metrics

Expand Down Expand Up @@ -110,9 +111,8 @@ def replicate(x: Any) -> Any:

# Making actor network
actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
cfg.network.action_head, action_dim=action_dim, independent_std=False
)
action_head, _ = get_action_head(env)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
actor_network = Actor(actor_torso, actor_action_head)
actor_params = actor_network.init(actor_key, obs_single_batched)

Expand Down
6 changes: 3 additions & 3 deletions mava/systems/sac/anakin/ff_masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.wrappers import episode_metrics

Expand Down Expand Up @@ -113,9 +114,8 @@ def replicate(x: Any) -> Any:

# Making actor network
actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
cfg.network.action_head, action_dim=action_dim, independent_std=False
)
action_head, _ = get_action_head(env)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
actor_network = Actor(actor_torso, actor_action_head)
actor_params = actor_network.init(actor_key, obs_single_batched)

Expand Down
30 changes: 30 additions & 0 deletions mava/utils/network_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Tuple

from jumanji.specs import DiscreteArray, MultiDiscreteArray

from mava.types import MarlEnv

_DISCRETE = "discrete"
_CONTINUOUS = "continuous"

WiemKhlifi marked this conversation as resolved.
Show resolved Hide resolved

def get_action_head(env: MarlEnv) -> Tuple[Dict[str, str], str]:
"""Returns the appropriate action head config based on the environment action_spec."""
if isinstance(env.action_spec(), (DiscreteArray, MultiDiscreteArray)):
return {"_target_": "mava.networks.heads.DiscreteActionHead"}, _DISCRETE

return {"_target_": "mava.networks.heads.ContinuousActionHead"}, _CONTINUOUS
2 changes: 1 addition & 1 deletion test/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_continuous_env(fast_config: dict, env_name: str) -> None:
system_path = random.choice(ppo_systems + sac_systems)
_, _, system_name = system_path.split(".")

overrides = [f"env={env_name}", "network=continuous_mlp"]
overrides = [f"env={env_name}"]
with initialize(version_base=None, config_path=config_path):
cfg = compose(config_name=f"{system_name}", overrides=overrides)
cfg = _get_fast_config(cfg, fast_config)
Expand Down
Loading