Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 9, 2023
1 parent a7dfc96 commit 664c21f
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 11 deletions.
6 changes: 4 additions & 2 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,13 @@ def _get_loss(
return loss_module, True

def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
return {
items = {
"loss_actor": list(loss.actor_network_params.flatten_keys().values()),
"loss_qvalue": list(loss.qvalue_network_params.flatten_keys().values()),
"loss_alpha": [loss.log_alpha],
}
if not self.fixed_alpha:
items.update({"loss_alpha": [loss.log_alpha]})
return items

def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
Expand Down
6 changes: 4 additions & 2 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ def _get_loss(
return loss_module, True

def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
return {
items = {
"loss_actor": list(loss.actor_network_params.flatten_keys().values()),
"loss_qvalue": list(loss.qvalue_network_params.flatten_keys().values()),
"loss_alpha": [loss.log_alpha],
}
if not self.fixed_alpha:
items.update({"loss_alpha": [loss.log_alpha]})
return items

def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
Expand Down
2 changes: 0 additions & 2 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,6 @@ def _optimizer_loop(self, group: str) -> TensorDictBase:

optimizer.step()
optimizer.zero_grad()
elif loss_name.startswith("loss"):
raise AssertionError
self.replay_buffers[group].update_tensordict_priority(subdata)
if self.target_updaters[group] is not None:
self.target_updaters[group].step()
Expand Down
17 changes: 17 additions & 0 deletions premade_scripts/smacv2/conf/algorithm/isac.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
defaults:
- isac_config
- _self_


share_param_critic: True

num_qvalue_nets: 2
loss_function: "l2"
delay_qvalue: True
target_entropy: "auto"
discrete_target_entropy_weight: 0.2

alpha_init: 0.000001
min_alpha: null
max_alpha: null
fixed_alpha: True
18 changes: 18 additions & 0 deletions premade_scripts/smacv2/conf/algorithm/masac.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
defaults:
- masac_config
- _self_



share_param_critic: True

num_qvalue_nets: 2
loss_function: "l2"
delay_qvalue: True
target_entropy: "auto"
discrete_target_entropy_weight: 0.2

alpha_init: 0.000001
min_alpha: null
max_alpha: null
fixed_alpha: True
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ hydra:
seed: 0

experiment:
sampling_device: "cuda"
sampling_device: "cpu"
train_device: "cuda"

share_policy_params: True
Expand All @@ -31,10 +31,10 @@ experiment:

exploration_eps_init: 0.8
exploration_eps_end: 0.01
exploration_anneal_frames: null
exploration_anneal_frames: 1_000_000

max_n_iters: null
max_n_frames: 30_000_000
max_n_frames: 20_000_000

on_policy_collected_frames_per_batch: 6000
on_policy_n_envs_per_worker: 10
Expand Down
2 changes: 1 addition & 1 deletion premade_scripts/smacv2/smacv2_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from omegaconf import DictConfig, OmegaConf


@hydra.main(version_base=None, config_path=".", config_name="config")
@hydra.main(version_base=None, config_path="conf", config_name="config")
def hydra_experiment(cfg: DictConfig) -> None:
hydra_choices = HydraConfig.get().runtime.choices
task_name = hydra_choices.task
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion premade_scripts/vmas/vmas_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from omegaconf import DictConfig, OmegaConf


@hydra.main(version_base=None, config_path=".", config_name="config")
@hydra.main(version_base=None, config_path="conf", config_name="config")
def hydra_experiment(cfg: DictConfig) -> None:
hydra_choices = HydraConfig.get().runtime.choices
task_name = hydra_choices.task
Expand Down

0 comments on commit 664c21f

Please sign in to comment.