-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MAT #1107
base: develop
Are you sure you want to change the base?
Add MAT #1107
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple small things
Co-authored-by: Sasha Abramowitz <reallysasha@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤖 🤖 🤖 🤖
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some nits else the rest is good to go 🛥️ 🥅
@@ -1,13 +1,13 @@ | |||
# --- Anakin config --- | |||
|
|||
# --- Training --- | |||
num_envs: 16 # Number of vectorised environments per device. | |||
num_envs: 64 # Number of vectorised environments per device. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spamming just to revert before merging if not aiming to update these 👀
num_envs: 64 # Number of vectorised environments per device. | |
num_envs: 16 # Number of vectorised environments per device. | |
num_evaluation: 200 # Number of evenly spaced evaluations to perform during training. |
@@ -1,7 +1,7 @@ | |||
# ---Environment Configs--- | |||
defaults: | |||
- _self_ | |||
- scenario: tiny-2ag # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] | |||
- scenario: small-4ag # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spam: revert before merge
- scenario: small-4ag # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] | |
- scenario: tiny-2ag # [tiny-2ag, tiny-4ag, tiny-4ag-easy, small-4ag] |
add_agent_id: True | ||
|
||
# --- RL hyperparameters --- | ||
actor_lr: 0.0005 # Learning rate for actor network |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just raising this for future maybe we should change actor_lr
for both MAT and Sable to network_lr
or something else since it's not really an actor network 💭
from flax import linen as nn | ||
from flax.linen.initializers import orthogonal | ||
|
||
# TODO: Use einops for all the reshapes and matrix multiplications |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still a todo or will be ignored?
) | ||
from mava.systems.mat.types import MATNetworkConfig | ||
from mava.types import MavaObservation | ||
from mava.utils.network_utils import _CONTINUOUS, _DISCRETE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT:
from mava.utils.network_utils import _CONTINUOUS, _DISCRETE | |
from mava.utils.network_utils import CONTINUOUS, DISCRETE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Ruan for all the work on adding MAT, I left some comments but most of them nit and minor suggestions
- arch: anakin | ||
- system: mat/mat | ||
- network: transformer | ||
- env: rware |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- env: rware | |
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax] |
|
||
_run_system(system_path, cfg) | ||
|
||
|
||
@pytest.mark.parametrize("env_name", discrete_envs) | ||
def test_discrete_env(fast_config: dict, env_name: str) -> None: | ||
"""Test all discrete envs on random systems.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you can MAT to the random choice here
@@ -0,0 +1,6 @@ | |||
# --- Network params --- | |||
n_block: 1 # Transformer blocks | |||
n_embd: 64 # Transformer embedding dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for sake of unification maybe we can rename this to embed_dim similar to Sable
class SwiGLU(nn.Module): | ||
ffn_dim: int | ||
embed_dim: int | ||
|
||
def setup(self) -> None: | ||
self.W_1 = self.param("W_1", nn.initializers.zeros, (self.embed_dim, self.ffn_dim)) | ||
self.W_G = self.param("W_G", nn.initializers.zeros, (self.embed_dim, self.ffn_dim)) | ||
self.W_2 = self.param("W_2", nn.initializers.zeros, (self.ffn_dim, self.embed_dim)) | ||
|
||
def __call__(self, x: chex.Array) -> chex.Array: | ||
return (jax.nn.swish(x @ self.W_G) * (x @ self.W_1)) @ self.W_2 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class SwiGLU(nn.Module): | |
ffn_dim: int | |
embed_dim: int | |
def setup(self) -> None: | |
self.W_1 = self.param("W_1", nn.initializers.zeros, (self.embed_dim, self.ffn_dim)) | |
self.W_G = self.param("W_G", nn.initializers.zeros, (self.embed_dim, self.ffn_dim)) | |
self.W_2 = self.param("W_2", nn.initializers.zeros, (self.ffn_dim, self.embed_dim)) | |
def __call__(self, x: chex.Array) -> chex.Array: | |
return (jax.nn.swish(x @ self.W_G) * (x @ self.W_1)) @ self.W_2 | |
class SwiGLU(nn.Module): | |
"""SiwGLU module for Sable's Network. | |
Implements the SwiGLU feedforward neural network module, which is a variation | |
of the standard feedforward layer using the Swish activation function combined | |
with a Gated Linear Unit (GLU). | |
""" | |
hidden_dim: int | |
input_dim: int | |
def setup(self) -> None: | |
# Initialize the weights for the SwiGLU layer | |
self.W_linear = self.param( | |
"W_linear", nn.initializers.zeros, (self.input_dim, self.hidden_dim) | |
) | |
self.W_gate = self.param("W_gate", nn.initializers.zeros, (self.input_dim, self.hidden_dim)) | |
self.W_output = self.param( | |
"W_output", nn.initializers.zeros, (self.hidden_dim, self.input_dim) | |
) | |
def __call__(self, x: chex.Array) -> chex.Array: | |
"""Applies the SwiGLU mechanism to the input tensor `x`.""" | |
# Apply Swish activation to the gated branch and multiply with the linear branch | |
gated_output = jax.nn.swish(x @ self.W_gate) * (x @ self.W_linear) | |
# Transform the result back to the input dimension | |
return gated_output @ self.W_output | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added documentation and updated vars naming in swiglu, i added this suggestion here since MAT PR will be first to be merged, also adding it to torsos.py is better than in the outside utils folder (i will update sable PR based on that)
|
||
def get_learner_fn( | ||
env: MarlEnv, | ||
apply_fns: Tuple[ActorApply, CriticApply], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you can create type ExecutionApply and TrainApply instead of using actor critic one
|
||
eval_keys = jax.random.split(key_e, n_devices) | ||
|
||
def eval_act_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can here follow the ppo systems where we call a maker function from the evaluator instead of creating it here
|
||
# Evaluate. | ||
eval_metrics = evaluator(trained_params, eval_keys, {}) | ||
jax.block_until_ready(eval_metrics) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need these block_until_ready? We never added them to other on-policy systems
|
||
# Define network and optimiser. | ||
actor_network = MultiAgentTransformer( | ||
obs_dim=init_x.agents_view.shape[-1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need obs_dim as input
raise ValueError("Invalid action space type") | ||
|
||
# Define network and optimiser. | ||
actor_network = MultiAgentTransformer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we rename this to mat_network (very optional)
What?
Adds the Multi-agent Transformer to Mava.
Extra
The transformer based systems, or any other system with only one set of parameters, do not require us to make a
Params
NamedTuple
with only one element. I had to update the checkpointer to no longer check if the params that are restored are in aFrozenDict
. All other systems can still checkpoint and reload with the change. We also no longer need the check since we pin to a version of Flax that is higher than 0.6.11.