Skip to content

Commit

Permalink
[Refactor] Pass experiment object to algorithm (#31)
Browse files Browse the repository at this point in the history
* mappo extractor

* copyright

* copyright

* pass entire experiment to algo

* pass entire experiment to model

* remove kwargs from model input

* empty

* empty

* update

* action_spec to model

* action_spec to model
  • Loading branch information
matteobettini authored Oct 21, 2023
1 parent 4f3cc64 commit 8f26dd2
Show file tree
Hide file tree
Showing 19 changed files with 73 additions and 110 deletions.
80 changes: 18 additions & 62 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import pathlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
from typing import Any, Dict, Iterable, Optional, Tuple, Type

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
CompositeSpec,
DiscreteTensorSpec,
LazyTensorStorage,
OneHotDiscreteTensorSpec,
Expand All @@ -34,40 +33,22 @@ class Algorithm(ABC):
and all abstract methods should be implemented.
Args:
experiment_config (ExperimentConfig): the configuration dataclass for the experiment
model_config (ModelConfig): the configuration dataclass for the policy
critic_model_config (ModelConfig): the configuration dataclass for the (eventual) critic
observation_spec (CompositeSpec): the observation spec of the task
action_spec (CompositeSpec): the action spec of the task
state_spec (CompositeSpec): the state spec of the task
action_mask_spec (CompositeSpec): the action_mask spec of the task
group_map (Dictionary): the group map of the task
on_policy (bool): whether the algorithm has to be trained on policy
experiment (Experiment): the experiment class
"""

def __init__(
self,
experiment_config: "DictConfig", # noqa: F821
model_config: ModelConfig,
critic_model_config: ModelConfig,
observation_spec: CompositeSpec,
action_spec: CompositeSpec,
state_spec: Optional[CompositeSpec],
action_mask_spec: Optional[CompositeSpec],
group_map: Dict[str, List[str]],
on_policy: bool,
):
self.device: DEVICE_TYPING = experiment_config.train_device

self.experiment_config = experiment_config
self.model_config = model_config
self.critic_model_config = critic_model_config
self.on_policy = on_policy
self.group_map = group_map
self.observation_spec = observation_spec
self.action_spec = action_spec
self.state_spec = state_spec
self.action_mask_spec = action_mask_spec
def __init__(self, experiment):
self.experiment = experiment

self.device: DEVICE_TYPING = experiment.config.train_device
self.experiment_config = experiment.config
self.model_config = experiment.model_config
self.critic_model_config = experiment.critic_model_config
self.on_policy = experiment.on_policy
self.group_map = experiment.group_map
self.observation_spec = experiment.observation_spec
self.action_spec = experiment.action_spec
self.state_spec = experiment.state_spec
self.action_mask_spec = experiment.action_mask_spec

# Cached values that will be instantiated only once and then remain fixed
self._losses_and_updaters = {}
Expand Down Expand Up @@ -346,43 +327,18 @@ class AlgorithmConfig:
2. implement all abstract methods
"""

def get_algorithm(
self,
experiment_config,
model_config: ModelConfig,
critic_model_config: ModelConfig,
observation_spec: CompositeSpec,
action_spec: CompositeSpec,
state_spec: CompositeSpec,
action_mask_spec: Optional[CompositeSpec],
group_map: Dict[str, List[str]],
) -> Algorithm:
def get_algorithm(self, experiment) -> Algorithm:
"""
Main function to turn the config into the associated algorithm
Args:
experiment_config (ExperimentConfig): the configuration dataclass for the experiment
model_config (ModelConfig): the configuration dataclass for the policy
critic_model_config (ModelConfig): the configuration dataclass for the (eventual) critic
observation_spec (CompositeSpec): the observation spec of the task
action_spec (CompositeSpec): the action spec of the task
state_spec (CompositeSpec): the state spec of the task
action_mask_spec (CompositeSpec): the action_mask spec of the task
group_map (Dictionary): the group map of the task
experiment (Experiment): the experiment class
Returns: the Algorithm
"""
return self.associated_class()(
**self.__dict__, # Passes all the custom config parameters
experiment_config=experiment_config,
model_config=model_config,
critic_model_config=critic_model_config,
observation_spec=observation_spec,
action_spec=action_spec,
state_spec=state_spec,
action_mask_spec=action_mask_spec,
group_map=group_map,
on_policy=self.on_policy(),
experiment=experiment,
)

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
action_spec=self.action_spec,
)

policy = ProbabilisticActor(
Expand Down Expand Up @@ -217,6 +218,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
action_spec=self.action_spec,
)
)

Expand Down
10 changes: 6 additions & 4 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
critic_coef: float,
loss_critic_type: str,
lmbda: float,
scale_mapping: str,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -40,6 +41,7 @@ def __init__(
self.critic_coef = critic_coef
self.loss_critic_type = loss_critic_type
self.lmbda = lmbda
self.scale_mapping = scale_mapping

#############################
# Overridden abstract methods
Expand All @@ -48,7 +50,6 @@ def __init__(
def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:

# Loss
loss_module = ClipPPOLoss(
actor=policy_for_loss,
Expand Down Expand Up @@ -83,7 +84,6 @@ def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:

n_agents = len(self.group_map[group])
if continuous:
logits_shape = list(self.action_spec[group, "action"].shape)
Expand Down Expand Up @@ -124,11 +124,12 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
action_spec=self.action_spec,
)

if continuous:
extractor_module = TensorDictModule(
NormalParamExtractor(),
NormalParamExtractor(scale_mapping=self.scale_mapping),
in_keys=[(group, "logits")],
out_keys=[(group, "loc"), (group, "scale")],
)
Expand Down Expand Up @@ -261,20 +262,21 @@ def get_critic(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
action_spec=self.action_spec,
)

return value_module


@dataclass
class IppoConfig(AlgorithmConfig):

share_param_critic: bool = MISSING
clip_epsilon: float = MISSING
entropy_coef: float = MISSING
critic_coef: float = MISSING
loss_critic_type: str = MISSING
lmbda: float = MISSING
scale_mapping: str = MISSING

@staticmethod
def associated_class() -> Type[Algorithm]:
Expand Down
3 changes: 1 addition & 2 deletions benchmarl/algorithms/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:

n_agents = len(self.group_map[group])
logits_shape = [
*self.action_spec[group, "action"].shape,
Expand Down Expand Up @@ -99,6 +98,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
action_spec=self.action_spec,
)
if self.action_mask_spec is not None:
action_mask_key = (group, "action_mask")
Expand Down Expand Up @@ -175,7 +175,6 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:

@dataclass
class IqlConfig(AlgorithmConfig):

delay_value: bool = MISSING
loss_function: str = MISSING

Expand Down
10 changes: 7 additions & 3 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
min_alpha: Optional[float],
max_alpha: Optional[float],
fixed_alpha: bool,
scale_mapping: str,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -52,6 +53,7 @@ def __init__(
self.min_alpha = min_alpha
self.max_alpha = max_alpha
self.fixed_alpha = fixed_alpha
self.scale_mapping = scale_mapping

#############################
# Overridden abstract methods
Expand Down Expand Up @@ -126,7 +128,6 @@ def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:

n_agents = len(self.group_map[group])
if continuous:
logits_shape = list(self.action_spec[group, "action"].shape)
Expand Down Expand Up @@ -167,11 +168,12 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
action_spec=self.action_spec,
)

if continuous:
extractor_module = TensorDictModule(
NormalParamExtractor(),
NormalParamExtractor(scale_mapping=self.scale_mapping),
in_keys=[(group, "logits")],
out_keys=[(group, "loc"), (group, "scale")],
)
Expand Down Expand Up @@ -291,6 +293,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
action_spec=self.action_spec,
)

return value_module
Expand Down Expand Up @@ -346,6 +349,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
action_spec=self.action_spec,
)
)

Expand All @@ -354,7 +358,6 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:

@dataclass
class IsacConfig(AlgorithmConfig):

share_param_critic: bool = MISSING

num_qvalue_nets: int = MISSING
Expand All @@ -367,6 +370,7 @@ class IsacConfig(AlgorithmConfig):
min_alpha: Optional[float] = MISSING
max_alpha: Optional[float] = MISSING
fixed_alpha: bool = MISSING
scale_mapping: str = MISSING

@staticmethod
def associated_class() -> Type[Algorithm]:
Expand Down
6 changes: 3 additions & 3 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def _get_loss(
)

def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:

return {
"loss_actor": list(loss.actor_network_params.flatten_keys().values()),
"loss_value": list(loss.value_network_params.flatten_keys().values()),
Expand Down Expand Up @@ -103,6 +102,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
action_spec=self.action_spec,
)

policy = ProbabilisticActor(
Expand Down Expand Up @@ -222,11 +222,11 @@ def get_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
action_spec=self.action_spec,
)
)

else:

modules.append(
TensorDictModule(
lambda obs, action: torch.cat([obs, action], dim=-1),
Expand Down Expand Up @@ -263,6 +263,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
action_spec=self.action_spec,
)
)

Expand All @@ -282,7 +283,6 @@ def get_value_module(self, group: str) -> TensorDictModule:

@dataclass
class MaddpgConfig(AlgorithmConfig):

share_param_critic: bool = MISSING

loss_function: str = MISSING
Expand Down
Loading

0 comments on commit 8f26dd2

Please sign in to comment.