diff --git a/.github/unittest/install_meltingpot.sh b/.github/unittest/install_meltingpot.sh
new file mode 100644
index 00000000..6842d3b0
--- /dev/null
+++ b/.github/unittest/install_meltingpot.sh
@@ -0,0 +1,2 @@
+
+pip install dm-meltingpot
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index ff470a37..05edc0c8 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.10"]
+ python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
diff --git a/.github/workflows/meltingpot_tests.yml b/.github/workflows/meltingpot_tests.yml
new file mode 100644
index 00000000..77e21a67
--- /dev/null
+++ b/.github/workflows/meltingpot_tests.yml
@@ -0,0 +1,43 @@
+# This workflow will install Python dependencies, run tests and lint with a single version of Python
+# For more information see:
+# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
+
+
+name: meltingpot_tests
+
+on:
+ push:
+ branches: [ $default-branch , "main" ]
+ pull_request:
+ branches: [ $default-branch , "main" ]
+
+permissions:
+ contents: read
+
+jobs:
+ tests:
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: ["3.11"]
+
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v3
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ bash .github/unittest/install_dependencies_nightly.sh
+ - name: Install meltingpot
+ run: |
+ bash .github/unittest/install_meltingpot.sh
+ - name: Test with pytest
+ run: |
+ pytest test/test_meltingpot.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ fail_ci_if_error: false
diff --git a/.github/workflows/pettingzoo_tests.yml b/.github/workflows/pettingzoo_tests.yml
index 3fbd0125..66d5e81c 100644
--- a/.github/workflows/pettingzoo_tests.yml
+++ b/.github/workflows/pettingzoo_tests.yml
@@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
@@ -37,9 +37,7 @@ jobs:
- name: Test with pytest
run: |
xvfb-run -s "-screen 0 1024x768x24" pytest test/test_pettingzoo.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
-
- - if: matrix.python-version == '3.10'
- name: Upload coverage to Codecov
+ - name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
fail_ci_if_error: false
diff --git a/.github/workflows/smacv2_tests.yml b/.github/workflows/smacv2_tests.yml
index 20761a69..7023afd9 100644
--- a/.github/workflows/smacv2_tests.yml
+++ b/.github/workflows/smacv2_tests.yml
@@ -21,7 +21,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.10"]
+ python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
@@ -44,8 +44,7 @@ jobs:
pytest test/test_smacv2.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
- - if: matrix.python-version == '3.10'
- name: Upload coverage to Codecov
+ - name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
fail_ci_if_error: false
diff --git a/.github/workflows/torchrl_stable_tests.yml b/.github/workflows/torchrl_stable_tests.yml
index 413c10d7..3ca7e448 100644
--- a/.github/workflows/torchrl_stable_tests.yml
+++ b/.github/workflows/torchrl_stable_tests.yml
@@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.10"]
+ python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
@@ -38,4 +38,4 @@ jobs:
bash .github/unittest/install_pettingzoo.sh
- name: Tests
run: |
- xvfb-run -s "-screen 0 1024x768x24" pytest test/test_algorithm.py test/test_models.py test/test_task.py test/test_vmas.py test/test_pettingzoo.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
+ xvfb-run -s "-screen 0 1024x768x24" pytest test/test_algorithm.py test/test_models.py test/test_task.py test/test_vmas.py test/test_pettingzoo.py test/test_meltingpot.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml
index a1736419..eb45153b 100644
--- a/.github/workflows/unit_tests.yml
+++ b/.github/workflows/unit_tests.yml
@@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.8", "3.9", "3.10","3.11"]
steps:
- uses: actions/checkout@v3
diff --git a/.github/workflows/vmas_tests.yml b/.github/workflows/vmas_tests.yml
index c42e8bd7..c6fe5e91 100644
--- a/.github/workflows/vmas_tests.yml
+++ b/.github/workflows/vmas_tests.yml
@@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
@@ -38,8 +38,7 @@ jobs:
run: |
xvfb-run -s "-screen 0 1024x768x24" pytest test/test_vmas.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
- - if: matrix.python-version == '3.10'
- name: Upload coverage to Codecov
+ - name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
fail_ci_if_error: false
diff --git a/README.md b/README.md
index 9e421edb..e63cba6d 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
[![tests](https://github.com/facebookresearch/BenchMARL/actions/workflows/unit_tests.yml/badge.svg)](test)
[![codecov](https://codecov.io/github/facebookresearch/BenchMARL/coverage.svg?branch=main)](https://codecov.io/gh/facebookresearch/BenchMARL)
[![Documentation Status](https://readthedocs.org/projects/benchmarl/badge/?version=latest)](https://benchmarl.readthedocs.io/en/latest/?badge=latest)
-[![Python](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10-blue.svg)](https://www.python.org/downloads/)
+[![Python](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11-blue.svg)](https://www.python.org/downloads/)
[![Downloads](https://static.pepy.tech/personalized-badge/benchmarl?period=total&units=international_system&left_color=grey&right_color=blue&left_text=Downloads)](https://pepy.tech/project/benchmarl)
[![Discord Shield](https://dcbadge.vercel.app/api/server/jEEWCn6T3p?style=flat)](https://discord.gg/jEEWCn6T3p)
@@ -113,6 +113,10 @@ pip install vmas
pip install "pettingzoo[all]"
```
+##### MeltingPot
+```bash
+pip install dm-meltingpot
+```
##### SMACv2
Follow the instructions on the environment [repository](https://github.com/oxwhirl/smacv2).
@@ -236,12 +240,14 @@ determine the training strategy. Here is a table with the currently implemented
challenge to solve.
They differ based on many aspects, here is a table with the current environments in BenchMARL
-| Environment | Tasks | Cooperation | Global state | Reward function | Action space | Vectorized |
-|--------------------------------------------------------------------|-------------------------------------|---------------------------|--------------|-------------------------------|-----------------------|:----------------:|
-| [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) | [18](benchmarl/conf/task/vmas) | Cooperative + Competitive | No | Shared + Independent + Global | Continuous + Discrete | Yes |
-| [SMACv2](https://github.com/oxwhirl/smacv2) | [15](benchmarl/conf/task/smacv2) | Cooperative | Yes | Global | Discrete | No |
-| [MPE](https://github.com/openai/multiagent-particle-envs) | [8](benchmarl/conf/task/pettingzoo) | Cooperative + Competitive | Yes | Shared + Independent | Continuous + Discrete | No |
-| [SISL](https://github.com/sisl/MADRL) | [2](benchmarl/conf/task/pettingzoo) | Cooperative | No | Shared | Continuous | No |
+| Environment | Tasks | Cooperation | Global state | Reward function | Action space | Vectorized |
+|--------------------------------------------------------------------|--------------------------------------|---------------------------|--------------|-------------------------------|-----------------------|:----------------:|
+| [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) | [18](benchmarl/conf/task/vmas) | Cooperative + Competitive | No | Shared + Independent + Global | Continuous + Discrete | Yes |
+| [SMACv2](https://github.com/oxwhirl/smacv2) | [15](benchmarl/conf/task/smacv2) | Cooperative | Yes | Global | Discrete | No |
+| [MPE](https://github.com/openai/multiagent-particle-envs) | [8](benchmarl/conf/task/pettingzoo) | Cooperative + Competitive | Yes | Shared + Independent | Continuous + Discrete | No |
+| [SISL](https://github.com/sisl/MADRL) | [2](benchmarl/conf/task/pettingzoo) | Cooperative | No | Shared | Continuous | No |
+| [MeltingPot](https://github.com/google-deepmind/meltingpot) | [49](benchmarl/conf/task/meltingpot) | Cooperative + Competitive | Yes | Independent | Discrete | No |
+
> [!NOTE]
> BenchMARL uses the [TorchRL MARL API](https://github.com/pytorch/rl/issues/1463) for grouping agents.
diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py
index 4155197d..e9b885b1 100644
--- a/benchmarl/algorithms/common.py
+++ b/benchmarl/algorithms/common.py
@@ -7,7 +7,7 @@
import pathlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
-from typing import Any, Dict, Iterable, Optional, Tuple, Type
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
@@ -19,6 +19,7 @@
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement
+from torchrl.envs import Compose, Transform
from torchrl.objectives import LossModule
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater
@@ -132,8 +133,7 @@ def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater
return self._losses_and_updaters[group]
def get_replay_buffer(
- self,
- group: str,
+ self, group: str, transforms: List[Transform] = None
) -> ReplayBuffer:
"""
Get the ReplayBuffer for a specific group.
@@ -141,6 +141,7 @@ def get_replay_buffer(
Args:
group (str): agent group of the loss and updater
+ transforms (optional, list of Transform): Transforms to apply to the replay buffer ``.sample()`` call
Returns: ReplayBuffer the group
"""
@@ -154,6 +155,7 @@ def get_replay_buffer(
sampler=sampler,
batch_size=sampling_size,
priority_key=(group, "td_error"),
+ transform=Compose(*transforms) if transforms is not None else None,
)
def get_policy_for_loss(self, group: str) -> TensorDictModule:
diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml
index 712291f6..2a357d73 100644
--- a/benchmarl/conf/experiment/base_experiment.yaml
+++ b/benchmarl/conf/experiment/base_experiment.yaml
@@ -82,7 +82,7 @@ evaluation_episodes: 10
evaluation_deterministic_actions: True
# List of loggers to use, options are: wandb, csv, tensorboard, mflow
-loggers: [wandb]
+loggers: []
# Create a json folder as part of the output in the format of marl-eval
create_json: True
diff --git a/benchmarl/conf/task/meltingpot/allelopathic_harvest__open.yaml b/benchmarl/conf/task/meltingpot/allelopathic_harvest__open.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/bach_or_stravinsky_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/bach_or_stravinsky_in_the_matrix__arena.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/bach_or_stravinsky_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/bach_or_stravinsky_in_the_matrix__repeated.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/boat_race__eight_races.yaml b/benchmarl/conf/task/meltingpot/boat_race__eight_races.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/chemistry__three_metabolic_cycles.yaml b/benchmarl/conf/task/meltingpot/chemistry__three_metabolic_cycles.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/chemistry__three_metabolic_cycles_with_plentiful_distractors.yaml b/benchmarl/conf/task/meltingpot/chemistry__three_metabolic_cycles_with_plentiful_distractors.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/chemistry__two_metabolic_cycles.yaml b/benchmarl/conf/task/meltingpot/chemistry__two_metabolic_cycles.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/chemistry__two_metabolic_cycles_with_distractors.yaml b/benchmarl/conf/task/meltingpot/chemistry__two_metabolic_cycles_with_distractors.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/chicken_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/chicken_in_the_matrix__arena.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/chicken_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/chicken_in_the_matrix__repeated.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/clean_up.yaml b/benchmarl/conf/task/meltingpot/clean_up.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/coins.yaml b/benchmarl/conf/task/meltingpot/coins.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__asymmetric.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__asymmetric.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__circuit.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__circuit.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__cramped.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__cramped.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__crowded.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__crowded.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__figure_eight.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__figure_eight.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__forced.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__forced.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__ring.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__ring.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/commons_harvest__closed.yaml b/benchmarl/conf/task/meltingpot/commons_harvest__closed.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/commons_harvest__open.yaml b/benchmarl/conf/task/meltingpot/commons_harvest__open.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/commons_harvest__partnership.yaml b/benchmarl/conf/task/meltingpot/commons_harvest__partnership.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/coop_mining.yaml b/benchmarl/conf/task/meltingpot/coop_mining.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/daycare.yaml b/benchmarl/conf/task/meltingpot/daycare.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/externality_mushrooms__dense.yaml b/benchmarl/conf/task/meltingpot/externality_mushrooms__dense.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/factory_commons__either_or.yaml b/benchmarl/conf/task/meltingpot/factory_commons__either_or.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/fruit_market__concentric_rivers.yaml b/benchmarl/conf/task/meltingpot/fruit_market__concentric_rivers.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/gift_refinements.yaml b/benchmarl/conf/task/meltingpot/gift_refinements.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/hidden_agenda.yaml b/benchmarl/conf/task/meltingpot/hidden_agenda.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/paintball__capture_the_flag.yaml b/benchmarl/conf/task/meltingpot/paintball__capture_the_flag.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/paintball__king_of_the_hill.yaml b/benchmarl/conf/task/meltingpot/paintball__king_of_the_hill.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/predator_prey__alley_hunt.yaml b/benchmarl/conf/task/meltingpot/predator_prey__alley_hunt.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/predator_prey__open.yaml b/benchmarl/conf/task/meltingpot/predator_prey__open.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/predator_prey__orchard.yaml b/benchmarl/conf/task/meltingpot/predator_prey__orchard.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/predator_prey__random_forest.yaml b/benchmarl/conf/task/meltingpot/predator_prey__random_forest.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/prisoners_dilemma_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/prisoners_dilemma_in_the_matrix__arena.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/prisoners_dilemma_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/prisoners_dilemma_in_the_matrix__repeated.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/pure_coordination_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/pure_coordination_in_the_matrix__arena.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/pure_coordination_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/pure_coordination_in_the_matrix__repeated.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/rationalizable_coordination_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/rationalizable_coordination_in_the_matrix__arena.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/rationalizable_coordination_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/rationalizable_coordination_in_the_matrix__repeated.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__arena.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__one_shot.yaml b/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__one_shot.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__repeated.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/stag_hunt_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/stag_hunt_in_the_matrix__arena.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/stag_hunt_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/stag_hunt_in_the_matrix__repeated.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/territory__inside_out.yaml b/benchmarl/conf/task/meltingpot/territory__inside_out.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/territory__open.yaml b/benchmarl/conf/task/meltingpot/territory__open.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/conf/task/meltingpot/territory__rooms.yaml b/benchmarl/conf/task/meltingpot/territory__rooms.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/environments/__init__.py b/benchmarl/environments/__init__.py
index 874a6b74..c6359305 100644
--- a/benchmarl/environments/__init__.py
+++ b/benchmarl/environments/__init__.py
@@ -5,6 +5,7 @@
#
from .common import Task
+from .meltingpot.common import MeltingPotTask
from .pettingzoo.common import PettingZooTask
from .smacv2.common import Smacv2Task
from .vmas.common import VmasTask
@@ -12,7 +13,7 @@
# This is a registry mapping "envname/task_name" to the EnvNameTask.TASK_NAME enum
# It is used by automatically load task enums from yaml files
task_config_registry = {}
-for env in [VmasTask, Smacv2Task, PettingZooTask]:
+for env in [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask]:
env_config_registry = {
f"{env.env_name()}/{task.name.lower()}": task for task in env
}
@@ -31,6 +32,7 @@
from .pettingzoo.simple_tag import TaskConfig as SimpleTagConfig
from .pettingzoo.simple_world_comm import TaskConfig as SimpleWorldComm
from .pettingzoo.waterworld import TaskConfig as WaterworldConfig
+
from .vmas.balance import TaskConfig as BalanceConfig
from .vmas.dispersion import TaskConfig as DispersionConfig
from .vmas.dropout import TaskConfig as DropoutConfig
@@ -38,7 +40,6 @@
from .vmas.navigation import TaskConfig as NavigationConfig
from .vmas.reverse_transport import TaskConfig as ReverseTransportConfig
from .vmas.sampling import TaskConfig as SamplingConfig
-
from .vmas.simple_adverasary import TaskConfig as VmasSimpleAdversaryConfig
from .vmas.simple_crypto import TaskConfig as VmasSimpleCryptoConfig
from .vmas.simple_push import TaskConfig as VmasSimplePushConfig
diff --git a/benchmarl/environments/common.py b/benchmarl/environments/common.py
index 32faa7dc..377f9ded 100644
--- a/benchmarl/environments/common.py
+++ b/benchmarl/environments/common.py
@@ -96,13 +96,13 @@ def get_env_fun(
num_envs (int): The number of envs that should be in the batch_size of the returned env.
In vectorized envs, this can be used to set the number of batched environments.
If your environment is not vectorized, you can just ignore this, and it will be
- wrapped in a torchrl.envs.SerialEnv with num_envs automatically.
+ wrapped in a :class:`torchrl.envs.SerialEnv` with num_envs automatically.
continuous_actions (bool): Whether your environment should have continuous or discrete actions.
If your environment does not support both, ignore this and refer to the supports_x_actions methods.
seed (optional, int): The seed of your env
device (str): the device of your env, you can pass this to any torchrl env constructor
- Returns: a function that takes no arguments and returns a torchrl.envs.EnvBase object
+ Returns: a function that takes no arguments and returns a :class:`torchrl.envs.EnvBase` object
"""
raise NotImplementedError
@@ -242,6 +242,28 @@ def get_reward_sum_transform(self, env: EnvBase) -> Transform:
reset_keys = env.reset_keys
return RewardSum(reset_keys=reset_keys)
+ def get_env_transforms(self, env: EnvBase) -> List[Transform]:
+ """
+ Returns a list of :class:`torchrl.envs.Transform` to be applied to the env.
+
+ Args:
+ env (EnvBase): An environment created via self.get_env_fun
+
+
+ """
+ return []
+
+ def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]:
+ """
+ Returns a list of :class:`torchrl.envs.Transform` to be applied to the :class:`torchrl.data.ReplayBuffer`.
+
+ Args:
+ env (EnvBase): An environment created via self.get_env_fun
+
+
+ """
+ return []
+
@staticmethod
def render_callback(experiment, env: EnvBase, data: TensorDictBase):
try:
diff --git a/benchmarl/environments/meltingpot/__init__.py b/benchmarl/environments/meltingpot/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/benchmarl/environments/meltingpot/common.py b/benchmarl/environments/meltingpot/common.py
new file mode 100644
index 00000000..f209d8b1
--- /dev/null
+++ b/benchmarl/environments/meltingpot/common.py
@@ -0,0 +1,170 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+
+from typing import Callable, Dict, List, Optional
+
+import torch
+from tensordict import TensorDictBase
+
+from torchrl.data import CompositeSpec
+from torchrl.envs import DoubleToFloat, DTypeCastTransform, EnvBase, Transform
+
+from benchmarl.environments.common import Task
+from benchmarl.utils import DEVICE_TYPING
+
+
+class MeltingPotTask(Task):
+ """Enum for meltingpot tasks."""
+
+ PREDATOR_PREY__ALLEY_HUNT = None
+ CLEAN_UP = None
+ COLLABORATIVE_COOKING__CIRCUIT = None
+ FRUIT_MARKET__CONCENTRIC_RIVERS = None
+ COLLABORATIVE_COOKING__FIGURE_EIGHT = None
+ PAINTBALL__KING_OF_THE_HILL = None
+ FACTORY_COMMONS__EITHER_OR = None
+ PURE_COORDINATION_IN_THE_MATRIX__ARENA = None
+ RUNNING_WITH_SCISSORS_IN_THE_MATRIX__REPEATED = None
+ COLLABORATIVE_COOKING__CRAMPED = None
+ RUNNING_WITH_SCISSORS_IN_THE_MATRIX__ARENA = None
+ PRISONERS_DILEMMA_IN_THE_MATRIX__REPEATED = None
+ TERRITORY__OPEN = None
+ STAG_HUNT_IN_THE_MATRIX__REPEATED = None
+ CHICKEN_IN_THE_MATRIX__REPEATED = None
+ GIFT_REFINEMENTS = None
+ PURE_COORDINATION_IN_THE_MATRIX__REPEATED = None
+ COLLABORATIVE_COOKING__FORCED = None
+ RATIONALIZABLE_COORDINATION_IN_THE_MATRIX__ARENA = None
+ BACH_OR_STRAVINSKY_IN_THE_MATRIX__ARENA = None
+ CHEMISTRY__TWO_METABOLIC_CYCLES_WITH_DISTRACTORS = None
+ COMMONS_HARVEST__PARTNERSHIP = None
+ PREDATOR_PREY__OPEN = None
+ TERRITORY__ROOMS = None
+ HIDDEN_AGENDA = None
+ COOP_MINING = None
+ DAYCARE = None
+ PRISONERS_DILEMMA_IN_THE_MATRIX__ARENA = None
+ TERRITORY__INSIDE_OUT = None
+ BACH_OR_STRAVINSKY_IN_THE_MATRIX__REPEATED = None
+ COMMONS_HARVEST__CLOSED = None
+ CHEMISTRY__THREE_METABOLIC_CYCLES_WITH_PLENTIFUL_DISTRACTORS = None
+ STAG_HUNT_IN_THE_MATRIX__ARENA = None
+ PAINTBALL__CAPTURE_THE_FLAG = None
+ COLLABORATIVE_COOKING__CROWDED = None
+ ALLELOPATHIC_HARVEST__OPEN = None
+ COLLABORATIVE_COOKING__RING = None
+ COMMONS_HARVEST__OPEN = None
+ COINS = None
+ PREDATOR_PREY__ORCHARD = None
+ PREDATOR_PREY__RANDOM_FOREST = None
+ COLLABORATIVE_COOKING__ASYMMETRIC = None
+ RATIONALIZABLE_COORDINATION_IN_THE_MATRIX__REPEATED = None
+ CHEMISTRY__THREE_METABOLIC_CYCLES = None
+ RUNNING_WITH_SCISSORS_IN_THE_MATRIX__ONE_SHOT = None
+ CHEMISTRY__TWO_METABOLIC_CYCLES = None
+ CHICKEN_IN_THE_MATRIX__ARENA = None
+ BOAT_RACE__EIGHT_RACES = None
+ EXTERNALITY_MUSHROOMS__DENSE = None
+
+ def get_env_fun(
+ self,
+ num_envs: int,
+ continuous_actions: bool,
+ seed: Optional[int],
+ device: DEVICE_TYPING,
+ ) -> Callable[[], EnvBase]:
+ from torchrl.envs.libs.meltingpot import MeltingpotEnv
+
+ return lambda: MeltingpotEnv(
+ substrate=self.name.lower(),
+ categorical_actions=True,
+ **self.config,
+ )
+
+ def supports_continuous_actions(self) -> bool:
+ return False
+
+ def supports_discrete_actions(self) -> bool:
+ return True
+
+ def has_render(self, env: EnvBase) -> bool:
+ return True
+
+ def max_steps(self, env: EnvBase) -> int:
+ return self.config.get("max_steps", 100)
+
+ def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
+ return env.group_map
+
+ def get_env_transforms(self, env: EnvBase) -> List[Transform]:
+ return [DoubleToFloat()]
+
+ def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]:
+ return [
+ DTypeCastTransform(
+ dtype_in=torch.uint8,
+ dtype_out=torch.float,
+ in_keys=[
+ "RGB",
+ *[
+ (group, "observation", "RGB")
+ for group in self.group_map(env).keys()
+ ],
+ ("next", "RGB"),
+ *[
+ ("next", group, "observation", "RGB")
+ for group in self.group_map(env).keys()
+ ],
+ ],
+ in_keys_inv=[],
+ )
+ ]
+
+ def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
+ observation_spec = env.observation_spec.clone()
+ for group in self.group_map(env):
+ del observation_spec[group]
+ if list(observation_spec.keys()) != ["RGB"]:
+ raise ValueError(
+ f"More than one global state key found in observation spec {observation_spec}."
+ )
+ return observation_spec
+
+ def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
+ return None
+
+ def observation_spec(self, env: EnvBase) -> CompositeSpec:
+ observation_spec = env.observation_spec.clone()
+ for group_key in list(observation_spec.keys()):
+ if group_key not in self.group_map(env).keys():
+ del observation_spec[group_key]
+ else:
+ group_obs_spec = observation_spec[group_key]["observation"]
+ for key in list(group_obs_spec.keys()):
+ if key != "RGB":
+ del group_obs_spec[key]
+ return observation_spec
+
+ def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
+ observation_spec = env.observation_spec.clone()
+ for group_key in list(observation_spec.keys()):
+ if group_key not in self.group_map(env).keys():
+ del observation_spec[group_key]
+ else:
+ group_obs_spec = observation_spec[group_key]["observation"]
+ del group_obs_spec["RGB"]
+ return observation_spec
+
+ def action_spec(self, env: EnvBase) -> CompositeSpec:
+ return env.full_action_spec
+
+ @staticmethod
+ def env_name() -> str:
+ return "meltingpot"
+
+ @staticmethod
+ def render_callback(experiment, env: EnvBase, data: TensorDictBase):
+ return data.get("RGB")
diff --git a/benchmarl/environments/pettingzoo/common.py b/benchmarl/environments/pettingzoo/common.py
index 5b27754f..fbaadf59 100644
--- a/benchmarl/environments/pettingzoo/common.py
+++ b/benchmarl/environments/pettingzoo/common.py
@@ -122,7 +122,6 @@ def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
del observation_spec["state"]
if observation_spec.is_empty():
return None
-
return observation_spec
def observation_spec(self, env: EnvBase) -> CompositeSpec:
@@ -132,7 +131,8 @@ def observation_spec(self, env: EnvBase) -> CompositeSpec:
for key in list(group_obs_spec.keys()):
if key != "observation":
del group_obs_spec[key]
-
+ if "state" in observation_spec.keys():
+ del observation_spec["state"]
return observation_spec
def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
@@ -142,10 +142,12 @@ def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
for key in list(group_obs_spec.keys()):
if key != "info":
del group_obs_spec[key]
+ if "state" in observation_spec.keys():
+ del observation_spec["state"]
return observation_spec
def action_spec(self, env: EnvBase) -> CompositeSpec:
- return env.input_spec["full_action_spec"]
+ return env.full_action_spec
@staticmethod
def env_name() -> str:
diff --git a/benchmarl/environments/smacv2/common.py b/benchmarl/environments/smacv2/common.py
index 47043396..e968c31f 100644
--- a/benchmarl/environments/smacv2/common.py
+++ b/benchmarl/environments/smacv2/common.py
@@ -88,7 +88,7 @@ def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
return observation_spec
def action_spec(self, env: EnvBase) -> CompositeSpec:
- return env.input_spec["full_action_spec"]
+ return env.full_action_spec
@staticmethod
def log_info(batch: TensorDictBase) -> Dict[str, float]:
diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py
index a839fdbb..f7a75a99 100644
--- a/benchmarl/experiment/experiment.py
+++ b/benchmarl/experiment/experiment.py
@@ -387,7 +387,6 @@ def _setup_task(self):
device=self.config.sampling_device,
)
)
-
self.observation_spec = self.task.observation_spec(test_env)
self.info_spec = self.task.info_spec(test_env)
self.state_spec = self.task.state_spec(test_env)
@@ -397,24 +396,34 @@ def _setup_task(self):
self.train_group_map = copy.deepcopy(self.group_map)
self.max_steps = self.task.max_steps(test_env)
- transforms = [self.task.get_reward_sum_transform(test_env)]
- transform = Compose(*transforms)
+ transforms_env = self.task.get_env_transforms(test_env)
+ transforms_training = transforms_env + [
+ self.task.get_reward_sum_transform(test_env)
+ ]
+
+ transforms_env = Compose(*transforms_env)
+ transforms_training = Compose(*transforms_training)
if test_env.batch_size == ():
self.env_func = lambda: TransformedEnv(
SerialEnv(self.config.n_envs_per_worker(self.on_policy), env_func),
- transform.clone(),
+ transforms_training.clone(),
)
else:
- self.env_func = lambda: TransformedEnv(env_func(), transform.clone())
+ self.env_func = lambda: TransformedEnv(
+ env_func(), transforms_training.clone()
+ )
- self.test_env = test_env.to(self.config.sampling_device)
+ self.test_env = TransformedEnv(test_env, transforms_env.clone()).to(
+ self.config.sampling_device
+ )
def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)
self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
+ transforms=self.task.get_replay_buffer_transforms(self.test_env),
)
for group in self.group_map.keys()
}
diff --git a/benchmarl/models/cnn.py b/benchmarl/models/cnn.py
index afce12c5..efe77003 100644
--- a/benchmarl/models/cnn.py
+++ b/benchmarl/models/cnn.py
@@ -229,7 +229,8 @@ def _perform_checks(self):
raise ValueError(
f"CNN input value {input_key} from {self.input_spec} has an invalid shape"
)
-
+ if not len(self.image_in_keys):
+ raise ValueError("CNN found no image inputs, maybe use an MLP?")
if self.input_has_agent_dim and input_shape_image[-3] != self.n_agents:
raise ValueError(
"If the CNN input has the agent dimension,"
@@ -257,7 +258,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# Gather images
input = torch.cat(
[tensordict.get(in_key) for in_key in self.image_in_keys], dim=-1
- )
+ ).to(torch.float)
# BenchMARL images are X,Y,C -> we convert them to C, X, Y for processing in TorchRL models
input = input.transpose(-3, -1).transpose(-2, -1)
@@ -316,9 +317,9 @@ class CnnConfig(ModelConfig):
"""Dataclass config for a :class:`~benchmarl.models.Cnn`."""
cnn_num_cells: Sequence[int] = MISSING
- cnn_kernel_sizes: Sequence[int] = MISSING
- cnn_strides: Sequence[int] = MISSING
- cnn_paddings: Sequence[int] = MISSING
+ cnn_kernel_sizes: Union[Sequence[int], int] = MISSING
+ cnn_strides: Union[Sequence[int], int] = MISSING
+ cnn_paddings: Union[Sequence[int], int] = MISSING
cnn_activation_class: Type[nn.Module] = MISSING
mlp_num_cells: Sequence[int] = MISSING
diff --git a/benchmarl/models/mlp.py b/benchmarl/models/mlp.py
index 9f98c37b..ea810b00 100644
--- a/benchmarl/models/mlp.py
+++ b/benchmarl/models/mlp.py
@@ -93,7 +93,7 @@ def _perform_checks(self):
)
else:
raise ValueError(
- f"MLP input value {input_key} from {self.input_spec} has an invalid shape"
+ f"MLP input value {input_key} from {self.input_spec} has an invalid shape, maybe you need a CNN?"
)
if self.input_has_agent_dim:
if input_shape[-1] != self.n_agents:
diff --git a/benchmarl/utils.py b/benchmarl/utils.py
index 5d83ecc6..d2d63ae6 100644
--- a/benchmarl/utils.py
+++ b/benchmarl/utils.py
@@ -21,6 +21,8 @@ def _read_yaml_config(config_file: str) -> Dict[str, Any]:
with open(config_file) as config:
yaml_string = config.read()
config_dict = yaml.safe_load(yaml_string)
+ if config_dict is None:
+ config_dict = {}
if "defaults" in config_dict.keys():
del config_dict["defaults"]
return config_dict
diff --git a/docs/source/concepts/components.rst b/docs/source/concepts/components.rst
index 95be5e70..60698590 100644
--- a/docs/source/concepts/components.rst
+++ b/docs/source/concepts/components.rst
@@ -73,7 +73,7 @@ Environments
Tasks are scenarios from a specific environment which constitute the MARL
challenge to solve.
-They differ based on many aspects, here is a table with the current environments in BenchMARL
+They differ based on many aspects, here is a table with the current environments in BenchMARL:
.. _environment-table:
@@ -89,6 +89,9 @@ They differ based on many aspects, here is a table with the current environments
+-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+
| :class:`~benchmarl.environments.PettingZooTask` | 10 | Cooperative + Competitive | Yes + No | Shared + Independent | Continuous + Discrete | No |
+-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+
+ | :class:`~benchmarl.environments.MeltingPotTask` | 49 | Cooperative + Competitive | Yes | Independent | Discrete | No |
+ +-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+
+
Models
diff --git a/docs/source/modules/environments.rst b/docs/source/modules/environments.rst
index 78437d0a..020e6ae3 100644
--- a/docs/source/modules/environments.rst
+++ b/docs/source/modules/environments.rst
@@ -44,7 +44,7 @@ PettingZoo
SMACv2
-----------
+------
.. autosummary::
:nosignatures:
@@ -52,3 +52,13 @@ SMACv2
:template: autosummary/class.rst
Smacv2Task
+
+MeltingPot
+----------
+
+.. autosummary::
+ :nosignatures:
+ :toctree: ../generated
+ :template: autosummary/class.rst
+
+ MeltingPotTask
diff --git a/docs/source/usage/installation.rst b/docs/source/usage/installation.rst
index 135c599a..19c55a27 100644
--- a/docs/source/usage/installation.rst
+++ b/docs/source/usage/installation.rst
@@ -52,6 +52,15 @@ PettingZoo
pip install "pettingzoo[all]"
+MeltingPot
+^^^^^^^^^^
+:github:`null` `GitHub `__
+
+
+.. code-block:: console
+
+ pip install "dm-meltingpot"
+
SMACv2
^^^^^^
diff --git a/test/conftest.py b/test/conftest.py
index ee40e177..d2a9f186 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -8,7 +8,7 @@
import torch_geometric.nn.conv
from benchmarl.experiment import ExperimentConfig
-from benchmarl.models import GnnConfig, MlpConfig
+from benchmarl.models import CnnConfig, GnnConfig, MlpConfig
from benchmarl.models.common import ModelConfig, SequenceModelConfig
from torch import nn
@@ -54,6 +54,26 @@ def mlp_sequence_config() -> ModelConfig:
)
+@pytest.fixture
+def cnn_sequence_config() -> ModelConfig:
+ return SequenceModelConfig(
+ model_configs=[
+ CnnConfig(
+ cnn_num_cells=[4, 3],
+ cnn_kernel_sizes=[3, 2],
+ cnn_strides=1,
+ cnn_paddings=0,
+ cnn_activation_class=nn.Tanh,
+ mlp_num_cells=[4],
+ mlp_activation_class=nn.Tanh,
+ mlp_layer_class=nn.Linear,
+ ),
+ MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear),
+ ],
+ intermediate_sizes=[5],
+ )
+
+
@pytest.fixture
def mlp_gnn_sequence_config() -> ModelConfig:
return SequenceModelConfig(
diff --git a/test/test_meltingpot.py b/test/test_meltingpot.py
new file mode 100644
index 00000000..f86a8b99
--- /dev/null
+++ b/test/test_meltingpot.py
@@ -0,0 +1,130 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+
+
+import packaging
+import pytest
+import torchrl
+
+from benchmarl.algorithms import (
+ algorithm_config_registry,
+ IppoConfig,
+ MasacConfig,
+ QmixConfig,
+)
+from benchmarl.algorithms.common import AlgorithmConfig
+from benchmarl.environments import MeltingPotTask, Task
+from benchmarl.experiment import Experiment
+
+from utils import _has_meltingpot
+from utils_experiment import ExperimentUtils
+
+
+def _get_unique_envs(names):
+ prefixes = set()
+ result = []
+ for env in names:
+ prefix = env.name.split("_")[0]
+ if prefix not in prefixes:
+ prefixes.add(prefix)
+ result.append(env)
+ return result
+
+
+@pytest.mark.skipif(not _has_meltingpot, reason="Meltingpot not found")
+@pytest.mark.skipif(
+ packaging.version.parse(torchrl.__version__).base_version <= "0.3.1",
+ reason="TorchRL <= 0.3.1 does nto support meltingpot",
+)
+class TestMeltingPot:
+ @pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
+ @pytest.mark.parametrize("task", [MeltingPotTask.COMMONS_HARVEST__OPEN])
+ def test_all_algos(
+ self,
+ algo_config: AlgorithmConfig,
+ task: Task,
+ experiment_config,
+ cnn_sequence_config,
+ ):
+ # To not run unsupported algo-task pairs
+ if not algo_config.supports_discrete_actions():
+ pytest.skip()
+
+ task = task.get_from_yaml()
+ experiment_config.checkpoint_interval = 0
+ experiment = Experiment(
+ algorithm_config=algo_config.get_from_yaml(),
+ model_config=cnn_sequence_config,
+ seed=0,
+ config=experiment_config,
+ task=task,
+ )
+ experiment.run()
+
+ @pytest.mark.parametrize("algo_config", [MasacConfig])
+ @pytest.mark.parametrize("task", _get_unique_envs(list(MeltingPotTask)))
+ def test_all_tasks(
+ self,
+ algo_config: AlgorithmConfig,
+ task: Task,
+ experiment_config,
+ cnn_sequence_config,
+ ):
+ task = task.get_from_yaml()
+ experiment_config.checkpoint_interval = 0
+ experiment = Experiment(
+ algorithm_config=algo_config.get_from_yaml(),
+ model_config=cnn_sequence_config,
+ seed=0,
+ config=experiment_config,
+ task=task,
+ )
+ experiment.run()
+
+ @pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
+ @pytest.mark.parametrize("task", [MeltingPotTask.COMMONS_HARVEST__OPEN])
+ def test_reloading_trainer(
+ self,
+ algo_config: AlgorithmConfig,
+ task: Task,
+ experiment_config,
+ cnn_sequence_config,
+ ):
+ # To not run unsupported algo-task pairs
+ if not algo_config.supports_discrete_actions():
+ pytest.skip()
+
+ algo_config = algo_config.get_from_yaml()
+
+ ExperimentUtils.check_experiment_loading(
+ algo_config=algo_config,
+ model_config=cnn_sequence_config,
+ experiment_config=experiment_config,
+ task=task.get_from_yaml(),
+ )
+
+ @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig])
+ @pytest.mark.parametrize("task", [MeltingPotTask.COMMONS_HARVEST__OPEN])
+ @pytest.mark.parametrize("share_params", [True, False])
+ def test_share_policy_params(
+ self,
+ algo_config: AlgorithmConfig,
+ task: Task,
+ share_params,
+ experiment_config,
+ cnn_sequence_config,
+ ):
+ experiment_config.share_policy_params = share_params
+ task = task.get_from_yaml()
+ experiment_config.checkpoint_interval = 0
+ experiment = Experiment(
+ algorithm_config=algo_config.get_from_yaml(),
+ model_config=cnn_sequence_config,
+ seed=0,
+ config=experiment_config,
+ task=task,
+ )
+ experiment.run()
diff --git a/test/test_models.py b/test/test_models.py
index b88435eb..4cff2d59 100644
--- a/test/test_models.py
+++ b/test/test_models.py
@@ -3,11 +3,13 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
-import importlib
+
from typing import List
+import packaging
import pytest
import torch
+import torchrl
from benchmarl.hydra_config import load_model_config_from_hydra
from benchmarl.models import model_config_registry
@@ -77,7 +79,10 @@ def test_models_forward_shape(
pytest.skip() # this combination should never happen
if ("gnn" in model_name) and centralised:
pytest.skip("gnn model is always decentralized")
- if importlib.metadata.version("torchrl") <= "0.3.1" and "cnn" in model_name:
+ if (
+ packaging.version.parse(torchrl.__version__).base_version <= "0.3.1"
+ and "cnn" in model_name
+ ):
pytest.skip("TorchRL <= 0.3.1 does not support MultiAgentCNN")
torch.manual_seed(0)
diff --git a/test/test_pettingzoo.py b/test/test_pettingzoo.py
index e4935e49..d8799bb5 100644
--- a/test/test_pettingzoo.py
+++ b/test/test_pettingzoo.py
@@ -14,7 +14,6 @@
MaddpgConfig,
MasacConfig,
QmixConfig,
- VdnConfig,
)
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import PettingZooTask, Task
@@ -118,9 +117,7 @@ def test_reloading_trainer(
):
experiment_config.prefer_continuous_actions = prefer_continuous
algo_config = algo_config.get_from_yaml()
- if isinstance(algo_config, VdnConfig):
- # There are some bugs currently in TorchRL https://github.com/pytorch/rl/issues/1593
- pytest.skip()
+
ExperimentUtils.check_experiment_loading(
algo_config=algo_config,
model_config=mlp_sequence_config,
diff --git a/test/test_vmas.py b/test/test_vmas.py
index cf4e4d52..3dfc4f81 100644
--- a/test/test_vmas.py
+++ b/test/test_vmas.py
@@ -15,7 +15,6 @@
MaddpgConfig,
MasacConfig,
QmixConfig,
- VdnConfig,
)
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, VmasTask
@@ -107,9 +106,7 @@ def test_reloading_trainer(
mlp_sequence_config,
):
algo_config = algo_config.get_from_yaml()
- if isinstance(algo_config, VdnConfig):
- # There are some bugs currently in TorchRL https://github.com/pytorch/rl/issues/1593
- pytest.skip()
+
ExperimentUtils.check_experiment_loading(
algo_config=algo_config,
model_config=mlp_sequence_config,
diff --git a/test/utils.py b/test/utils.py
index 52214c46..848b149b 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -9,3 +9,4 @@
_has_vmas = importlib.util.find_spec("vmas") is not None
_has_smacv2 = importlib.util.find_spec("smacv2") is not None
_has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None
+_has_meltingpot = importlib.util.find_spec("meltingpot") is not None