Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jul 30, 2024
1 parent fd5444d commit 7f1ff80
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
28 changes: 28 additions & 0 deletions test/test_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

from benchmarl.algorithms import (
algorithm_config_registry,
IddpgConfig,
IppoConfig,
IsacConfig,
MaddpgConfig,
MappoConfig,
MasacConfig,
QmixConfig,
)
Expand Down Expand Up @@ -104,6 +106,32 @@ def test_gnn(
)
experiment.run()

@pytest.mark.parametrize(
"algo_config", [IddpgConfig, MaddpgConfig, IppoConfig, MappoConfig, QmixConfig]
)
@pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG])
def test_gru(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
gru_mlp_sequence_config,
):
algo_config = algo_config.get_from_yaml()
if algo_config.has_critic():
algo_config.share_param_critic = False
experiment_config.share_policy_params = False
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config,
model_config=gru_mlp_sequence_config,
critic_model_config=gru_mlp_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
@pytest.mark.parametrize("prefer_continuous", [True, False])
@pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG])
Expand Down
20 changes: 20 additions & 0 deletions test/test_smacv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,23 @@ def test_gnn(
task=task,
)
experiment.run()

@pytest.mark.parametrize("algo_config", [QmixConfig])
@pytest.mark.parametrize("task", [Smacv2Task.PROTOSS_5_VS_5])
def test_gru(
self,
algo_config,
task,
experiment_config,
gru_mlp_sequence_config,
):
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=gru_mlp_sequence_config,
critic_model_config=gru_mlp_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

0 comments on commit 7f1ff80

Please sign in to comment.