Skip to content

Commit

Permalink
Merge branch 'refs/heads/more_vmas_tasks' into simplfy_extending_and_…
Browse files Browse the repository at this point in the history
…examples
  • Loading branch information
matteobettini committed Jun 10, 2024
2 parents 94d7ed6 + 9f4a5de commit 1c7b9db
Show file tree
Hide file tree
Showing 20 changed files with 148 additions and 189 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ They differ based on many aspects, here is a table with the current environments
> BenchMARL uses the [TorchRL MARL API](https://github.com/pytorch/rl/issues/1463) for grouping agents.
> In competitive environments like MPE, for example, teams will be in different groups. Each group has its own loss,
> models, buffers, and so on. Parameter sharing options refer to sharing within the group. See the example on [creating
> a custom algorithm](examples/extending/algorithm/customalgorithm.py) for more info.
> a custom algorithm](examples/extending/algorithm/custom_algorithm.py) for more info.
**Models**. Models are neural networks used to process data. They can be used as actors (policies) or,
when requested, as critics. We provide a set of base models (layers) and a SequenceModel to concatenate
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ def _load_hydra_schemas():
cs.store(name=f"{algo_name}_config", group="algorithm", node=algo_schema)
# Load task schemas
for task_schema_name, task_schema in _task_class_registry.items():
cs.store(name=f"{task_schema_name}_config", group="task", node=task_schema)
cs.store(name=task_schema_name, group="task", node=task_schema)

_load_hydra_schemas()
11 changes: 5 additions & 6 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,14 @@ def get_from_yaml(cls, path: Optional[str] = None):
Returns: the loaded AlgorithmConfig
"""

if path is None:
config = AlgorithmConfig._load_from_yaml(
name=cls.associated_class().__name__
return cls(
**AlgorithmConfig._load_from_yaml(
name=cls.associated_class().__name__,
)
)

else:
config = _read_yaml_config(path)
return cls(**config)
return cls(**_read_yaml_config(path))

@staticmethod
@abstractmethod
Expand Down
100 changes: 74 additions & 26 deletions benchmarl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,86 @@
# LICENSE file in the root directory of this source tree.
#

from .common import _get_task_config_class, Task
from .common import Task
from .meltingpot.common import MeltingPotTask
from .pettingzoo.common import PettingZooTask
from .smacv2.common import Smacv2Task
from .vmas.common import VmasTask

# The enum classes for the environments available.
# This is the only object in this file you need to modify when adding a new environment.
tasks = [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask]

# This is a registry mapping "envname/task_name" to the EnvNameTask.TASK_NAME enum
# It is used by automatically load task enums from yaml files.
# It is populated automatically.
_task_config_registry = {}
# It is used by automatically load task enums from yaml files
task_config_registry = {}
for env in [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask]:
env_config_registry = {
f"{env.env_name()}/{task.name.lower()}": task for task in env
}
task_config_registry.update(env_config_registry)

# This is a registry mapping "envname_taskname" to the TaskConfig python dataclass of the task.
# It is used by hydra to validate loaded configs.
# You will see the "envname_taskname" strings in the hydra defaults at the top of yaml files.
# This is optional and, if a task does not possess an associated TaskConfig, this entry will be simply skipped.
# It is populated automatically.
_task_class_registry = {}

# Automatic population of registries
for env in tasks:
env_config_registry = {}
environemnt_name = env.env_name()
for task in env:
task_name = task.name.lower()
full_task_name = f"{environemnt_name}/{task_name}"
env_config_registry[full_task_name] = task
from .pettingzoo.multiwalker import TaskConfig as MultiwalkerConfig
from .pettingzoo.simple_adversary import TaskConfig as SimpleAdversaryConfig
from .pettingzoo.simple_crypto import TaskConfig as SimpleCryptoConfig
from .pettingzoo.simple_push import TaskConfig as SimplePushConfig
from .pettingzoo.simple_reference import TaskConfig as SimpleReferenceConfig
from .pettingzoo.simple_speaker_listener import (
TaskConfig as SimpleSpeakerListenerConfig,
)
from .pettingzoo.simple_spread import TaskConfig as SimpleSpreadConfig
from .pettingzoo.simple_tag import TaskConfig as SimpleTagConfig
from .pettingzoo.simple_world_comm import TaskConfig as SimpleWorldComm
from .pettingzoo.waterworld import TaskConfig as WaterworldConfig

from .vmas.balance import TaskConfig as BalanceConfig
from .vmas.dispersion import TaskConfig as DispersionConfig
from .vmas.dropout import TaskConfig as DropoutConfig
from .vmas.give_way import TaskConfig as GiveWayConfig
from .vmas.navigation import TaskConfig as NavigationConfig
from .vmas.reverse_transport import TaskConfig as ReverseTransportConfig
from .vmas.sampling import TaskConfig as SamplingConfig
from .vmas.simple_adversary import TaskConfig as VmasSimpleAdversaryConfig
from .vmas.simple_crypto import TaskConfig as VmasSimpleCryptoConfig
from .vmas.simple_push import TaskConfig as VmasSimplePushConfig
from .vmas.simple_reference import TaskConfig as VmasSimpleReferenceConfig
from .vmas.simple_speaker_listener import TaskConfig as VmasSimpleSpeakerListenerConfig
from .vmas.simple_spread import TaskConfig as VmasSimpleSpreadConfig
from .vmas.simple_tag import TaskConfig as VmasSimpleTagConfig
from .vmas.simple_world_comm import TaskConfig as VmasSimpleWorldComm
from .vmas.transport import TaskConfig as TransportConfig
from .vmas.wheel import TaskConfig as WheelConfig
from .vmas.wind_flocking import TaskConfig as WindFlockingConfig

task_config_class = _get_task_config_class(environemnt_name, task_name)
if task_config_class is not None:
_task_class_registry[full_task_name.replace("/", "_")] = task_config_class
_task_config_registry.update(env_config_registry)

# This is a registry mapping task config schemas names to their python dataclass
# It is used by hydra to validate loaded configs.
# You will see the "envname_taskname_config" strings in the hydra defaults at the top of yaml files.
# This feature is optional.
_task_class_registry = {
"vmas_balance_config": BalanceConfig,
"vmas_sampling_config": SamplingConfig,
"vmas_navigation_config": NavigationConfig,
"vmas_transport_config": TransportConfig,
"vmas_reverse_transport_config": ReverseTransportConfig,
"vmas_wheel_config": WheelConfig,
"vmas_dispersion_config": DispersionConfig,
"vmas_give_way_config": GiveWayConfig,
"vmas_wind_flocking_config": WindFlockingConfig,
"vmas_dropout_config": DropoutConfig,
"vmas_simple_adversary_config": VmasSimpleAdversaryConfig,
"vmas_simple_crypto_config": VmasSimpleCryptoConfig,
"vmas_simple_push_config": VmasSimplePushConfig,
"vmas_simple_reference_config": VmasSimpleReferenceConfig,
"vmas_simple_speaker_listener_config": VmasSimpleSpeakerListenerConfig,
"vmas_simple_spread_config": VmasSimpleSpreadConfig,
"vmas_simple_tag_config": VmasSimpleTagConfig,
"vmas_simple_world_comm_config": VmasSimpleWorldComm,
"pettingzoo_multiwalker_config": MultiwalkerConfig,
"pettingzoo_waterworld_config": WaterworldConfig,
"pettingzoo_simple_adversary_config": SimpleAdversaryConfig,
"pettingzoo_simple_crypto_config": SimpleCryptoConfig,
"pettingzoo_simple_push_config": SimplePushConfig,
"pettingzoo_simple_reference_config": SimpleReferenceConfig,
"pettingzoo_simple_speaker_listener_config": SimpleSpeakerListenerConfig,
"pettingzoo_simple_spread_config": SimpleSpreadConfig,
"pettingzoo_simple_tag_config": SimpleTagConfig,
"pettingzoo_simple_world_comm_config": SimpleWorldComm,
}
64 changes: 21 additions & 43 deletions benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import importlib
import os
import os.path as osp
import warnings
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
Expand All @@ -21,44 +20,25 @@
from benchmarl.utils import _read_yaml_config, DEVICE_TYPING


def _type_check_task_config(
environemnt_name: str,
task_name: str,
config: Dict[str, Any],
warn_on_missing_dataclass: bool = True,
):

task_config_class = _get_task_config_class(environemnt_name, task_name)

if task_config_class is not None:
return task_config_class(**config).__dict__
else:
if warn_on_missing_dataclass:
warnings.warn(
"TaskConfig python dataclass not foud, task is being loaded without type checks"
)
return config
def _load_config(name: str, config: Dict[str, Any]):
if not name.endswith(".py"):
name += ".py"

pathname = None
for dirpath, _, filenames in os.walk(osp.dirname(__file__)):
if pathname is None:
for filename in filenames:
if filename == name:
pathname = os.path.join(dirpath, filename)
break

def _get_task_config_class(environemnt_name: str, task_name: str):
if not task_name.endswith(".py"):
task_name += ".py"
if pathname is None:
raise ValueError(f"Task {name} not found.")

pathname = None
for dirpath, _, filenames in os.walk(
Path(osp.dirname(__file__)) / environemnt_name
):
if task_name in filenames:
pathname = os.path.join(dirpath, task_name)
break

if pathname is not None:
spec = importlib.util.spec_from_file_location("", pathname)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.TaskConfig
else:
return None
spec = importlib.util.spec_from_file_location("", pathname)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.TaskConfig(**config).__dict__


class Task(Enum):
Expand Down Expand Up @@ -334,12 +314,10 @@ def get_from_yaml(self, path: Optional[str] = None) -> Task:
Returns: the task with the loaded config
"""
environment_name = self.env_name()
task_name = self.name.lower()
full_name = str(Path(environment_name) / Path(task_name))
if path is None:
config = Task._load_from_yaml(full_name)
task_name = self.name.lower()
return self.update_config(
Task._load_from_yaml(str(Path(self.env_name()) / Path(task_name)))
)
else:
config = _read_yaml_config(path)
config = _type_check_task_config(environment_name, task_name, config)
return self.update_config(config)
return self.update_config(**_read_yaml_config(path))
11 changes: 2 additions & 9 deletions benchmarl/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
# LICENSE file in the root directory of this source tree.
#
import importlib
from dataclasses import is_dataclass

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, task_config_registry
from benchmarl.environments.common import _type_check_task_config
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import model_config_registry
from benchmarl.models.common import ModelConfig, parse_model_config, SequenceModelConfig
Expand Down Expand Up @@ -60,14 +58,9 @@ def load_task_config_from_hydra(cfg: DictConfig, task_name: str) -> Task:
:class:`~benchmarl.environments.Task`
"""
environment_name, inner_task_name = task_name.split("/")
cfg_dict_checked = OmegaConf.to_object(cfg)
if is_dataclass(cfg_dict_checked):
cfg_dict_checked = cfg_dict_checked.__dict__
cfg_dict_checked = _type_check_task_config(
environment_name, inner_task_name, cfg_dict_checked
return task_config_registry[task_name].update_config(
OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
)
return task_config_registry[task_name].update_config(cfg_dict_checked)


def load_experiment_config_from_hydra(cfg: DictConfig) -> ExperimentConfig:
Expand Down
13 changes: 8 additions & 5 deletions benchmarl/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def _load_from_yaml(name: str) -> Dict[str, Any]:
/ "layers"
/ f"{name.lower()}.yaml"
)
return _read_yaml_config(str(yaml_path.resolve()))
cfg = _read_yaml_config(str(yaml_path.resolve()))
return parse_model_config(cfg)

@classmethod
def get_from_yaml(cls, path: Optional[str] = None):
Expand All @@ -310,11 +311,13 @@ def get_from_yaml(cls, path: Optional[str] = None):
Returns: the loaded AlgorithmConfig
"""
if path is None:
config = ModelConfig._load_from_yaml(name=cls.associated_class().__name__)
return cls(
**ModelConfig._load_from_yaml(
name=cls.associated_class().__name__,
)
)
else:
config = _read_yaml_config(path)
config = parse_model_config(config)
return cls(**config)
return cls(**parse_model_config(_read_yaml_config(path)))


@dataclass
Expand Down
9 changes: 5 additions & 4 deletions examples/extending/algorithm/README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@

# Creating a new algorithm

Here are the steps to create a new algorithm.
Here are the steps to create a new algorithm. You can find the custom IQL algorithm
created for this example in [`custom_agorithm.py`](custom_algorithm.py).

1. Create your `CustomAlgorithm` and `CustomAlgorithmConfig` following the example
in [`algorithms/customalgorithm.py`](algorithms/customalgorithm.py). These will be the algorithm code
in [`custom_agorithm.py`](custom_algorithm.py). These will be the algorithm code
and an associated dataclass to validate loaded configs.
2. Create a `customalgorithm.yaml` with the configuration parameters you defined
in your script. Make sure it has `customalgorithm_config` within its defaults at
the top of the file to let hydra know which python dataclass it is
associated to. You can see [`conf/algorithm/customalgorithm.yaml`](conf/algorithm/customalgorithm.yaml)
associated to. You can see [`customiqlalgorithm.yaml`](customiqlalgorithm.yaml)
for an example.
3. Place your algorithm script in [`benchmarl/algorithms`](../../../benchmarl/algorithms) and
your config in [`benchmarl/conf/algorithm`](../../../benchmarl/conf/algorithm) (or any other place you want to
override from)
4. Add `{"customalgorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithms.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py)
4. Add `{"customagorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithms.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py)
5. Load it with
```bash
python benchmarl/run.py algorithm=customalgorithm task=...
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# Copyright (c) 2024.
# ProrokLab (https://www.proroklab.org/)
# All rights reserved.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -19,7 +17,7 @@
from torchrl.objectives import DQNLoss, LossModule, ValueEstimators


class CustomAlgorithm(Algorithm):
class CustomIqlAlgorithm(Algorithm):
def __init__(
self, delay_value: bool, loss_function: str, my_custom_arg: int, **kwargs
):
Expand Down Expand Up @@ -215,7 +213,7 @@ def my_custom_method(self):


@dataclass
class CustomAlgorithmConfig(AlgorithmConfig):
class CustomIqlConfig(AlgorithmConfig):
# This is a class representing the configuration of your algorithm
# It will be used to validate loaded configs, so that everytime you load this algorithm
# we know exactly which and what parameters to expect with their types
Expand All @@ -228,7 +226,7 @@ class CustomAlgorithmConfig(AlgorithmConfig):
@staticmethod
def associated_class() -> Type[Algorithm]:
# The associated algorithm class
return CustomAlgorithm
return CustomIqlAlgorithm

@staticmethod
def supports_continuous_actions() -> bool:
Expand Down
8 changes: 4 additions & 4 deletions examples/extending/model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
Here are the steps to create a new model.

1. Create your `CustomModel` and `CustomModelConfig` following the example
in [`models/custommodel.py`](models/custommodel.py). These will be the model code
in [`custom_model.py`](custom_model.py). These will be the model code
and an associated dataclass to validate loaded configs.
2. Create a `custommodel.yaml` with the configuration parameters you defined
in your script. Make sure it has a `name` entry equal to `custommodel` to let hydra know which python dataclass it is
associated to. You can see [`conf/model/layers/custommodel.yaml`](conf/model/layers/custommodel.yaml)
in your script. Make sure it has a `name` entry equal to `custom_model` to let hydra know which python dataclass it is
associated to. You can see [`custommodel.yaml`](custommodel.yaml)
for an example.
3. Place your model script in [`benchmarl/models`](../../../benchmarl/models) and
your config in [`benchmarl/conf/model/layers`](../../../benchmarl/conf/model/layers) (or any other place you want to
override from)
4. Add `{"custommodel": CustomModelConfig}` to the [`benchmarl.models.model_config_registry`](../../../benchmarl/models/__init__.py)
4. Add `{"custom_model": CustomModelConfig}` to the [`benchmarl.models.model_config_registry`](../../../benchmarl/models/__init__.py)
5. Load it with
```bash
python benchmarl/run.py model=layers/custommodel algorithm=... task=...
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

name: custommodel
name: custom_model
custom_param: 3
activation_class: torch.nn.Tanh
Loading

0 comments on commit 1c7b9db

Please sign in to comment.