Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] LSTM #120

Merged
merged 6 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ agent group. Here is a table of the models implemented in BenchMARL
|------------------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GRU](benchmarl/models/gru.py) | Yes | Yes | Yes |
| [LSTM](benchmarl/models/lstm.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |
| [Deepsets](benchmarl/models/deepsets.py) | Yes | Yes | Yes |
Expand Down
15 changes: 15 additions & 0 deletions benchmarl/conf/model/layers/lstm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

name: lstm

hidden_size: 128
n_layers: 1
bias: True
dropout: 0
compile: False

mlp_num_cells: [256, 256]
mlp_layer_class: torch.nn.Linear
mlp_activation_class: torch.nn.Tanh
mlp_activation_kwargs: null
mlp_norm_class: null
mlp_norm_kwargs: null
4 changes: 4 additions & 0 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .deepsets import Deepsets, DeepsetsConfig
from .gnn import Gnn, GnnConfig
from .gru import Gru, GruConfig
from .lstm import Lstm, LstmConfig
from .mlp import Mlp, MlpConfig

classes = [
Expand All @@ -22,6 +23,8 @@
"DeepsetsConfig",
"Gru",
"GruConfig",
"Lstm",
"LstmConfig",
]

model_config_registry = {
Expand All @@ -30,4 +33,5 @@
"cnn": CnnConfig,
"deepsets": DeepsetsConfig,
"gru": GruConfig,
"lstm": LstmConfig,
}
21 changes: 7 additions & 14 deletions benchmarl/models/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@
# LICENSE file in the root directory of this source tree.
#

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from __future__ import annotations

from dataclasses import dataclass, MISSING
Expand Down Expand Up @@ -167,7 +161,7 @@ def forward(
h_0=None,
):
# Input and output always have the multiagent dimension
# Hidden state only has it when not centralised
# Hidden states always have it apart from when it is centralized and share params
# is_init never has it

assert is_init is not None, "We need to pass is_init"
Expand Down Expand Up @@ -202,7 +196,7 @@ def forward(
is_init = is_init.unsqueeze(-2).expand(batch, seq, self.n_agents, 1)

if h_0 is None:
if self.centralised:
if self.centralised and self.share_params:
shape = (
batch,
self.n_layers,
Expand Down Expand Up @@ -243,8 +237,8 @@ def run_net(self, input, is_init, h_0):
if self.centralised:
output, h_n = self.vmap_func_module(
self._empty_gru,
(0, None, None, None),
(-2, -2),
(0, None, None, -3),
(-2, -3),
)(self.params, input, is_init, h_0)
else:
output, h_n = self.vmap_func_module(
Expand Down Expand Up @@ -283,8 +277,8 @@ class Gru(Model):

The BenchMARL GRU accepts multiple inputs of type array: Tensors of shape ``(*batch,F)``

Where `F` is the number of features.
The features `F` will be processed to features of `hidden_size` by the GRU.
Where `F` is the number of features. These arrays will be concatenated along the F dimensions,
which will be processed to features of `hidden_size` by the GRU.

Args:
hidden_size (int): The number of features in the hidden state.
Expand Down Expand Up @@ -516,10 +510,9 @@ def is_rnn(self) -> bool:
return True

def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec:
name = f"_hidden_gru_{model_index}"
spec = CompositeSpec(
{
name: UnboundedContinuousTensorSpec(
f"_hidden_gru_{model_index}": UnboundedContinuousTensorSpec(
shape=(self.n_layers, self.hidden_size)
)
}
Expand Down
Loading
Loading