diff --git a/.github/unittest/install_meltingpot.sh b/.github/unittest/install_meltingpot.sh new file mode 100644 index 00000000..6842d3b0 --- /dev/null +++ b/.github/unittest/install_meltingpot.sh @@ -0,0 +1,2 @@ + +pip install dm-meltingpot diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ff470a37..05edc0c8 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.11"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/meltingpot_tests.yml b/.github/workflows/meltingpot_tests.yml new file mode 100644 index 00000000..77e21a67 --- /dev/null +++ b/.github/workflows/meltingpot_tests.yml @@ -0,0 +1,43 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: +# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + + +name: meltingpot_tests + +on: + push: + branches: [ $default-branch , "main" ] + pull_request: + branches: [ $default-branch , "main" ] + +permissions: + contents: read + +jobs: + tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.11"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + bash .github/unittest/install_dependencies_nightly.sh + - name: Install meltingpot + run: | + bash .github/unittest/install_meltingpot.sh + - name: Test with pytest + run: | + pytest test/test_meltingpot.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + fail_ci_if_error: false diff --git a/.github/workflows/pettingzoo_tests.yml b/.github/workflows/pettingzoo_tests.yml index 3fbd0125..66d5e81c 100644 --- a/.github/workflows/pettingzoo_tests.yml +++ b/.github/workflows/pettingzoo_tests.yml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.11"] steps: - uses: actions/checkout@v3 @@ -37,9 +37,7 @@ jobs: - name: Test with pytest run: | xvfb-run -s "-screen 0 1024x768x24" pytest test/test_pettingzoo.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html - - - if: matrix.python-version == '3.10' - name: Upload coverage to Codecov + - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: fail_ci_if_error: false diff --git a/.github/workflows/smacv2_tests.yml b/.github/workflows/smacv2_tests.yml index 20761a69..7023afd9 100644 --- a/.github/workflows/smacv2_tests.yml +++ b/.github/workflows/smacv2_tests.yml @@ -21,7 +21,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.11"] steps: - uses: actions/checkout@v3 @@ -44,8 +44,7 @@ jobs: pytest test/test_smacv2.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html - - if: matrix.python-version == '3.10' - name: Upload coverage to Codecov + - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: fail_ci_if_error: false diff --git a/.github/workflows/torchrl_stable_tests.yml b/.github/workflows/torchrl_stable_tests.yml index 413c10d7..3ca7e448 100644 --- a/.github/workflows/torchrl_stable_tests.yml +++ b/.github/workflows/torchrl_stable_tests.yml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.11"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -38,4 +38,4 @@ jobs: bash .github/unittest/install_pettingzoo.sh - name: Tests run: | - xvfb-run -s "-screen 0 1024x768x24" pytest test/test_algorithm.py test/test_models.py test/test_task.py test/test_vmas.py test/test_pettingzoo.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html + xvfb-run -s "-screen 0 1024x768x24" pytest test/test_algorithm.py test/test_models.py test/test_task.py test/test_vmas.py test/test_pettingzoo.py test/test_meltingpot.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index a1736419..eb45153b 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10","3.11"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/vmas_tests.yml b/.github/workflows/vmas_tests.yml index c42e8bd7..c6fe5e91 100644 --- a/.github/workflows/vmas_tests.yml +++ b/.github/workflows/vmas_tests.yml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.11"] steps: - uses: actions/checkout@v3 @@ -38,8 +38,7 @@ jobs: run: | xvfb-run -s "-screen 0 1024x768x24" pytest test/test_vmas.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html - - if: matrix.python-version == '3.10' - name: Upload coverage to Codecov + - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: fail_ci_if_error: false diff --git a/README.md b/README.md index 9e421edb..e63cba6d 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![tests](https://github.com/facebookresearch/BenchMARL/actions/workflows/unit_tests.yml/badge.svg)](test) [![codecov](https://codecov.io/github/facebookresearch/BenchMARL/coverage.svg?branch=main)](https://codecov.io/gh/facebookresearch/BenchMARL) [![Documentation Status](https://readthedocs.org/projects/benchmarl/badge/?version=latest)](https://benchmarl.readthedocs.io/en/latest/?badge=latest) -[![Python](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10-blue.svg)](https://www.python.org/downloads/) +[![Python](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11-blue.svg)](https://www.python.org/downloads/) pypi version [![Downloads](https://static.pepy.tech/personalized-badge/benchmarl?period=total&units=international_system&left_color=grey&right_color=blue&left_text=Downloads)](https://pepy.tech/project/benchmarl) [![Discord Shield](https://dcbadge.vercel.app/api/server/jEEWCn6T3p?style=flat)](https://discord.gg/jEEWCn6T3p) @@ -113,6 +113,10 @@ pip install vmas pip install "pettingzoo[all]" ``` +##### MeltingPot +```bash +pip install dm-meltingpot +``` ##### SMACv2 Follow the instructions on the environment [repository](https://github.com/oxwhirl/smacv2). @@ -236,12 +240,14 @@ determine the training strategy. Here is a table with the currently implemented challenge to solve. They differ based on many aspects, here is a table with the current environments in BenchMARL -| Environment | Tasks | Cooperation | Global state | Reward function | Action space | Vectorized | -|--------------------------------------------------------------------|-------------------------------------|---------------------------|--------------|-------------------------------|-----------------------|:----------------:| -| [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) | [18](benchmarl/conf/task/vmas) | Cooperative + Competitive | No | Shared + Independent + Global | Continuous + Discrete | Yes | -| [SMACv2](https://github.com/oxwhirl/smacv2) | [15](benchmarl/conf/task/smacv2) | Cooperative | Yes | Global | Discrete | No | -| [MPE](https://github.com/openai/multiagent-particle-envs) | [8](benchmarl/conf/task/pettingzoo) | Cooperative + Competitive | Yes | Shared + Independent | Continuous + Discrete | No | -| [SISL](https://github.com/sisl/MADRL) | [2](benchmarl/conf/task/pettingzoo) | Cooperative | No | Shared | Continuous | No | +| Environment | Tasks | Cooperation | Global state | Reward function | Action space | Vectorized | +|--------------------------------------------------------------------|--------------------------------------|---------------------------|--------------|-------------------------------|-----------------------|:----------------:| +| [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) | [18](benchmarl/conf/task/vmas) | Cooperative + Competitive | No | Shared + Independent + Global | Continuous + Discrete | Yes | +| [SMACv2](https://github.com/oxwhirl/smacv2) | [15](benchmarl/conf/task/smacv2) | Cooperative | Yes | Global | Discrete | No | +| [MPE](https://github.com/openai/multiagent-particle-envs) | [8](benchmarl/conf/task/pettingzoo) | Cooperative + Competitive | Yes | Shared + Independent | Continuous + Discrete | No | +| [SISL](https://github.com/sisl/MADRL) | [2](benchmarl/conf/task/pettingzoo) | Cooperative | No | Shared | Continuous | No | +| [MeltingPot](https://github.com/google-deepmind/meltingpot) | [49](benchmarl/conf/task/meltingpot) | Cooperative + Competitive | Yes | Independent | Discrete | No | + > [!NOTE] > BenchMARL uses the [TorchRL MARL API](https://github.com/pytorch/rl/issues/1463) for grouping agents. diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index 4155197d..e9b885b1 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -7,7 +7,7 @@ import pathlib from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Iterable, Optional, Tuple, Type +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type from tensordict import TensorDictBase from tensordict.nn import TensorDictModule, TensorDictSequential @@ -19,6 +19,7 @@ TensorDictReplayBuffer, ) from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement +from torchrl.envs import Compose, Transform from torchrl.objectives import LossModule from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater @@ -132,8 +133,7 @@ def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater return self._losses_and_updaters[group] def get_replay_buffer( - self, - group: str, + self, group: str, transforms: List[Transform] = None ) -> ReplayBuffer: """ Get the ReplayBuffer for a specific group. @@ -141,6 +141,7 @@ def get_replay_buffer( Args: group (str): agent group of the loss and updater + transforms (optional, list of Transform): Transforms to apply to the replay buffer ``.sample()`` call Returns: ReplayBuffer the group """ @@ -154,6 +155,7 @@ def get_replay_buffer( sampler=sampler, batch_size=sampling_size, priority_key=(group, "td_error"), + transform=Compose(*transforms) if transforms is not None else None, ) def get_policy_for_loss(self, group: str) -> TensorDictModule: diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index 712291f6..2a357d73 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -82,7 +82,7 @@ evaluation_episodes: 10 evaluation_deterministic_actions: True # List of loggers to use, options are: wandb, csv, tensorboard, mflow -loggers: [wandb] +loggers: [] # Create a json folder as part of the output in the format of marl-eval create_json: True diff --git a/benchmarl/conf/task/meltingpot/allelopathic_harvest__open.yaml b/benchmarl/conf/task/meltingpot/allelopathic_harvest__open.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/bach_or_stravinsky_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/bach_or_stravinsky_in_the_matrix__arena.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/bach_or_stravinsky_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/bach_or_stravinsky_in_the_matrix__repeated.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/boat_race__eight_races.yaml b/benchmarl/conf/task/meltingpot/boat_race__eight_races.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/chemistry__three_metabolic_cycles.yaml b/benchmarl/conf/task/meltingpot/chemistry__three_metabolic_cycles.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/chemistry__three_metabolic_cycles_with_plentiful_distractors.yaml b/benchmarl/conf/task/meltingpot/chemistry__three_metabolic_cycles_with_plentiful_distractors.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/chemistry__two_metabolic_cycles.yaml b/benchmarl/conf/task/meltingpot/chemistry__two_metabolic_cycles.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/chemistry__two_metabolic_cycles_with_distractors.yaml b/benchmarl/conf/task/meltingpot/chemistry__two_metabolic_cycles_with_distractors.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/chicken_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/chicken_in_the_matrix__arena.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/chicken_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/chicken_in_the_matrix__repeated.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/clean_up.yaml b/benchmarl/conf/task/meltingpot/clean_up.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/coins.yaml b/benchmarl/conf/task/meltingpot/coins.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__asymmetric.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__asymmetric.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__circuit.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__circuit.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__cramped.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__cramped.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__crowded.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__crowded.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__figure_eight.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__figure_eight.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__forced.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__forced.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/collaborative_cooking__ring.yaml b/benchmarl/conf/task/meltingpot/collaborative_cooking__ring.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/commons_harvest__closed.yaml b/benchmarl/conf/task/meltingpot/commons_harvest__closed.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/commons_harvest__open.yaml b/benchmarl/conf/task/meltingpot/commons_harvest__open.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/commons_harvest__partnership.yaml b/benchmarl/conf/task/meltingpot/commons_harvest__partnership.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/coop_mining.yaml b/benchmarl/conf/task/meltingpot/coop_mining.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/daycare.yaml b/benchmarl/conf/task/meltingpot/daycare.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/externality_mushrooms__dense.yaml b/benchmarl/conf/task/meltingpot/externality_mushrooms__dense.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/factory_commons__either_or.yaml b/benchmarl/conf/task/meltingpot/factory_commons__either_or.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/fruit_market__concentric_rivers.yaml b/benchmarl/conf/task/meltingpot/fruit_market__concentric_rivers.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/gift_refinements.yaml b/benchmarl/conf/task/meltingpot/gift_refinements.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/hidden_agenda.yaml b/benchmarl/conf/task/meltingpot/hidden_agenda.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/paintball__capture_the_flag.yaml b/benchmarl/conf/task/meltingpot/paintball__capture_the_flag.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/paintball__king_of_the_hill.yaml b/benchmarl/conf/task/meltingpot/paintball__king_of_the_hill.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/predator_prey__alley_hunt.yaml b/benchmarl/conf/task/meltingpot/predator_prey__alley_hunt.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/predator_prey__open.yaml b/benchmarl/conf/task/meltingpot/predator_prey__open.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/predator_prey__orchard.yaml b/benchmarl/conf/task/meltingpot/predator_prey__orchard.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/predator_prey__random_forest.yaml b/benchmarl/conf/task/meltingpot/predator_prey__random_forest.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/prisoners_dilemma_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/prisoners_dilemma_in_the_matrix__arena.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/prisoners_dilemma_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/prisoners_dilemma_in_the_matrix__repeated.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/pure_coordination_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/pure_coordination_in_the_matrix__arena.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/pure_coordination_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/pure_coordination_in_the_matrix__repeated.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/rationalizable_coordination_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/rationalizable_coordination_in_the_matrix__arena.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/rationalizable_coordination_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/rationalizable_coordination_in_the_matrix__repeated.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__arena.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__one_shot.yaml b/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__one_shot.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/running_with_scissors_in_the_matrix__repeated.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/stag_hunt_in_the_matrix__arena.yaml b/benchmarl/conf/task/meltingpot/stag_hunt_in_the_matrix__arena.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/stag_hunt_in_the_matrix__repeated.yaml b/benchmarl/conf/task/meltingpot/stag_hunt_in_the_matrix__repeated.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/territory__inside_out.yaml b/benchmarl/conf/task/meltingpot/territory__inside_out.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/territory__open.yaml b/benchmarl/conf/task/meltingpot/territory__open.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/conf/task/meltingpot/territory__rooms.yaml b/benchmarl/conf/task/meltingpot/territory__rooms.yaml new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/environments/__init__.py b/benchmarl/environments/__init__.py index 874a6b74..c6359305 100644 --- a/benchmarl/environments/__init__.py +++ b/benchmarl/environments/__init__.py @@ -5,6 +5,7 @@ # from .common import Task +from .meltingpot.common import MeltingPotTask from .pettingzoo.common import PettingZooTask from .smacv2.common import Smacv2Task from .vmas.common import VmasTask @@ -12,7 +13,7 @@ # This is a registry mapping "envname/task_name" to the EnvNameTask.TASK_NAME enum # It is used by automatically load task enums from yaml files task_config_registry = {} -for env in [VmasTask, Smacv2Task, PettingZooTask]: +for env in [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask]: env_config_registry = { f"{env.env_name()}/{task.name.lower()}": task for task in env } @@ -31,6 +32,7 @@ from .pettingzoo.simple_tag import TaskConfig as SimpleTagConfig from .pettingzoo.simple_world_comm import TaskConfig as SimpleWorldComm from .pettingzoo.waterworld import TaskConfig as WaterworldConfig + from .vmas.balance import TaskConfig as BalanceConfig from .vmas.dispersion import TaskConfig as DispersionConfig from .vmas.dropout import TaskConfig as DropoutConfig @@ -38,7 +40,6 @@ from .vmas.navigation import TaskConfig as NavigationConfig from .vmas.reverse_transport import TaskConfig as ReverseTransportConfig from .vmas.sampling import TaskConfig as SamplingConfig - from .vmas.simple_adverasary import TaskConfig as VmasSimpleAdversaryConfig from .vmas.simple_crypto import TaskConfig as VmasSimpleCryptoConfig from .vmas.simple_push import TaskConfig as VmasSimplePushConfig diff --git a/benchmarl/environments/common.py b/benchmarl/environments/common.py index 32faa7dc..377f9ded 100644 --- a/benchmarl/environments/common.py +++ b/benchmarl/environments/common.py @@ -96,13 +96,13 @@ def get_env_fun( num_envs (int): The number of envs that should be in the batch_size of the returned env. In vectorized envs, this can be used to set the number of batched environments. If your environment is not vectorized, you can just ignore this, and it will be - wrapped in a torchrl.envs.SerialEnv with num_envs automatically. + wrapped in a :class:`torchrl.envs.SerialEnv` with num_envs automatically. continuous_actions (bool): Whether your environment should have continuous or discrete actions. If your environment does not support both, ignore this and refer to the supports_x_actions methods. seed (optional, int): The seed of your env device (str): the device of your env, you can pass this to any torchrl env constructor - Returns: a function that takes no arguments and returns a torchrl.envs.EnvBase object + Returns: a function that takes no arguments and returns a :class:`torchrl.envs.EnvBase` object """ raise NotImplementedError @@ -242,6 +242,28 @@ def get_reward_sum_transform(self, env: EnvBase) -> Transform: reset_keys = env.reset_keys return RewardSum(reset_keys=reset_keys) + def get_env_transforms(self, env: EnvBase) -> List[Transform]: + """ + Returns a list of :class:`torchrl.envs.Transform` to be applied to the env. + + Args: + env (EnvBase): An environment created via self.get_env_fun + + + """ + return [] + + def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]: + """ + Returns a list of :class:`torchrl.envs.Transform` to be applied to the :class:`torchrl.data.ReplayBuffer`. + + Args: + env (EnvBase): An environment created via self.get_env_fun + + + """ + return [] + @staticmethod def render_callback(experiment, env: EnvBase, data: TensorDictBase): try: diff --git a/benchmarl/environments/meltingpot/__init__.py b/benchmarl/environments/meltingpot/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/environments/meltingpot/common.py b/benchmarl/environments/meltingpot/common.py new file mode 100644 index 00000000..f209d8b1 --- /dev/null +++ b/benchmarl/environments/meltingpot/common.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import Callable, Dict, List, Optional + +import torch +from tensordict import TensorDictBase + +from torchrl.data import CompositeSpec +from torchrl.envs import DoubleToFloat, DTypeCastTransform, EnvBase, Transform + +from benchmarl.environments.common import Task +from benchmarl.utils import DEVICE_TYPING + + +class MeltingPotTask(Task): + """Enum for meltingpot tasks.""" + + PREDATOR_PREY__ALLEY_HUNT = None + CLEAN_UP = None + COLLABORATIVE_COOKING__CIRCUIT = None + FRUIT_MARKET__CONCENTRIC_RIVERS = None + COLLABORATIVE_COOKING__FIGURE_EIGHT = None + PAINTBALL__KING_OF_THE_HILL = None + FACTORY_COMMONS__EITHER_OR = None + PURE_COORDINATION_IN_THE_MATRIX__ARENA = None + RUNNING_WITH_SCISSORS_IN_THE_MATRIX__REPEATED = None + COLLABORATIVE_COOKING__CRAMPED = None + RUNNING_WITH_SCISSORS_IN_THE_MATRIX__ARENA = None + PRISONERS_DILEMMA_IN_THE_MATRIX__REPEATED = None + TERRITORY__OPEN = None + STAG_HUNT_IN_THE_MATRIX__REPEATED = None + CHICKEN_IN_THE_MATRIX__REPEATED = None + GIFT_REFINEMENTS = None + PURE_COORDINATION_IN_THE_MATRIX__REPEATED = None + COLLABORATIVE_COOKING__FORCED = None + RATIONALIZABLE_COORDINATION_IN_THE_MATRIX__ARENA = None + BACH_OR_STRAVINSKY_IN_THE_MATRIX__ARENA = None + CHEMISTRY__TWO_METABOLIC_CYCLES_WITH_DISTRACTORS = None + COMMONS_HARVEST__PARTNERSHIP = None + PREDATOR_PREY__OPEN = None + TERRITORY__ROOMS = None + HIDDEN_AGENDA = None + COOP_MINING = None + DAYCARE = None + PRISONERS_DILEMMA_IN_THE_MATRIX__ARENA = None + TERRITORY__INSIDE_OUT = None + BACH_OR_STRAVINSKY_IN_THE_MATRIX__REPEATED = None + COMMONS_HARVEST__CLOSED = None + CHEMISTRY__THREE_METABOLIC_CYCLES_WITH_PLENTIFUL_DISTRACTORS = None + STAG_HUNT_IN_THE_MATRIX__ARENA = None + PAINTBALL__CAPTURE_THE_FLAG = None + COLLABORATIVE_COOKING__CROWDED = None + ALLELOPATHIC_HARVEST__OPEN = None + COLLABORATIVE_COOKING__RING = None + COMMONS_HARVEST__OPEN = None + COINS = None + PREDATOR_PREY__ORCHARD = None + PREDATOR_PREY__RANDOM_FOREST = None + COLLABORATIVE_COOKING__ASYMMETRIC = None + RATIONALIZABLE_COORDINATION_IN_THE_MATRIX__REPEATED = None + CHEMISTRY__THREE_METABOLIC_CYCLES = None + RUNNING_WITH_SCISSORS_IN_THE_MATRIX__ONE_SHOT = None + CHEMISTRY__TWO_METABOLIC_CYCLES = None + CHICKEN_IN_THE_MATRIX__ARENA = None + BOAT_RACE__EIGHT_RACES = None + EXTERNALITY_MUSHROOMS__DENSE = None + + def get_env_fun( + self, + num_envs: int, + continuous_actions: bool, + seed: Optional[int], + device: DEVICE_TYPING, + ) -> Callable[[], EnvBase]: + from torchrl.envs.libs.meltingpot import MeltingpotEnv + + return lambda: MeltingpotEnv( + substrate=self.name.lower(), + categorical_actions=True, + **self.config, + ) + + def supports_continuous_actions(self) -> bool: + return False + + def supports_discrete_actions(self) -> bool: + return True + + def has_render(self, env: EnvBase) -> bool: + return True + + def max_steps(self, env: EnvBase) -> int: + return self.config.get("max_steps", 100) + + def group_map(self, env: EnvBase) -> Dict[str, List[str]]: + return env.group_map + + def get_env_transforms(self, env: EnvBase) -> List[Transform]: + return [DoubleToFloat()] + + def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]: + return [ + DTypeCastTransform( + dtype_in=torch.uint8, + dtype_out=torch.float, + in_keys=[ + "RGB", + *[ + (group, "observation", "RGB") + for group in self.group_map(env).keys() + ], + ("next", "RGB"), + *[ + ("next", group, "observation", "RGB") + for group in self.group_map(env).keys() + ], + ], + in_keys_inv=[], + ) + ] + + def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]: + observation_spec = env.observation_spec.clone() + for group in self.group_map(env): + del observation_spec[group] + if list(observation_spec.keys()) != ["RGB"]: + raise ValueError( + f"More than one global state key found in observation spec {observation_spec}." + ) + return observation_spec + + def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]: + return None + + def observation_spec(self, env: EnvBase) -> CompositeSpec: + observation_spec = env.observation_spec.clone() + for group_key in list(observation_spec.keys()): + if group_key not in self.group_map(env).keys(): + del observation_spec[group_key] + else: + group_obs_spec = observation_spec[group_key]["observation"] + for key in list(group_obs_spec.keys()): + if key != "RGB": + del group_obs_spec[key] + return observation_spec + + def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]: + observation_spec = env.observation_spec.clone() + for group_key in list(observation_spec.keys()): + if group_key not in self.group_map(env).keys(): + del observation_spec[group_key] + else: + group_obs_spec = observation_spec[group_key]["observation"] + del group_obs_spec["RGB"] + return observation_spec + + def action_spec(self, env: EnvBase) -> CompositeSpec: + return env.full_action_spec + + @staticmethod + def env_name() -> str: + return "meltingpot" + + @staticmethod + def render_callback(experiment, env: EnvBase, data: TensorDictBase): + return data.get("RGB") diff --git a/benchmarl/environments/pettingzoo/common.py b/benchmarl/environments/pettingzoo/common.py index 5b27754f..fbaadf59 100644 --- a/benchmarl/environments/pettingzoo/common.py +++ b/benchmarl/environments/pettingzoo/common.py @@ -122,7 +122,6 @@ def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]: del observation_spec["state"] if observation_spec.is_empty(): return None - return observation_spec def observation_spec(self, env: EnvBase) -> CompositeSpec: @@ -132,7 +131,8 @@ def observation_spec(self, env: EnvBase) -> CompositeSpec: for key in list(group_obs_spec.keys()): if key != "observation": del group_obs_spec[key] - + if "state" in observation_spec.keys(): + del observation_spec["state"] return observation_spec def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]: @@ -142,10 +142,12 @@ def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]: for key in list(group_obs_spec.keys()): if key != "info": del group_obs_spec[key] + if "state" in observation_spec.keys(): + del observation_spec["state"] return observation_spec def action_spec(self, env: EnvBase) -> CompositeSpec: - return env.input_spec["full_action_spec"] + return env.full_action_spec @staticmethod def env_name() -> str: diff --git a/benchmarl/environments/smacv2/common.py b/benchmarl/environments/smacv2/common.py index 47043396..e968c31f 100644 --- a/benchmarl/environments/smacv2/common.py +++ b/benchmarl/environments/smacv2/common.py @@ -88,7 +88,7 @@ def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]: return observation_spec def action_spec(self, env: EnvBase) -> CompositeSpec: - return env.input_spec["full_action_spec"] + return env.full_action_spec @staticmethod def log_info(batch: TensorDictBase) -> Dict[str, float]: diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index a839fdbb..f7a75a99 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -387,7 +387,6 @@ def _setup_task(self): device=self.config.sampling_device, ) ) - self.observation_spec = self.task.observation_spec(test_env) self.info_spec = self.task.info_spec(test_env) self.state_spec = self.task.state_spec(test_env) @@ -397,24 +396,34 @@ def _setup_task(self): self.train_group_map = copy.deepcopy(self.group_map) self.max_steps = self.task.max_steps(test_env) - transforms = [self.task.get_reward_sum_transform(test_env)] - transform = Compose(*transforms) + transforms_env = self.task.get_env_transforms(test_env) + transforms_training = transforms_env + [ + self.task.get_reward_sum_transform(test_env) + ] + + transforms_env = Compose(*transforms_env) + transforms_training = Compose(*transforms_training) if test_env.batch_size == (): self.env_func = lambda: TransformedEnv( SerialEnv(self.config.n_envs_per_worker(self.on_policy), env_func), - transform.clone(), + transforms_training.clone(), ) else: - self.env_func = lambda: TransformedEnv(env_func(), transform.clone()) + self.env_func = lambda: TransformedEnv( + env_func(), transforms_training.clone() + ) - self.test_env = test_env.to(self.config.sampling_device) + self.test_env = TransformedEnv(test_env, transforms_env.clone()).to( + self.config.sampling_device + ) def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm(experiment=self) self.replay_buffers = { group: self.algorithm.get_replay_buffer( group=group, + transforms=self.task.get_replay_buffer_transforms(self.test_env), ) for group in self.group_map.keys() } diff --git a/benchmarl/models/cnn.py b/benchmarl/models/cnn.py index afce12c5..efe77003 100644 --- a/benchmarl/models/cnn.py +++ b/benchmarl/models/cnn.py @@ -229,7 +229,8 @@ def _perform_checks(self): raise ValueError( f"CNN input value {input_key} from {self.input_spec} has an invalid shape" ) - + if not len(self.image_in_keys): + raise ValueError("CNN found no image inputs, maybe use an MLP?") if self.input_has_agent_dim and input_shape_image[-3] != self.n_agents: raise ValueError( "If the CNN input has the agent dimension," @@ -257,7 +258,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather images input = torch.cat( [tensordict.get(in_key) for in_key in self.image_in_keys], dim=-1 - ) + ).to(torch.float) # BenchMARL images are X,Y,C -> we convert them to C, X, Y for processing in TorchRL models input = input.transpose(-3, -1).transpose(-2, -1) @@ -316,9 +317,9 @@ class CnnConfig(ModelConfig): """Dataclass config for a :class:`~benchmarl.models.Cnn`.""" cnn_num_cells: Sequence[int] = MISSING - cnn_kernel_sizes: Sequence[int] = MISSING - cnn_strides: Sequence[int] = MISSING - cnn_paddings: Sequence[int] = MISSING + cnn_kernel_sizes: Union[Sequence[int], int] = MISSING + cnn_strides: Union[Sequence[int], int] = MISSING + cnn_paddings: Union[Sequence[int], int] = MISSING cnn_activation_class: Type[nn.Module] = MISSING mlp_num_cells: Sequence[int] = MISSING diff --git a/benchmarl/models/mlp.py b/benchmarl/models/mlp.py index 9f98c37b..ea810b00 100644 --- a/benchmarl/models/mlp.py +++ b/benchmarl/models/mlp.py @@ -93,7 +93,7 @@ def _perform_checks(self): ) else: raise ValueError( - f"MLP input value {input_key} from {self.input_spec} has an invalid shape" + f"MLP input value {input_key} from {self.input_spec} has an invalid shape, maybe you need a CNN?" ) if self.input_has_agent_dim: if input_shape[-1] != self.n_agents: diff --git a/benchmarl/utils.py b/benchmarl/utils.py index 5d83ecc6..d2d63ae6 100644 --- a/benchmarl/utils.py +++ b/benchmarl/utils.py @@ -21,6 +21,8 @@ def _read_yaml_config(config_file: str) -> Dict[str, Any]: with open(config_file) as config: yaml_string = config.read() config_dict = yaml.safe_load(yaml_string) + if config_dict is None: + config_dict = {} if "defaults" in config_dict.keys(): del config_dict["defaults"] return config_dict diff --git a/docs/source/concepts/components.rst b/docs/source/concepts/components.rst index 95be5e70..60698590 100644 --- a/docs/source/concepts/components.rst +++ b/docs/source/concepts/components.rst @@ -73,7 +73,7 @@ Environments Tasks are scenarios from a specific environment which constitute the MARL challenge to solve. -They differ based on many aspects, here is a table with the current environments in BenchMARL +They differ based on many aspects, here is a table with the current environments in BenchMARL: .. _environment-table: @@ -89,6 +89,9 @@ They differ based on many aspects, here is a table with the current environments +-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+ | :class:`~benchmarl.environments.PettingZooTask` | 10 | Cooperative + Competitive | Yes + No | Shared + Independent | Continuous + Discrete | No | +-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+ + | :class:`~benchmarl.environments.MeltingPotTask` | 49 | Cooperative + Competitive | Yes | Independent | Discrete | No | + +-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+ + Models diff --git a/docs/source/modules/environments.rst b/docs/source/modules/environments.rst index 78437d0a..020e6ae3 100644 --- a/docs/source/modules/environments.rst +++ b/docs/source/modules/environments.rst @@ -44,7 +44,7 @@ PettingZoo SMACv2 ----------- +------ .. autosummary:: :nosignatures: @@ -52,3 +52,13 @@ SMACv2 :template: autosummary/class.rst Smacv2Task + +MeltingPot +---------- + +.. autosummary:: + :nosignatures: + :toctree: ../generated + :template: autosummary/class.rst + + MeltingPotTask diff --git a/docs/source/usage/installation.rst b/docs/source/usage/installation.rst index 135c599a..19c55a27 100644 --- a/docs/source/usage/installation.rst +++ b/docs/source/usage/installation.rst @@ -52,6 +52,15 @@ PettingZoo pip install "pettingzoo[all]" +MeltingPot +^^^^^^^^^^ +:github:`null` `GitHub `__ + + +.. code-block:: console + + pip install "dm-meltingpot" + SMACv2 ^^^^^^ diff --git a/test/conftest.py b/test/conftest.py index ee40e177..d2a9f186 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,7 +8,7 @@ import torch_geometric.nn.conv from benchmarl.experiment import ExperimentConfig -from benchmarl.models import GnnConfig, MlpConfig +from benchmarl.models import CnnConfig, GnnConfig, MlpConfig from benchmarl.models.common import ModelConfig, SequenceModelConfig from torch import nn @@ -54,6 +54,26 @@ def mlp_sequence_config() -> ModelConfig: ) +@pytest.fixture +def cnn_sequence_config() -> ModelConfig: + return SequenceModelConfig( + model_configs=[ + CnnConfig( + cnn_num_cells=[4, 3], + cnn_kernel_sizes=[3, 2], + cnn_strides=1, + cnn_paddings=0, + cnn_activation_class=nn.Tanh, + mlp_num_cells=[4], + mlp_activation_class=nn.Tanh, + mlp_layer_class=nn.Linear, + ), + MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear), + ], + intermediate_sizes=[5], + ) + + @pytest.fixture def mlp_gnn_sequence_config() -> ModelConfig: return SequenceModelConfig( diff --git a/test/test_meltingpot.py b/test/test_meltingpot.py new file mode 100644 index 00000000..f86a8b99 --- /dev/null +++ b/test/test_meltingpot.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +import packaging +import pytest +import torchrl + +from benchmarl.algorithms import ( + algorithm_config_registry, + IppoConfig, + MasacConfig, + QmixConfig, +) +from benchmarl.algorithms.common import AlgorithmConfig +from benchmarl.environments import MeltingPotTask, Task +from benchmarl.experiment import Experiment + +from utils import _has_meltingpot +from utils_experiment import ExperimentUtils + + +def _get_unique_envs(names): + prefixes = set() + result = [] + for env in names: + prefix = env.name.split("_")[0] + if prefix not in prefixes: + prefixes.add(prefix) + result.append(env) + return result + + +@pytest.mark.skipif(not _has_meltingpot, reason="Meltingpot not found") +@pytest.mark.skipif( + packaging.version.parse(torchrl.__version__).base_version <= "0.3.1", + reason="TorchRL <= 0.3.1 does nto support meltingpot", +) +class TestMeltingPot: + @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) + @pytest.mark.parametrize("task", [MeltingPotTask.COMMONS_HARVEST__OPEN]) + def test_all_algos( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + cnn_sequence_config, + ): + # To not run unsupported algo-task pairs + if not algo_config.supports_discrete_actions(): + pytest.skip() + + task = task.get_from_yaml() + experiment_config.checkpoint_interval = 0 + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=cnn_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + + @pytest.mark.parametrize("algo_config", [MasacConfig]) + @pytest.mark.parametrize("task", _get_unique_envs(list(MeltingPotTask))) + def test_all_tasks( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + cnn_sequence_config, + ): + task = task.get_from_yaml() + experiment_config.checkpoint_interval = 0 + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=cnn_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + + @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) + @pytest.mark.parametrize("task", [MeltingPotTask.COMMONS_HARVEST__OPEN]) + def test_reloading_trainer( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + cnn_sequence_config, + ): + # To not run unsupported algo-task pairs + if not algo_config.supports_discrete_actions(): + pytest.skip() + + algo_config = algo_config.get_from_yaml() + + ExperimentUtils.check_experiment_loading( + algo_config=algo_config, + model_config=cnn_sequence_config, + experiment_config=experiment_config, + task=task.get_from_yaml(), + ) + + @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig]) + @pytest.mark.parametrize("task", [MeltingPotTask.COMMONS_HARVEST__OPEN]) + @pytest.mark.parametrize("share_params", [True, False]) + def test_share_policy_params( + self, + algo_config: AlgorithmConfig, + task: Task, + share_params, + experiment_config, + cnn_sequence_config, + ): + experiment_config.share_policy_params = share_params + task = task.get_from_yaml() + experiment_config.checkpoint_interval = 0 + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=cnn_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() diff --git a/test/test_models.py b/test/test_models.py index b88435eb..4cff2d59 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -3,11 +3,13 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -import importlib + from typing import List +import packaging import pytest import torch +import torchrl from benchmarl.hydra_config import load_model_config_from_hydra from benchmarl.models import model_config_registry @@ -77,7 +79,10 @@ def test_models_forward_shape( pytest.skip() # this combination should never happen if ("gnn" in model_name) and centralised: pytest.skip("gnn model is always decentralized") - if importlib.metadata.version("torchrl") <= "0.3.1" and "cnn" in model_name: + if ( + packaging.version.parse(torchrl.__version__).base_version <= "0.3.1" + and "cnn" in model_name + ): pytest.skip("TorchRL <= 0.3.1 does not support MultiAgentCNN") torch.manual_seed(0) diff --git a/test/test_pettingzoo.py b/test/test_pettingzoo.py index e4935e49..d8799bb5 100644 --- a/test/test_pettingzoo.py +++ b/test/test_pettingzoo.py @@ -14,7 +14,6 @@ MaddpgConfig, MasacConfig, QmixConfig, - VdnConfig, ) from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import PettingZooTask, Task @@ -118,9 +117,7 @@ def test_reloading_trainer( ): experiment_config.prefer_continuous_actions = prefer_continuous algo_config = algo_config.get_from_yaml() - if isinstance(algo_config, VdnConfig): - # There are some bugs currently in TorchRL https://github.com/pytorch/rl/issues/1593 - pytest.skip() + ExperimentUtils.check_experiment_loading( algo_config=algo_config, model_config=mlp_sequence_config, diff --git a/test/test_vmas.py b/test/test_vmas.py index cf4e4d52..3dfc4f81 100644 --- a/test/test_vmas.py +++ b/test/test_vmas.py @@ -15,7 +15,6 @@ MaddpgConfig, MasacConfig, QmixConfig, - VdnConfig, ) from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task, VmasTask @@ -107,9 +106,7 @@ def test_reloading_trainer( mlp_sequence_config, ): algo_config = algo_config.get_from_yaml() - if isinstance(algo_config, VdnConfig): - # There are some bugs currently in TorchRL https://github.com/pytorch/rl/issues/1593 - pytest.skip() + ExperimentUtils.check_experiment_loading( algo_config=algo_config, model_config=mlp_sequence_config, diff --git a/test/utils.py b/test/utils.py index 52214c46..848b149b 100644 --- a/test/utils.py +++ b/test/utils.py @@ -9,3 +9,4 @@ _has_vmas = importlib.util.find_spec("vmas") is not None _has_smacv2 = importlib.util.find_spec("smacv2") is not None _has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None +_has_meltingpot = importlib.util.find_spec("meltingpot") is not None