Skip to content

Commit

Permalink
algorithm docstrings
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 9, 2023
1 parent 67c1a78 commit 7ae6572
Show file tree
Hide file tree
Showing 15 changed files with 197 additions and 172 deletions.
2 changes: 2 additions & 0 deletions benchmarl/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .qmix import Qmix, QmixConfig
from .vdn import Vdn, VdnConfig

# A registry mapping "algoname" to its config dataclass
# This is used to aid loading of algorithms from yaml
algorithm_config_registry = {
"mappo": MappoConfig,
"ippo": IppoConfig,
Expand Down
177 changes: 172 additions & 5 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@


class Algorithm(ABC):
"""
Abstract class for an algorithm.
This should be overridden by implemented algorithms
and all abstract methods should be implemented.
Args:
experiment_config (ExperimentConfig): the configuration dataclass for the experiment
model_config (ModelConfig): the configuration dataclass for the policy
critic_model_config (ModelConfig): the configuration dataclass for the (eventual) critic
observation_spec (CompositeSpec): the observation spec of the task
action_spec (CompositeSpec): the action spec of the task
state_spec (CompositeSpec): the state spec of the task
action_mask_spec (CompositeSpec): the action_mask spec of the task
group_map (Dictionary): the group map of the task
on_policy (bool): whether the algorithm has to be trained on policy
"""

def __init__(
self,
experiment_config: "DictConfig", # noqa: F821
Expand All @@ -46,6 +63,7 @@ def __init__(
self.state_spec = state_spec
self.action_mask_spec = action_mask_spec

# Cached values that will be instantiated only once and then remain fixed
self._losses_and_updaters = {}
self._policies_for_loss = {}
self._policies_for_collection = {}
Expand Down Expand Up @@ -97,6 +115,17 @@ def _check_specs(self):
)

def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater]:
"""
Get the LossModule and TargetNetUpdater for a specific group.
This function calls the abstract self._get_loss() which needs to be implemented.
The function will cache the output at the first call and return the cached values in future calls.
Args:
group (str): agent group of the loss and updater
Returns: LossModule and TargetNetUpdater for the group
"""
if group not in self._losses_and_updaters.keys():
action_space = self.action_spec[group, "action"]
continuous = not isinstance(
Expand Down Expand Up @@ -126,6 +155,15 @@ def get_replay_buffer(
self,
group: str,
) -> ReplayBuffer:
"""
Get the ReplayBuffer for a specific group.
This function will check self.on_policy and create the buffer accordingly
Args:
group (str): agent group of the loss and updater
Returns: ReplayBuffer the group
"""
memory_size = self.experiment_config.replay_buffer_memory_size(self.on_policy)
sampling_size = self.experiment_config.train_minibatch_size(self.on_policy)
storing_device = self.device
Expand All @@ -138,6 +176,16 @@ def get_replay_buffer(
)

def get_policy_for_loss(self, group: str) -> TensorDictModule:
"""
Get the non-explorative policy for a specific group loss.
This function calls the abstract self._get_policy_for_loss() which needs to be implemented.
The function will cache the output at the first call and return the cached values in future calls.
Args:
group (str): agent group of the policy
Returns: TensorDictModule representing the policy
"""
if group not in self._policies_for_loss.keys():
action_space = self.action_spec[group, "action"]
continuous = not isinstance(
Expand All @@ -155,6 +203,13 @@ def get_policy_for_loss(self, group: str) -> TensorDictModule:
return self._policies_for_loss[group]

def get_policy_for_collection(self) -> TensorDictSequential:
"""
Get the explorative policy for all groups together.
This function calls the abstract self._get_policy_for_collection() which needs to be implemented.
The function will cache the output at the first call and return the cached values in future calls.
Returns: TensorDictSequential representing all explorative policies
"""
policies = []
for group in self.group_map.keys():
if group not in self._policies_for_collection.keys():
Expand All @@ -173,6 +228,12 @@ def get_policy_for_collection(self) -> TensorDictSequential:
return TensorDictSequential(*policies)

def get_parameters(self, group: str) -> Dict[str, Iterable]:
"""
Get the dictionary mapping loss names to the relative parameters to optimize for a given group.
This function calls the abstract self._get_parameters() which needs to be implemented.
Returns: a dictionary mapping loss names to a parameters' list
"""
return self._get_parameters(
group=group,
loss=self.get_loss_and_updater(group)[0],
Expand All @@ -186,36 +247,99 @@ def get_parameters(self, group: str) -> Dict[str, Iterable]:
def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
"""
Implement this function to return the LossModule for a specific group.
Args:
group (str): agent group of the loss
policy_for_loss (TensorDictModule): the policy to use in the loss
continuous (bool): whether to return a loss for continuous or discrete actions
Returns: LossModule and a bool representing if the loss should have target parameters
"""
raise NotImplementedError

@abstractmethod
def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
"""
Get the dictionary mapping loss names to the relative parameters to optimize for a given group loss.
Returns: a dictionary mapping loss names to a parameters' list
"""
raise NotImplementedError

@abstractmethod
def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:
"""
Get the non-explorative policy for a specific group.
Args:
group (str): agent group of the policy
model_config (ModelConfig): model config class
continuous (bool): whether the policy should be continuous or discrete
Returns: TensorDictModule representing the policy
"""
raise NotImplementedError

@abstractmethod
def _get_policy_for_collection(
self, policy_for_loss: TensorDictModule, group: str, continuous: bool
) -> TensorDictModule:
"""
Implement this function to add an explorative layer to the policy used in the loss.
Args:
policy_for_loss (TensorDictModule): the group policy used in the loss
group (str): agent group
continuous (bool): whether the policy is continuous or discrete
Returns: TensorDictModule representing the explorative policy
"""
raise NotImplementedError

@abstractmethod
def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
"""
This function can be used to reshape data coming from collection before it is passed to the policy.
Args:
group (str): agent group
batch (TensorDictBase): the batch of data coming from the collector
Returns: the processed batch
"""
raise NotImplementedError

def process_loss_vals(
self, group: str, loss_vals: TensorDictBase
) -> TensorDictBase:
"""
Here you can modify the loss_vals tensordict containing entries loss_name->loss_value
For example, you can sum two entries in a new entry, to optimize them together.
Args:
group (str): agent group
loss_vals (TensorDictBase): the tensordict returned by the loss forward method
Returns: the processed loss_vals
"""
return loss_vals


@dataclass
class AlgorithmConfig:
"""
Dataclass representing an algorithm configuration.
This should be overridden by implemented algorithms.
Implementors should:
1. add configuration parameters for their algorithm
2. implement all abstract methods
"""

def get_algorithm(
self,
experiment_config,
Expand All @@ -227,8 +351,23 @@ def get_algorithm(
action_mask_spec: Optional[CompositeSpec],
group_map: Dict[str, List[str]],
) -> Algorithm:
"""
Main function to turn the config into the associated algorithm
Args:
experiment_config (ExperimentConfig): the configuration dataclass for the experiment
model_config (ModelConfig): the configuration dataclass for the policy
critic_model_config (ModelConfig): the configuration dataclass for the (eventual) critic
observation_spec (CompositeSpec): the observation spec of the task
action_spec (CompositeSpec): the action spec of the task
state_spec (CompositeSpec): the state spec of the task
action_mask_spec (CompositeSpec): the action_mask spec of the task
group_map (Dictionary): the group map of the task
Returns: the Algorithm
"""
return self.associated_class()(
**self.__dict__,
**self.__dict__, # Passes all the custom config parameters
experiment_config=experiment_config,
model_config=model_config,
critic_model_config=critic_model_config,
Expand All @@ -250,27 +389,55 @@ def _load_from_yaml(name: str) -> Dict[str, Any]:
)
return read_yaml_config(str(yaml_path.resolve()))

@staticmethod
@abstractmethod
def get_from_yaml(path: Optional[str] = None):
raise NotImplementedError
@classmethod
def get_from_yaml(cls, path: Optional[str] = None):
"""
Load the algorithm configuration from yaml
Args:
path (str, optional): The full path of the yaml file to load from.
If None, it will default to
benchmarl/conf/algorithm/self.associated_class().__name__
Returns: the loaded AlgorithmConfig
"""
if path is None:
return cls(
**AlgorithmConfig._load_from_yaml(
name=cls.associated_class().__name__,
)
)
else:
return cls(**read_yaml_config(path))

@staticmethod
@abstractmethod
def associated_class() -> Type[Algorithm]:
"""
The algorithm class associated to the config
"""
raise NotImplementedError

@staticmethod
@abstractmethod
def on_policy() -> bool:
"""
If the algorithm has to be run on policy or off policy
"""
raise NotImplementedError

@staticmethod
@abstractmethod
def supports_continuous_actions() -> bool:
"""
If the algorithm supports continuous actions
"""
raise NotImplementedError

@staticmethod
@abstractmethod
def supports_discrete_actions() -> bool:
"""
If the algorithm supports discrete actions
"""
raise NotImplementedError
14 changes: 1 addition & 13 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, MISSING
from typing import Dict, Iterable, Optional, Tuple, Type
from typing import Dict, Iterable, Tuple, Type

import torch
from tensordict import TensorDictBase
Expand All @@ -10,7 +10,6 @@

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import read_yaml_config


class Iddpg(Algorithm):
Expand Down Expand Up @@ -239,14 +238,3 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def get_from_yaml(path: Optional[str] = None):
if path is None:
return IddpgConfig(
**AlgorithmConfig._load_from_yaml(
name=IddpgConfig.associated_class().__name__,
)
)
else:
return IddpgConfig(**read_yaml_config(path))
14 changes: 1 addition & 13 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, MISSING
from typing import Dict, Iterable, Optional, Tuple, Type
from typing import Dict, Iterable, Tuple, Type

import torch
from tensordict import TensorDictBase
Expand All @@ -13,7 +13,6 @@

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import read_yaml_config


class Ippo(Algorithm):
Expand Down Expand Up @@ -286,14 +285,3 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return True

@staticmethod
def get_from_yaml(path: Optional[str] = None):
if path is None:
return IppoConfig(
**AlgorithmConfig._load_from_yaml(
name=IppoConfig.associated_class().__name__,
)
)
else:
return IppoConfig(**read_yaml_config(path))
14 changes: 1 addition & 13 deletions benchmarl/algorithms/iql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, MISSING
from typing import Dict, Iterable, Optional, Tuple, Type
from typing import Dict, Iterable, Tuple, Type

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
Expand All @@ -9,7 +9,6 @@

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import read_yaml_config


class Iql(Algorithm):
Expand Down Expand Up @@ -189,14 +188,3 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def get_from_yaml(path: Optional[str] = None):
if path is None:
return IqlConfig(
**AlgorithmConfig._load_from_yaml(
name=IqlConfig.associated_class().__name__,
)
)
else:
return IqlConfig(**read_yaml_config(path))
Loading

0 comments on commit 7ae6572

Please sign in to comment.