Skip to content

Commit

Permalink
[Conf] Change reloading and evaluation interval to be in frame units
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 10, 2023
1 parent 1a427fe commit 62224bc
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 39 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -388,20 +388,20 @@ python benchmarl/run.py algorithm=mappo task=vmas/balance "experiment.loggers=[w

### Checkpointing

Experiments can be checkpointed every `experiment.checkpoint_interval` iterations.
Experiments can be checkpointed every `experiment.checkpoint_interval` collected frames.
Experiments will use an output folder for logging and checkpointing which can be specified in `experiment.save_folder`.
If this is left unspecified,
the default will be the hydra output folder (if using hydra) or (otherwise) the current directory
where the script is launched.
The output folder will contain a folder for each experiment with the corresponding experiment name.
Their checkpoints will be stored in a `"checkpoints"` folder within the experiment folder.
```bash
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=3 experiment.checkpoint_interval=1 experiment.save_folder="/my/folder"
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=3 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100
```

To load from a checkpoint, pass the absolute checkpoint file name to `experiment.restore_file`.
```bash
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.restore_file="/my/folder/checkpoint/checkpoint_03.pt"
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.on_policy_collected_frames_per_batch=100 experiment.restore_file="/hydra/experiment/folder/checkpoint/checkpoint_300.pt"
```

[![Example](https://img.shields.io/badge/Example-blue.svg)](examples/checkpointing/reload_experiment.py)
Expand Down
6 changes: 4 additions & 2 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,13 @@ def _get_loss(
return loss_module, True

def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
return {
items = {
"loss_actor": list(loss.actor_network_params.flatten_keys().values()),
"loss_qvalue": list(loss.qvalue_network_params.flatten_keys().values()),
"loss_alpha": [loss.log_alpha],
}
if not self.fixed_alpha:
items.update({"loss_alpha": [loss.log_alpha]})
return items

def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
Expand Down
6 changes: 4 additions & 2 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ def _get_loss(
return loss_module, True

def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
return {
items = {
"loss_actor": list(loss.actor_network_params.flatten_keys().values()),
"loss_qvalue": list(loss.qvalue_network_params.flatten_keys().values()),
"loss_alpha": [loss.log_alpha],
}
if not self.fixed_alpha:
items.update({"loss_alpha": [loss.log_alpha]})
return items

def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
Expand Down
13 changes: 7 additions & 6 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ on_policy_n_minibatch_iters: 45
on_policy_minibatch_size: 400

# Number of frames collected and each experiment iteration
off_policy_collected_frames_per_batch: 1000
off_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially.
off_policy_n_envs_per_worker: 1
off_policy_n_envs_per_worker: 10
# This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over.
off_policy_n_optimizer_steps: 1000
# Number of frames used for each off_policy_n_optimizer_steps when training off-policy algorithms
Expand All @@ -72,8 +72,8 @@ off_policy_memory_size: 1_000_000
evaluation: True
# Whether to render the evaluation (if rendering is available)
render: True
# Frequency of evaluation in terms of experiment iterations
evaluation_interval: 20
# Frequency of evaluation in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch)
evaluation_interval: 120_000
# Number of episodes that evaluation is run on
evaluation_episodes: 10

Expand All @@ -87,5 +87,6 @@ create_json: True
save_folder: null
# Absolute path to a checkpoint file where the experiment was saved. If null the experiment is started fresh.
restore_file: null
# Interval for experiment saving in terms of experiment iterations. Set it to 0 to disable checkpointing
checkpoint_interval: 50
# Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch).
# Set it to 0 to disable checkpointing
checkpoint_interval: 300_000
44 changes: 33 additions & 11 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ def get_exploration_anneal_frames(self, on_policy: bool):
else self.exploration_anneal_frames
)

def get_evaluation_interval(self, on_policy: bool):
if self.evaluation_interval % self.collected_frames_per_batch(on_policy) != 0:
raise ValueError(
f"evaluation_interval ({self.evaluation_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
return self.evaluation_interval

def get_checkpoint_interval(self, on_policy: bool):
if self.checkpoint_interval % self.collected_frames_per_batch(on_policy) != 0:
raise ValueError(
f"checkpoint_interval ({self.checkpoint_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
return self.checkpoint_interval

@staticmethod
def get_from_yaml(path: Optional[str] = None):
if path is None:
Expand Down Expand Up @@ -351,7 +367,7 @@ def _setup_name(self):
self.folder_name = folder_name / self.name
if (
len(self.config.loggers)
or self.config.checkpoint_interval > 0
or self.config.get_checkpoint_interval(self.on_policy) > 0
or self.config.create_json
):
self.folder_name.mkdir(parents=False, exist_ok=False)
Expand Down Expand Up @@ -474,17 +490,23 @@ def _collection_loop(self):
# Evaluation
if (
self.config.evaluation
and self.n_iters_performed % self.config.evaluation_interval == 0
and (
self.total_frames
% self.config.get_evaluation_interval(self.on_policy)
== 0
)
and (len(self.config.loggers) or self.config.create_json)
):
self._evaluation_loop(iter=self.n_iters_performed)
self._evaluation_loop()

# End of step
self.n_iters_performed += 1
self.logger.commit()
if (
self.config.checkpoint_interval > 0
and self.n_iters_performed % self.config.checkpoint_interval == 0
self.config.get_checkpoint_interval(self.on_policy) > 0
and self.total_frames
% self.config.get_checkpoint_interval(self.on_policy)
== 0
):
self.save_trainer()
sampling_start = time.time()
Expand Down Expand Up @@ -524,8 +546,6 @@ def _optimizer_loop(self, group: str) -> TensorDictBase:

optimizer.step()
optimizer.zero_grad()
elif loss_name.startswith("loss"):
raise AssertionError
self.replay_buffers[group].update_tensordict_priority(subdata)
if self.target_updaters[group] is not None:
self.target_updaters[group].step()
Expand All @@ -546,7 +566,7 @@ def _grad_clip(self, optimizer: torch.optim.Optimizer) -> float:
return float(gn)

@torch.no_grad()
def _evaluation_loop(self, iter: int):
def _evaluation_loop(self):
evaluation_start = time.time()
with set_exploration_type(ExplorationType.MODE):
if self.task.has_render(self.test_env) and self.config.render:
Expand Down Expand Up @@ -585,11 +605,13 @@ def callback(env, td):
)
rollouts = list(rollouts.unbind(0))
evaluation_time = time.time() - evaluation_start
self.logger.log({"timers/evaluation_time": evaluation_time}, step=iter)
self.logger.log(
{"timers/evaluation_time": evaluation_time}, step=self.n_iters_performed
)
self.logger.log_evaluation(
rollouts,
video_frames=video_frames,
step=iter,
step=self.n_iters_performed,
total_frames=self.total_frames,
)
# Callback
Expand Down Expand Up @@ -628,7 +650,7 @@ def load_state_dict(self, state_dict: Dict) -> None:
def save_trainer(self) -> None:
checkpoint_folder = self.folder_name / "checkpoints"
checkpoint_folder.mkdir(parents=False, exist_ok=True)
checkpoint_file = checkpoint_folder / f"checkpoint_{self.n_iters_performed}.pt"
checkpoint_file = checkpoint_folder / f"checkpoint_{self.total_frames}.pt"
torch.save(self.state_dict(), checkpoint_file)

def load_trainer(self) -> Experiment:
Expand Down
6 changes: 4 additions & 2 deletions examples/checkpointing/reload_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# Save the experiment in the current folder
experiment_config.save_folder = Path(os.path.dirname(os.path.realpath(__file__)))
# Checkpoint at every iteration
experiment_config.checkpoint_interval = 1
experiment_config.checkpoint_interval = (
experiment_config.on_policy_collected_frames_per_batch
)
# Run 3 iterations
experiment_config.max_n_iters = 3

Expand All @@ -34,7 +36,7 @@
experiment_config.restore_file = (
experiment.folder_name
/ "checkpoints"
/ f"checkpoint_{experiment_config.n_iters}.pt"
/ f"checkpoint_{experiment.total_frames}.pt"
)
# The experiment will be saved in the ame folder as the one it is restoring from
experiment_config.save_folder = None
Expand Down
4 changes: 2 additions & 2 deletions examples/checkpointing/reload_experiment.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=3 experiment.checkpoint_interval=1
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.restore_file="/hydra/experiment/folder/checkpoint/checkpoint_03.pt"
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=3 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.on_policy_collected_frames_per_batch=100 experiment.restore_file="/hydra/experiment/folder/checkpoint/checkpoint_300.pt"
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ hydra:
seed: 0

experiment:
sampling_device: "cuda"
sampling_device: "cpu"
train_device: "cuda"

share_policy_params: True
Expand All @@ -31,10 +31,10 @@ experiment:

exploration_eps_init: 0.8
exploration_eps_end: 0.01
exploration_anneal_frames: null
exploration_anneal_frames: 1_000_000

max_n_iters: null
max_n_frames: 30_000_000
max_n_frames: 20_000_000

on_policy_collected_frames_per_batch: 6000
on_policy_n_envs_per_worker: 10
Expand All @@ -43,11 +43,15 @@ experiment:

off_policy_collected_frames_per_batch: 6000
off_policy_n_envs_per_worker: 10
off_policy_n_optimizer_steps: 1000
off_policy_n_optimizer_steps: 100
off_policy_train_batch_size: 100
off_policy_memory_size: 1_000_000

evaluation: True
render: True
evaluation_interval: 50
evaluation_interval: 300_000
evaluation_episodes: 20

save_folder: null
restore_file: null
checkpoint_interval: 600_000
2 changes: 1 addition & 1 deletion premade_scripts/smacv2/smacv2_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from omegaconf import DictConfig, OmegaConf


@hydra.main(version_base=None, config_path=".", config_name="config")
@hydra.main(version_base=None, config_path="conf", config_name="config")
def hydra_experiment(cfg: DictConfig) -> None:
hydra_choices = HydraConfig.get().runtime.choices
task_name = hydra_choices.task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,15 @@ experiment:

off_policy_collected_frames_per_batch: 6000
off_policy_n_envs_per_worker: 60
off_policy_n_optimizer_steps: 6000
off_policy_n_optimizer_steps: 1000
off_policy_train_batch_size: 128
off_policy_memory_size: 1_000_000

evaluation: True
render: True
evaluation_interval: 20
evaluation_interval: 120_000
evaluation_episodes: 200

save_folder: null
restore_file: null
checkpoint_interval: 300_000
2 changes: 1 addition & 1 deletion premade_scripts/vmas/vmas_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from omegaconf import DictConfig, OmegaConf


@hydra.main(version_base=None, config_path=".", config_name="config")
@hydra.main(version_base=None, config_path="conf", config_name="config")
def hydra_experiment(cfg: DictConfig) -> None:
hydra_choices = HydraConfig.get().runtime.choices
task_name = hydra_choices.task
Expand Down
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def experiment_config(tmp_path) -> ExperimentConfig:
experiment_config.evaluation = True
experiment_config.render = True
experiment_config.evaluation_episodes = 2
experiment_config.evaluation_interval = 500
experiment_config.loggers = ["csv"]
experiment_config.create_json = True
experiment_config.checkpoint_interval = 1
experiment_config.checkpoint_interval = 100
return experiment_config


Expand Down
2 changes: 1 addition & 1 deletion test/utils_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def check_experiment_loading(

experiment_config.max_n_iters = max_n_iters + 3
experiment_config.restore_file = (
exp_folder / "checkpoints" / f"checkpoint_{max_n_iters}.pt"
exp_folder / "checkpoints" / f"checkpoint_{experiment.total_frames}.pt"
)
experiment_config.save_folder = None
experiment = Experiment(
Expand Down

0 comments on commit 62224bc

Please sign in to comment.