Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MAT #1107

Open
wants to merge 53 commits into
base: develop
Choose a base branch
from
Open

Add MAT #1107

wants to merge 53 commits into from

Conversation

RuanJohn
Copy link
Collaborator

@RuanJohn RuanJohn commented Oct 22, 2024

What?

Adds the Multi-agent Transformer to Mava.

Extra

The transformer based systems, or any other system with only one set of parameters, do not require us to make a Params NamedTuple with only one element. I had to update the checkpointer to no longer check if the params that are restored are in a FrozenDict. All other systems can still checkpoint and reload with the change. We also no longer need the check since we pin to a version of Flax that is higher than 0.6.11.

Copy link
Contributor

@sash-a sash-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple small things

mava/networks/attention.py Outdated Show resolved Hide resolved
mava/networks/mat_network.py Outdated Show resolved Hide resolved
mava/systems/mat/anakin/mat.py Show resolved Hide resolved
mava/systems/mat/anakin/mat.py Outdated Show resolved Hide resolved
test/integration_test.py Outdated Show resolved Hide resolved
sash-a
sash-a previously approved these changes Oct 28, 2024
Copy link
Contributor

@sash-a sash-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 🤖 🤖 🤖

Copy link
Contributor

@WiemKhlifi WiemKhlifi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nits else the rest is good to go 🛥️ 🥅

@@ -1,13 +1,13 @@
# --- Anakin config ---

# --- Training ---
num_envs: 16 # Number of vectorised environments per device.
num_envs: 64 # Number of vectorised environments per device.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spamming just to revert before merging if not aiming to update these 👀

Suggested change
num_envs: 64 # Number of vectorised environments per device.
num_envs: 16 # Number of vectorised environments per device.
num_evaluation: 200 # Number of evenly spaced evaluations to perform during training.

@@ -1,7 +1,7 @@
# ---Environment Configs---
defaults:
- _self_
- scenario: tiny-2ag # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag]
- scenario: small-4ag # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spam: revert before merge

Suggested change
- scenario: small-4ag # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag]
- scenario: tiny-2ag # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag]

add_agent_id: True

# --- RL hyperparameters ---
actor_lr: 0.0005 # Learning rate for actor network
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just raising this for future maybe we should change actor_lr for both MAT and Sable to network_lr or something else since it's not really an actor network 💭

from flax import linen as nn
from flax.linen.initializers import orthogonal

# TODO: Use einops for all the reshapes and matrix multiplications
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still a todo or will be ignored?

)
from mava.systems.mat.types import MATNetworkConfig
from mava.types import MavaObservation
from mava.utils.network_utils import _CONTINUOUS, _DISCRETE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT:

Suggested change
from mava.utils.network_utils import _CONTINUOUS, _DISCRETE
from mava.utils.network_utils import CONTINUOUS, DISCRETE

Copy link
Contributor

@OmaymaMahjoub OmaymaMahjoub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ruan for all the work on adding MAT, I left some comments but most of them nit and minor suggestions

- arch: anakin
- system: mat/mat
- network: transformer
- env: rware
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- env: rware
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax]


_run_system(system_path, cfg)


@pytest.mark.parametrize("env_name", discrete_envs)
def test_discrete_env(fast_config: dict, env_name: str) -> None:
"""Test all discrete envs on random systems."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you can MAT to the random choice here

@@ -0,0 +1,6 @@
# --- Network params ---
n_block: 1 # Transformer blocks
n_embd: 64 # Transformer embedding dimension
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for sake of unification maybe we can rename this to embed_dim similar to Sable

Comment on lines +79 to +90
class SwiGLU(nn.Module):
ffn_dim: int
embed_dim: int

def setup(self) -> None:
self.W_1 = self.param("W_1", nn.initializers.zeros, (self.embed_dim, self.ffn_dim))
self.W_G = self.param("W_G", nn.initializers.zeros, (self.embed_dim, self.ffn_dim))
self.W_2 = self.param("W_2", nn.initializers.zeros, (self.ffn_dim, self.embed_dim))

def __call__(self, x: chex.Array) -> chex.Array:
return (jax.nn.swish(x @ self.W_G) * (x @ self.W_1)) @ self.W_2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class SwiGLU(nn.Module):
ffn_dim: int
embed_dim: int
def setup(self) -> None:
self.W_1 = self.param("W_1", nn.initializers.zeros, (self.embed_dim, self.ffn_dim))
self.W_G = self.param("W_G", nn.initializers.zeros, (self.embed_dim, self.ffn_dim))
self.W_2 = self.param("W_2", nn.initializers.zeros, (self.ffn_dim, self.embed_dim))
def __call__(self, x: chex.Array) -> chex.Array:
return (jax.nn.swish(x @ self.W_G) * (x @ self.W_1)) @ self.W_2
class SwiGLU(nn.Module):
"""SiwGLU module for Sable's Network.
Implements the SwiGLU feedforward neural network module, which is a variation
of the standard feedforward layer using the Swish activation function combined
with a Gated Linear Unit (GLU).
"""
hidden_dim: int
input_dim: int
def setup(self) -> None:
# Initialize the weights for the SwiGLU layer
self.W_linear = self.param(
"W_linear", nn.initializers.zeros, (self.input_dim, self.hidden_dim)
)
self.W_gate = self.param("W_gate", nn.initializers.zeros, (self.input_dim, self.hidden_dim))
self.W_output = self.param(
"W_output", nn.initializers.zeros, (self.hidden_dim, self.input_dim)
)
def __call__(self, x: chex.Array) -> chex.Array:
"""Applies the SwiGLU mechanism to the input tensor `x`."""
# Apply Swish activation to the gated branch and multiply with the linear branch
gated_output = jax.nn.swish(x @ self.W_gate) * (x @ self.W_linear)
# Transform the result back to the input dimension
return gated_output @ self.W_output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added documentation and updated vars naming in swiglu, i added this suggestion here since MAT PR will be first to be merged, also adding it to torsos.py is better than in the outside utils folder (i will update sable PR based on that)


def get_learner_fn(
env: MarlEnv,
apply_fns: Tuple[ActorApply, CriticApply],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you can create type ExecutionApply and TrainApply instead of using actor critic one


eval_keys = jax.random.split(key_e, n_devices)

def eval_act_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can here follow the ppo systems where we call a maker function from the evaluator instead of creating it here


# Evaluate.
eval_metrics = evaluator(trained_params, eval_keys, {})
jax.block_until_ready(eval_metrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need these block_until_ready? We never added them to other on-policy systems


# Define network and optimiser.
actor_network = MultiAgentTransformer(
obs_dim=init_x.agents_view.shape[-1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need obs_dim as input

raise ValueError("Invalid action space type")

# Define network and optimiser.
actor_network = MultiAgentTransformer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename this to mat_network (very optional)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants