Skip to content

Commit

Permalink
[Example] Collecting with gradient (#77)
Browse files Browse the repository at this point in the history
* collect with grad

* amend

* amend

* amend
  • Loading branch information
matteobettini authored Jun 10, 2024
1 parent 87b4f3e commit a104cad
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 17 deletions.
2 changes: 2 additions & 0 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ buffer_device: "cpu"
share_policy_params: True
# If an algorithm and an env support both continuous and discrete actions, what should be preferred
prefer_continuous_actions: True
# If False collection is done using a collector (under no grad). If True, collection is done with gradients.
collect_with_grad: False

# Discount factor
gamma: 0.9
Expand Down
74 changes: 57 additions & 17 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchrl.collectors import SyncDataCollector
from torchrl.envs import SerialEnv, TransformedEnv
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.record.loggers import generate_exp_name
from tqdm import tqdm

Expand Down Expand Up @@ -54,6 +54,7 @@ class ExperimentConfig:

share_policy_params: bool = MISSING
prefer_continuous_actions: bool = MISSING
collect_with_grad: bool = MISSING

gamma: float = MISSING
lr: float = MISSING
Expand Down Expand Up @@ -456,17 +457,26 @@ def _setup_collector(self):
assert len(group_policy) == 1
self.group_policies.update({group: group_policy[0]})

self.collector = SyncDataCollector(
self.env_func,
self.policy,
device=self.config.sampling_device,
storing_device=self.config.train_device,
frames_per_batch=self.config.collected_frames_per_batch(self.on_policy),
total_frames=self.config.get_max_n_frames(self.on_policy),
init_random_frames=(
self.config.off_policy_init_random_frames if not self.on_policy else 0
),
)
if not self.config.collect_with_grad:
self.collector = SyncDataCollector(
self.env_func,
self.policy,
device=self.config.sampling_device,
storing_device=self.config.train_device,
frames_per_batch=self.config.collected_frames_per_batch(self.on_policy),
total_frames=self.config.get_max_n_frames(self.on_policy),
init_random_frames=(
self.config.off_policy_init_random_frames
if not self.on_policy
else 0
),
)
else:
if self.config.off_policy_init_random_frames and not self.on_policy:
raise TypeError(
"Collection via rollouts does not support initial random frames as of now."
)
self.rollout_env = self.env_func().to(self.config.sampling_device)

def _setup_name(self):
self.algorithm_name = self.algorithm_config.associated_class().__name__.lower()
Expand Down Expand Up @@ -544,8 +554,31 @@ def _collection_loop(self):
)
sampling_start = time.time()

if not self.config.collect_with_grad:
iterator = iter(self.collector)
else:
reset_batch = self.rollout_env.reset()

# Training/collection iterations
for batch in self.collector:
for _ in range(
self.n_iters_performed, self.config.get_max_n_iters(self.on_policy)
):
if not self.config.collect_with_grad:
batch = next(iterator)
else:
with set_exploration_type(ExplorationType.RANDOM):
batch = self.rollout_env.rollout(
max_steps=-(
-self.config.collected_frames_per_batch(self.on_policy)
// self.rollout_env.batch_size.numel()
),
policy=self.policy,
break_when_any_done=False,
auto_reset=False,
tensordict=reset_batch,
)
reset_batch = step_mdp(batch[..., -1])

# Logging collection
collection_time = time.time() - sampling_start
current_frames = batch.numel()
Expand All @@ -560,6 +593,7 @@ def _collection_loop(self):

# Callback
self._on_batch_collected(batch)
batch = batch.detach()

# Loop over groups
training_start = time.time()
Expand Down Expand Up @@ -593,7 +627,8 @@ def _collection_loop(self):
explore_layer.step(current_frames)

# Update policy in collector
self.collector.update_policy_weights_()
if not self.config.collect_with_grad:
self.collector.update_policy_weights_()

# Timers
training_time = time.time() - training_start
Expand Down Expand Up @@ -635,7 +670,10 @@ def _collection_loop(self):

def close(self):
"""Close the experiment."""
self.collector.shutdown()
if not self.config.collect_with_grad:
self.collector.shutdown()
else:
self.rollout_env.close()
self.test_env.close()
self.logger.finish()

Expand Down Expand Up @@ -766,13 +804,14 @@ def state_dict(self) -> OrderedDict:
)
state_dict = OrderedDict(
state=state,
collector=self.collector.state_dict(),
**{f"loss_{k}": item.state_dict() for k, item in self.losses.items()},
**{
f"buffer_{k}": item.state_dict()
for k, item in self.replay_buffers.items()
},
)
if not self.config.collect_with_grad:
state_dict.update({"collector": self.collector.state_dict()})
return state_dict

def load_state_dict(self, state_dict: Dict) -> None:
Expand All @@ -785,7 +824,8 @@ def load_state_dict(self, state_dict: Dict) -> None:
for group in self.group_map.keys():
self.losses[group].load_state_dict(state_dict[f"loss_{group}"])
self.replay_buffers[group].load_state_dict(state_dict[f"buffer_{group}"])
self.collector.load_state_dict(state_dict["collector"])
if not self.config.collect_with_grad:
self.collector.load_state_dict(state_dict["collector"])
self.total_time = state_dict["state"]["total_time"]
self.total_frames = state_dict["state"]["total_frames"]
self.n_iters_performed = state_dict["state"]["n_iters_performed"]
Expand Down

0 comments on commit a104cad

Please sign in to comment.