From f7caeca2c784222b22d323b3a9721d945cf58cff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 6 Jun 2023 09:44:55 +0800 Subject: [PATCH 01/27] add sil policy --- ding/policy/__init__.py | 1 + ding/policy/command_mode_policy_instance.py | 6 + ding/policy/sil.py | 271 ++++++++++++++++++ ding/rl_utils/__init__.py | 1 + ding/rl_utils/sil.py | 40 +++ .../cartpole/config/cartpole_sil_config.py | 48 ++++ 6 files changed, 367 insertions(+) create mode 100644 ding/policy/sil.py create mode 100644 ding/rl_utils/sil.py create mode 100644 dizoo/classic_control/cartpole/config/cartpole_sil_config.py diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 65f3f2757e..ef53fcc0c5 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -50,6 +50,7 @@ from .pc import ProcedureCloningBFSPolicy from .bcq import BCQPolicy +from .sil import SILPolicy # new-type policy from .ppof import PPOFPolicy diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 8b6123c063..6d953d2287 100755 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -36,6 +36,7 @@ from .sql import SQLPolicy from .bc import BehaviourCloningPolicy from .ibc import IBCPolicy +from .sil import SILPolicy from .dqfd import DQFDPolicy from .r2d3 import R2D3Policy @@ -432,3 +433,8 @@ def _get_setting_learn(self, command_info: dict) -> dict: def _get_setting_eval(self, command_info: dict) -> dict: return {} + + +@POLICY_REGISTRY.register('sil_command') +class SILCommandModePolicy(SILPolicy, DummyCommandModePolicy): + pass diff --git a/ding/policy/sil.py b/ding/policy/sil.py new file mode 100644 index 0000000000..15e2abcdd7 --- /dev/null +++ b/ding/policy/sil.py @@ -0,0 +1,271 @@ +from typing import List, Dict, Any, Tuple, Union +from collections import namedtuple +import torch + +from ding.rl_utils import sil_data, sil_error, get_gae_with_default_last_value, get_train_sample +from ding.torch_utils import Adam, to_device +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY, split_data_generator +from ding.utils.data import default_collate, default_decollate +from .base_policy import Policy +from .common_utils import default_preprocess_learn + + +@POLICY_REGISTRY.register('sil') +class SILPolicy(Policy): + r""" + Overview: + Policy class of SIL algorithm, paper link: https://arxiv.org/abs/1806.05635 + """ + config = dict( + # (string) RL policy register name (refer to function "register_policy"). + type='sil', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool) Whether to use on-policy training pipeline(behaviour policy and training policy are the same) + on_policy=True, + priority=False, + # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. + priority_IS_weight=False, + learn=dict( + update_per_collect=1, # fixed value, this line should not be modified by users + batch_size=64, + learning_rate=0.001, + # (List[float]) + betas=(0.9, 0.999), + # (float) + eps=1e-8, + # (float) + grad_norm=0.5, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) loss weight of the value network, the weight of policy network is set to 1 + value_weight=0.5, + # (bool) Whether to normalize advantage. Default to False. + adv_norm=False, + ignore_done=False, + ), + collect=dict( + # (int) collect n_sample data, train model n_iteration times + # n_sample=80, + unroll_len=1, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) discount factor for future reward, defaults int [0, 1] + discount_factor=0.9, + # (float) the trade-off factor lambda to balance 1step td and mc + gae_lambda=0.95, + ), + eval=dict(), + ) + + def default_model(self) -> Tuple[str, List[str]]: + return 'vac', ['ding.model.template.vac'] + + def _init_learn(self) -> None: + r""" + Overview: + Learn mode init method. Called by ``self.__init__``. + Init the optimizer, algorithm config, main and target models. + """ + # Optimizer + self._optimizer = Adam( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + betas=self._cfg.learn.betas, + eps=self._cfg.learn.eps + ) + + # Algorithm config + self._priority = self._cfg.priority + self._priority_IS_weight = self._cfg.priority_IS_weight + self._value_weight = self._cfg.learn.value_weight + self._entropy_weight = self._cfg.learn.entropy_weight + self._adv_norm = self._cfg.learn.adv_norm + self._grad_norm = self._cfg.learn.grad_norm + + # Main and target models + self._learn_model = model_wrap(self._model, wrapper_name='base') + self._learn_model.reset() + + def _forward_learn(self, data: dict) -> Dict[str, Any]: + r""" + Overview: + Forward and backward function of learn mode. + Arguments: + - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs','adv'] + Returns: + - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. + """ + data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) + if self._cuda: + data = to_device(data, self._device) + self._learn_model.train() + + for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): + # forward + output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') + + adv = batch['adv'] + return_ = batch['value'] + adv + if self._adv_norm: + # norm adv in total train_batch + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + error_data = sil_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) + + # Calculate SIL loss + sil_loss = sil_error(error_data) + wv, we = self._value_weight, self._entropy_weight + total_loss = sil_loss.policy_loss + wv * sil_loss.value_loss + + # ==================== + # SIL-learning update + # ==================== + + self._optimizer.zero_grad() + total_loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + list(self._learn_model.parameters()), + max_norm=self._grad_norm, + ) + self._optimizer.step() + + # ============= + # after update + # ============= + # only record last updates information in logger + return { + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'total_loss': total_loss.item(), + 'policy_loss': sil_loss.policy_loss.item(), + 'value_loss': sil_loss.value_loss.item(), + 'adv_abs_max': adv.abs().max().item(), + 'grad_norm': grad_norm, + } + + def _state_dict_learn(self) -> Dict[str, Any]: + return { + 'model': self._learn_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + self._learn_model.load_state_dict(state_dict['model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def _init_collect(self) -> None: + r""" + Overview: + Collect mode init method. Called by ``self.__init__``. + Init traj and unroll length, collect model. + """ + + self._unroll_len = self._cfg.collect.unroll_len + self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') + self._collect_model.reset() + # Algorithm + self._gamma = self._cfg.collect.discount_factor + self._gae_lambda = self._cfg.collect.gae_lambda + + def _forward_collect(self, data: dict) -> dict: + r""" + Overview: + Forward function of collect mode. + Arguments: + - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ + values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. + ReturnsKeys + - necessary: ``action`` + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._collect_model.eval() + with torch.no_grad(): + output = self._collect_model.forward(data, mode='compute_actor_critic') + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + r""" + Overview: + Generate dict type transition data from inputs. + Arguments: + - obs (:obj:`Any`): Env observation + - model_output (:obj:`dict`): Output of collect model, including at least ['action'] + - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ + (here 'obs' indicates obs after env step). + Returns: + - transition (:obj:`dict`): Dict type transition data. + """ + transition = { + 'obs': obs, + 'next_obs': timestep.obs, + 'action': model_output['action'], + 'value': model_output['value'], + 'reward': timestep.reward, + 'done': timestep.done, + } + return transition + + def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + r""" + Overview: + Get the trajectory and the n step return data, then sample from the n_step return data + Arguments: + - data (:obj:`list`): The trajectory's buffer list + Returns: + - samples (:obj:`dict`): The training samples generated + """ + data = get_gae_with_default_last_value( + data, + data[-1]['done'], + gamma=self._gamma, + gae_lambda=self._gae_lambda, + cuda=self._cuda, + ) + return get_train_sample(data, self._unroll_len) + + def _init_eval(self) -> None: + r""" + Overview: + Evaluate mode init method. Called by ``self.__init__``. + Init eval model with argmax strategy. + """ + self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') + self._eval_model.reset() + + def _forward_eval(self, data: dict) -> dict: + r""" + Overview: + Forward function of eval mode, similar to ``self._forward_collect``. + Arguments: + - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ + values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + Returns: + - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. + ReturnsKeys + - necessary: ``action`` + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._eval_model.eval() + with torch.no_grad(): + output = self._eval_model.forward(data, mode='compute_actor') + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _monitor_vars_learn(self) -> List[str]: + return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'entropy_loss', 'adv_abs_max', 'grad_norm'] diff --git a/ding/rl_utils/__init__.py b/ding/rl_utils/__init__.py index 080b37ead5..b2b8dcc08b 100644 --- a/ding/rl_utils/__init__.py +++ b/ding/rl_utils/__init__.py @@ -23,3 +23,4 @@ from .acer import acer_policy_error, acer_value_error, acer_trust_region_update from .sampler import ArgmaxSampler, MultinomialSampler, MuSampler, ReparameterizationSampler, HybridStochasticSampler, \ HybridDeterminsticSampler +from .sil import sil_error, sil_data diff --git a/ding/rl_utils/sil.py b/ding/rl_utils/sil.py new file mode 100644 index 0000000000..cf221ba04a --- /dev/null +++ b/ding/rl_utils/sil.py @@ -0,0 +1,40 @@ +from collections import namedtuple +import torch +import torch.nn.functional as F + +sil_data = namedtuple('sil_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight']) +sil_loss = namedtuple('sil_loss', ['policy_loss', 'value_loss']) + + +def sil_error(data: namedtuple) -> namedtuple: + """ + Overview: + Implementation of SIL(Self-Imitation Learning) (arXiv:1806.05635) + Arguments: + - data (:obj:`namedtuple`): SIL input data with fields shown in ``sil_data`` + Returns: + - sil_loss (:obj:`namedtuple`): the SIL loss item, all of them are the differentiable 0-dim tensor + Shapes: + - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim + - action (:obj:`torch.LongTensor`): :math:`(B, )` + - value (:obj:`torch.FloatTensor`): :math:`(B, )` + - adv (:obj:`torch.FloatTensor`): :math:`(B, )` + - return (:obj:`torch.FloatTensor`): :math:`(B, )` + - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` + - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor + - value_loss (:obj:`torch.FloatTensor`): :math:`()` + """ + logit, action, value, adv, return_, weight = data + if weight is None: + weight = torch.ones_like(value) + dist = torch.distributions.categorical.Categorical(logits=logit) + logp = dist.log_prob(action) + + # Clip the negative part of adv. + adv = adv.clamp_min(0) + policy_loss = -(logp * adv * weight).mean() + + # Clip the negative part of the distance between value and return. + rv_dist = torch.clamp_min((return_ - value), 0) + value_loss = (F.mse_loss(rv_dist, torch.zeros_like(rv_dist), reduction='none') * weight).mean() + return sil_loss(policy_loss, value_loss) diff --git a/dizoo/classic_control/cartpole/config/cartpole_sil_config.py b/dizoo/classic_control/cartpole/config/cartpole_sil_config.py new file mode 100644 index 0000000000..318e5267e3 --- /dev/null +++ b/dizoo/classic_control/cartpole/config/cartpole_sil_config.py @@ -0,0 +1,48 @@ +from easydict import EasyDict + +cartpole_sil_config = dict( + exp_name='cartpole_sil_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=195, + ), + policy=dict( + cuda=False, + model=dict( + obs_shape=4, + action_shape=2, + encoder_hidden_size_list=[128, 128, 64], + ), + learn=dict( + batch_size=40, + learning_rate=0.001, + ), + collect=dict( + # (int) collect n_sample data, train model n_iteration times + n_sample=80, + # (float) the trade-off factor lambda to balance 1step td and mc + gae_lambda=0.95, + ), + eval=dict(evaluator=dict(eval_freq=50, )), + ), +) +cartpole_sil_config = EasyDict(cartpole_sil_config) +main_config = cartpole_sil_config + +cartpole_sil_create_config = dict( + env=dict( + type='cartpole', + import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='sil'), +) +cartpole_sil_create_config = EasyDict(cartpole_sil_create_config) +create_config = cartpole_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c cartpole_sil_config.py -s 0` + from ding.entry import serial_pipeline_onpolicy + serial_pipeline_onpolicy((main_config, create_config), seed=0) From 25b94143f074ea23c1b8df31ea896b88824a9298 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 6 Jun 2023 09:47:59 +0800 Subject: [PATCH 02/27] add test file --- ding/rl_utils/tests/test_sil.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 ding/rl_utils/tests/test_sil.py diff --git a/ding/rl_utils/tests/test_sil.py b/ding/rl_utils/tests/test_sil.py new file mode 100644 index 0000000000..4390326459 --- /dev/null +++ b/ding/rl_utils/tests/test_sil.py @@ -0,0 +1,26 @@ +import pytest +import torch +from ding.rl_utils import sil_data, sil_error + +random_weight = torch.rand(4) + 1 +weight_args = [None, random_weight] + + +@pytest.mark.unittest +@pytest.mark.parametrize('weight, ', weight_args) +def test_a2c(weight): + B, N = 4, 32 + logit = torch.randn(B, N).requires_grad_(True) + action = torch.randint(0, N, size=(B, )) + value = torch.randn(B).requires_grad_(True) + adv = torch.rand(B) + return_ = torch.randn(B) * 2 + data = sil_data(logit, action, value, adv, return_, weight) + loss = sil_error(data) + assert all([l.shape == tuple() for l in loss]) + assert logit.grad is None + assert value.grad is None + total_loss = sum(loss) + total_loss.backward() + assert isinstance(logit.grad, torch.Tensor) + assert isinstance(value.grad, torch.Tensor) From 9a7d070c1e60e2954a94a0ca1b28cef108d1bd58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 6 Jun 2023 10:42:40 +0800 Subject: [PATCH 03/27] add sil info --- ding/policy/sil.py | 15 +++++++++------ ding/rl_utils/sil.py | 8 ++++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 15e2abcdd7..e72d90ce62 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -42,6 +42,8 @@ class SILPolicy(Policy): # ============================================================== # (float) loss weight of the value network, the weight of policy network is set to 1 value_weight=0.5, + # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 + entropy_weight=0.01, # (bool) Whether to normalize advantage. Default to False. adv_norm=False, ignore_done=False, @@ -116,16 +118,16 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: error_data = sil_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) # Calculate SIL loss - sil_loss = sil_error(error_data) - wv, we = self._value_weight, self._entropy_weight - total_loss = sil_loss.policy_loss + wv * sil_loss.value_loss + sil_loss, sil_info = sil_error(error_data) + wv = self._value_weight + sil_total_loss = sil_loss.policy_loss + wv * sil_loss.value_loss # ==================== # SIL-learning update # ==================== self._optimizer.zero_grad() - total_loss.backward() + sil_total_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( list(self._learn_model.parameters()), @@ -139,9 +141,10 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: # only record last updates information in logger return { 'cur_lr': self._optimizer.param_groups[0]['lr'], - 'total_loss': total_loss.item(), + 'total_loss': sil_total_loss.item(), 'policy_loss': sil_loss.policy_loss.item(), 'value_loss': sil_loss.value_loss.item(), + 'policy_clipfrac': sil_info.policy_clipfrac, 'adv_abs_max': adv.abs().max().item(), 'grad_norm': grad_norm, } @@ -268,4 +271,4 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: - return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'entropy_loss', 'adv_abs_max', 'grad_norm'] + return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'adv_abs_max', 'grad_norm'] diff --git a/ding/rl_utils/sil.py b/ding/rl_utils/sil.py index cf221ba04a..e3405a2455 100644 --- a/ding/rl_utils/sil.py +++ b/ding/rl_utils/sil.py @@ -4,6 +4,7 @@ sil_data = namedtuple('sil_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight']) sil_loss = namedtuple('sil_loss', ['policy_loss', 'value_loss']) +sil_info = namedtuple('sil_info', ['policy_clipfrac', 'value_clipfrac']) def sil_error(data: namedtuple) -> namedtuple: @@ -31,10 +32,13 @@ def sil_error(data: namedtuple) -> namedtuple: logp = dist.log_prob(action) # Clip the negative part of adv. + policy_clipfrac = adv.lt(0).float().mean().item() adv = adv.clamp_min(0) policy_loss = -(logp * adv * weight).mean() # Clip the negative part of the distance between value and return. - rv_dist = torch.clamp_min((return_ - value), 0) + rv_dist = return_ - value + value_clipfrac = rv_dist.lt(0).float().mean().item() + rv_dist = rv_dist.clamp_min(0) value_loss = (F.mse_loss(rv_dist, torch.zeros_like(rv_dist), reduction='none') * weight).mean() - return sil_loss(policy_loss, value_loss) + return sil_loss(policy_loss, value_loss), sil_info(policy_clipfrac, value_clipfrac) From 05461eaa7ec07b52ce7cd5e44313f0641f89e333 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 6 Jun 2023 10:43:01 +0800 Subject: [PATCH 04/27] update test file --- ding/rl_utils/tests/test_a2c.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/rl_utils/tests/test_a2c.py b/ding/rl_utils/tests/test_a2c.py index e2c635a7f2..dc0497b94d 100644 --- a/ding/rl_utils/tests/test_a2c.py +++ b/ding/rl_utils/tests/test_a2c.py @@ -18,7 +18,7 @@ def test_a2c(weight): adv = torch.rand(B) return_ = torch.randn(B) * 2 data = a2c_data(logit, action, value, adv, return_, weight) - loss = a2c_error(data) + loss, info = a2c_error(data) assert all([l.shape == tuple() for l in loss]) assert logit.grad is None assert value.grad is None From 8856d6293869976077bd6461baaba22689902be3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 6 Jun 2023 10:46:12 +0800 Subject: [PATCH 05/27] add monitor vars --- ding/policy/sil.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index e72d90ce62..b2259186e7 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -271,4 +271,5 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: - return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'adv_abs_max', 'grad_norm'] + return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'adv_abs_max', 'grad_norm', + 'policy_clipfrac', 'value_clipfrac'] From fa6c8dd2e3f50b73ce401a2eae54745552509704 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 9 Jun 2023 17:17:47 +0800 Subject: [PATCH 06/27] update sil policy --- ding/policy/sil.py | 65 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 13 deletions(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index b2259186e7..49fefea651 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -2,7 +2,7 @@ from collections import namedtuple import torch -from ding.rl_utils import sil_data, sil_error, get_gae_with_default_last_value, get_train_sample +from ding.rl_utils import sil_data, sil_error, get_gae_with_default_last_value, get_train_sample, a2c_data, a2c_error from ding.torch_utils import Adam, to_device from ding.model import model_wrap from ding.utils import POLICY_REGISTRY, split_data_generator @@ -27,6 +27,8 @@ class SILPolicy(Policy): priority=False, # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, + # (int) Number of epochs to use SIL loss to update the policy. + sil_update_per_collect=1, learn=dict( update_per_collect=1, # fixed value, this line should not be modified by users batch_size=64, @@ -115,19 +117,19 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: if self._adv_norm: # norm adv in total train_batch adv = (adv - adv.mean()) / (adv.std() + 1e-8) - error_data = sil_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) + error_data = a2c_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) - # Calculate SIL loss - sil_loss, sil_info = sil_error(error_data) - wv = self._value_weight - sil_total_loss = sil_loss.policy_loss + wv * sil_loss.value_loss + # Calculate A2C loss + a2c_loss = a2c_error(error_data) + wv, we = self._value_weight, self._entropy_weight + a2c_total_loss = a2c_loss.policy_loss + wv * a2c_loss.value_loss - we * a2c_loss.entropy_loss # ==================== - # SIL-learning update + # A2C-learning update # ==================== self._optimizer.zero_grad() - sil_total_loss.backward() + a2c_total_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( list(self._learn_model.parameters()), @@ -135,16 +137,51 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: ) self._optimizer.step() + for _ in range(self._cfg.sil_update_per_collect): + for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): + # forward + output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') + + adv = batch['adv'] + return_ = batch['value'] + adv + if self._adv_norm: + # norm adv in total train_batch + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + error_data = sil_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) + + # Calculate SIL loss + sil_loss, sil_info = sil_error(error_data) + wv = self._value_weight + sil_total_loss = sil_loss.policy_loss + wv * sil_loss.value_loss + + # ==================== + # SIL-learning update + # ==================== + + self._optimizer.zero_grad() + sil_total_loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + list(self._learn_model.parameters()), + max_norm=self._grad_norm, + ) + self._optimizer.step() + # ============= # after update # ============= # only record last updates information in logger return { 'cur_lr': self._optimizer.param_groups[0]['lr'], - 'total_loss': sil_total_loss.item(), - 'policy_loss': sil_loss.policy_loss.item(), - 'value_loss': sil_loss.value_loss.item(), + 'sil_total_loss': sil_total_loss.item(), + 'a2c_total_loss': a2c_total_loss.item(), + 'sil_policy_loss': sil_loss.policy_loss.item(), + 'a2c_policy_loss': a2c_loss.policy_loss.item(), + 'sil_value_loss': sil_loss.value_loss.item(), + 'a2c_value_loss': a2c_loss.value_loss.item(), + 'a2c_entropy_loss': a2c_loss.entropy_loss.item(), 'policy_clipfrac': sil_info.policy_clipfrac, + 'value_clipfrac': sil_info.value_clipfrac, 'adv_abs_max': adv.abs().max().item(), 'grad_norm': grad_norm, } @@ -271,5 +308,7 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: - return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'adv_abs_max', 'grad_norm', - 'policy_clipfrac', 'value_clipfrac'] + return super()._monitor_vars_learn() + ['a2c_policy_loss', 'sil_policy_loss', 'sil_value_loss', + 'a2c_value_loss', 'a2c_total_loss', 'sil_total_loss', + 'a2c_entropy_loss', 'adv_abs_max', 'grad_norm', 'policy_clipfrac', + 'value_clipfrac'] From bc63a246e536504b50716400e69497ba894c7489 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 9 Jun 2023 17:35:48 +0800 Subject: [PATCH 07/27] polish formate --- ding/policy/sil.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 49fefea651..c7e447f807 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -173,6 +173,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: # only record last updates information in logger return { 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'total_loss': sil_total_loss.item() + a2c_total_loss.item(), 'sil_total_loss': sil_total_loss.item(), 'a2c_total_loss': a2c_total_loss.item(), 'sil_policy_loss': sil_loss.policy_loss.item(), @@ -308,7 +309,7 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: - return super()._monitor_vars_learn() + ['a2c_policy_loss', 'sil_policy_loss', 'sil_value_loss', - 'a2c_value_loss', 'a2c_total_loss', 'sil_total_loss', - 'a2c_entropy_loss', 'adv_abs_max', 'grad_norm', 'policy_clipfrac', - 'value_clipfrac'] + return super()._monitor_vars_learn() + [ + 'a2c_policy_loss', 'sil_policy_loss', 'sil_value_loss', 'a2c_value_loss', 'a2c_total_loss', + 'sil_total_loss', 'a2c_entropy_loss', 'adv_abs_max', 'grad_norm', 'policy_clipfrac', 'value_clipfrac' + ] From e1457a981a44cff728fdd09de3b63f430999ef11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sat, 10 Jun 2023 14:31:54 +0800 Subject: [PATCH 08/27] update test file --- ding/rl_utils/tests/test_a2c.py | 2 +- ding/rl_utils/tests/test_sil.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ding/rl_utils/tests/test_a2c.py b/ding/rl_utils/tests/test_a2c.py index dc0497b94d..e2c635a7f2 100644 --- a/ding/rl_utils/tests/test_a2c.py +++ b/ding/rl_utils/tests/test_a2c.py @@ -18,7 +18,7 @@ def test_a2c(weight): adv = torch.rand(B) return_ = torch.randn(B) * 2 data = a2c_data(logit, action, value, adv, return_, weight) - loss, info = a2c_error(data) + loss = a2c_error(data) assert all([l.shape == tuple() for l in loss]) assert logit.grad is None assert value.grad is None diff --git a/ding/rl_utils/tests/test_sil.py b/ding/rl_utils/tests/test_sil.py index 4390326459..6e6afdb3ef 100644 --- a/ding/rl_utils/tests/test_sil.py +++ b/ding/rl_utils/tests/test_sil.py @@ -16,7 +16,7 @@ def test_a2c(weight): adv = torch.rand(B) return_ = torch.randn(B) * 2 data = sil_data(logit, action, value, adv, return_, weight) - loss = sil_error(data) + loss, info = sil_error(data) assert all([l.shape == tuple() for l in loss]) assert logit.grad is None assert value.grad is None From 50c60ebc63c9251712ace5cac8a34169ecfdc600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sat, 10 Jun 2023 14:33:05 +0800 Subject: [PATCH 09/27] update test file --- ding/rl_utils/tests/test_sil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/rl_utils/tests/test_sil.py b/ding/rl_utils/tests/test_sil.py index 6e6afdb3ef..32af5006b4 100644 --- a/ding/rl_utils/tests/test_sil.py +++ b/ding/rl_utils/tests/test_sil.py @@ -8,7 +8,7 @@ @pytest.mark.unittest @pytest.mark.parametrize('weight, ', weight_args) -def test_a2c(weight): +def test_sil(weight): B, N = 4, 32 logit = torch.randn(B, N).requires_grad_(True) action = torch.randint(0, N, size=(B, )) From 23823e1c58d08b9b15f5ed163bbdf473f55c41a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sat, 10 Jun 2023 14:34:08 +0800 Subject: [PATCH 10/27] polish config --- dizoo/classic_control/cartpole/config/cartpole_sil_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dizoo/classic_control/cartpole/config/cartpole_sil_config.py b/dizoo/classic_control/cartpole/config/cartpole_sil_config.py index 318e5267e3..b6eca71187 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_sil_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_sil_config.py @@ -10,6 +10,7 @@ ), policy=dict( cuda=False, + sil_update_per_collect=1, model=dict( obs_shape=4, action_shape=2, From 4015c1ff1821e348c912a061ff182ba780d853f0 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 15 Jun 2023 16:02:47 +0800 Subject: [PATCH 11/27] update offline --- ding/entry/__init__.py | 1 + ding/entry/serial_entry_sil.py | 140 ++++++++++++++++++ ding/policy/__init__.py | 2 +- ding/policy/sil.py | 56 +++---- .../cartpole/config/cartpole_sil_config.py | 4 +- 5 files changed, 173 insertions(+), 30 deletions(-) create mode 100644 ding/entry/serial_entry_sil.py diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index e0501b12db..b49b087139 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -28,3 +28,4 @@ from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream from .serial_entry_bco import serial_pipeline_bco from .serial_entry_pc import serial_pipeline_pc +from .serial_entry_sil import serial_pipeline_sil diff --git a/ding/entry/serial_entry_sil.py b/ding/entry/serial_entry_sil.py new file mode 100644 index 0000000000..7e99bcfb86 --- /dev/null +++ b/ding/entry/serial_entry_sil.py @@ -0,0 +1,140 @@ +from typing import Union, Optional, List, Any, Tuple +import os +import torch +from ditk import logging +from functools import partial +from tensorboardX import SummaryWriter +from copy import deepcopy + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ + create_serial_collector, create_serial_evaluator +from ding.config import read_config, compile_config +from ding.policy import create_policy +from ding.utils import set_pkg_seed +from .utils import random_collect + + +def serial_pipeline_sil( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + dynamic_seed: Optional[bool] = True, +) -> 'Policy': # noqa + """ + Overview: + Serial pipeline entry for off-policy RL. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + - dynamic_seed(:obj:`Optional[bool]`): set dynamic seed for collector. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + create_cfg.policy.type = create_cfg.policy.type + '_command' + env_fn = None if env_setting is None else env_setting[0] + cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + if env_setting is None: + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, evaluator_env_cfg = env_setting + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed, dynamic_seed=dynamic_seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = create_serial_collector( + cfg.policy.collect.collector, + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name + ) + evaluator = create_serial_evaluator( + cfg.policy.eval.evaluator, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name + ) + replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) + commander = BaseSerialCommander( + cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode + ) + # ========== + # Main loop + # ========== + # Learner's before_run hook. + learner.call_hook('before_run') + + # Accumulate plenty of data at the beginning of training. + if cfg.policy.get('random_collect_size', 0) > 0: + random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer) + while True: + collect_kwargs = commander.step() + # Evaluate policy performance + if evaluator.should_eval(learner.train_iter): + stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Collect data by default config n_sample/n_episode + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) + + train_data = {'new_data': new_data, 'replay_data': []} + + # Learn policy from collected data + for i in range(cfg.policy.learn.sil_update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) + if train_data is None: + # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times + logging.warning( + "Replay buffer's data can only train for {} steps. ".format(i) + + "You can modify data collect config, e.g. increasing n_sample, n_episode." + ) + break + train_data['replay_data'].append(train_data) + learner.train(train_data, collector.envstep) + if learner.policy.get_attribute('priority'): + replay_buffer.update(learner.priority_info) + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + import time + import pickle + import numpy as np + with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f: + eval_value_raw = [d['eval_episode_return'] for d in eval_info] + final_data = { + 'stop': stop, + 'env_step': collector.envstep, + 'train_iter': learner.train_iter, + 'eval_value': np.mean(eval_value_raw), + 'eval_value_raw': eval_value_raw, + 'finish_time': time.ctime(), + } + pickle.dump(final_data, f) + return policy diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index ef53fcc0c5..41de15422c 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -50,7 +50,7 @@ from .pc import ProcedureCloningBFSPolicy from .bcq import BCQPolicy -from .sil import SILPolicy +from .sil import SILA2CPolicy # new-type policy from .ppof import PPOFPolicy diff --git a/ding/policy/sil.py b/ding/policy/sil.py index c7e447f807..aea939682c 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -11,11 +11,11 @@ from .common_utils import default_preprocess_learn -@POLICY_REGISTRY.register('sil') -class SILPolicy(Policy): +@POLICY_REGISTRY.register('sil_a2c') +class SILA2CPolicy(Policy): r""" Overview: - Policy class of SIL algorithm, paper link: https://arxiv.org/abs/1806.05635 + Policy class of SIL algorithm combined with A2C, paper link: https://arxiv.org/abs/1806.05635 """ config = dict( # (string) RL policy register name (refer to function "register_policy"). @@ -103,6 +103,9 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: Returns: - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. """ + data_sil = data['replay_data'] + data_sil = default_preprocess_learn(data_sil, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) + data = data['new_data'] data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) if self._cuda: data = to_device(data, self._device) @@ -137,35 +140,34 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: ) self._optimizer.step() - for _ in range(self._cfg.sil_update_per_collect): - for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): - # forward - output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') + for batch in data_sil: + # forward + output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') - adv = batch['adv'] - return_ = batch['value'] + adv - if self._adv_norm: - # norm adv in total train_batch - adv = (adv - adv.mean()) / (adv.std() + 1e-8) - error_data = sil_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) + adv = batch['adv'] + return_ = batch['value'] + adv + if self._adv_norm: + # norm adv in total train_batch + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + error_data = sil_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) - # Calculate SIL loss - sil_loss, sil_info = sil_error(error_data) - wv = self._value_weight - sil_total_loss = sil_loss.policy_loss + wv * sil_loss.value_loss + # Calculate SIL loss + sil_loss, sil_info = sil_error(error_data) + wv = self._value_weight + sil_total_loss = sil_loss.policy_loss + wv * sil_loss.value_loss - # ==================== - # SIL-learning update - # ==================== + # ==================== + # SIL-learning update + # ==================== - self._optimizer.zero_grad() - sil_total_loss.backward() + self._optimizer.zero_grad() + sil_total_loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_( - list(self._learn_model.parameters()), - max_norm=self._grad_norm, - ) - self._optimizer.step() + grad_norm = torch.nn.utils.clip_grad_norm_( + list(self._learn_model.parameters()), + max_norm=self._grad_norm, + ) + self._optimizer.step() # ============= # after update diff --git a/dizoo/classic_control/cartpole/config/cartpole_sil_config.py b/dizoo/classic_control/cartpole/config/cartpole_sil_config.py index b6eca71187..a1b866c508 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_sil_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_sil_config.py @@ -1,7 +1,7 @@ from easydict import EasyDict cartpole_sil_config = dict( - exp_name='cartpole_sil_seed0', + exp_name='cartpole_sil_a2c_seed0', env=dict( collector_env_num=8, evaluator_env_num=5, @@ -38,7 +38,7 @@ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], ), env_manager=dict(type='base'), - policy=dict(type='sil'), + policy=dict(type='sil_a2c'), ) cartpole_sil_create_config = EasyDict(cartpole_sil_create_config) create_config = cartpole_sil_create_config From 45813c598de7e2feaf275b2f9c7a1f3fd1047f83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 15 Jun 2023 16:14:12 +0800 Subject: [PATCH 12/27] debug --- ding/entry/serial_entry_sil.py | 2 +- ding/policy/command_mode_policy_instance.py | 6 +++--- .../classic_control/cartpole/config/cartpole_sil_config.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ding/entry/serial_entry_sil.py b/ding/entry/serial_entry_sil.py index 7e99bcfb86..a197deb4c3 100644 --- a/ding/entry/serial_entry_sil.py +++ b/ding/entry/serial_entry_sil.py @@ -104,7 +104,7 @@ def serial_pipeline_sil( train_data = {'new_data': new_data, 'replay_data': []} # Learn policy from collected data - for i in range(cfg.policy.learn.sil_update_per_collect): + for i in range(cfg.policy.sil_update_per_collect): # Learner will train ``update_per_collect`` times in one iteration. train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) if train_data is None: diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 6d953d2287..483f91d678 100755 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -36,7 +36,7 @@ from .sql import SQLPolicy from .bc import BehaviourCloningPolicy from .ibc import IBCPolicy -from .sil import SILPolicy +from .sil import SILA2CPolicy from .dqfd import DQFDPolicy from .r2d3 import R2D3Policy @@ -435,6 +435,6 @@ def _get_setting_eval(self, command_info: dict) -> dict: return {} -@POLICY_REGISTRY.register('sil_command') -class SILCommandModePolicy(SILPolicy, DummyCommandModePolicy): +@POLICY_REGISTRY.register('sil_a2c_command') +class SILA2CCommandModePolicy(SILA2CPolicy, DummyCommandModePolicy): pass diff --git a/dizoo/classic_control/cartpole/config/cartpole_sil_config.py b/dizoo/classic_control/cartpole/config/cartpole_sil_config.py index a1b866c508..726ba26827 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_sil_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_sil_config.py @@ -45,5 +45,5 @@ if __name__ == "__main__": # or you can enter `ding -m serial_onpolicy -c cartpole_sil_config.py -s 0` - from ding.entry import serial_pipeline_onpolicy - serial_pipeline_onpolicy((main_config, create_config), seed=0) + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0) From c3030ce9b22ea3c57bf181e518beb4d1cb9bdd08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 15 Jun 2023 16:20:04 +0800 Subject: [PATCH 13/27] debug --- ding/entry/serial_entry_sil.py | 6 +++--- .../{cartpole_sil_config.py => cartpole_sil_a2c_config.py} | 0 2 files changed, 3 insertions(+), 3 deletions(-) rename dizoo/classic_control/cartpole/config/{cartpole_sil_config.py => cartpole_sil_a2c_config.py} (100%) diff --git a/ding/entry/serial_entry_sil.py b/ding/entry/serial_entry_sil.py index a197deb4c3..f353250f4c 100644 --- a/ding/entry/serial_entry_sil.py +++ b/ding/entry/serial_entry_sil.py @@ -101,7 +101,7 @@ def serial_pipeline_sil( new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) - train_data = {'new_data': new_data, 'replay_data': []} + tot_train_data = {'new_data': new_data, 'replay_data': []} # Learn policy from collected data for i in range(cfg.policy.sil_update_per_collect): @@ -114,8 +114,8 @@ def serial_pipeline_sil( "You can modify data collect config, e.g. increasing n_sample, n_episode." ) break - train_data['replay_data'].append(train_data) - learner.train(train_data, collector.envstep) + tot_train_data['replay_data'].append(train_data) + learner.train(tot_train_data, collector.envstep) if learner.policy.get_attribute('priority'): replay_buffer.update(learner.priority_info) if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: diff --git a/dizoo/classic_control/cartpole/config/cartpole_sil_config.py b/dizoo/classic_control/cartpole/config/cartpole_sil_a2c_config.py similarity index 100% rename from dizoo/classic_control/cartpole/config/cartpole_sil_config.py rename to dizoo/classic_control/cartpole/config/cartpole_sil_a2c_config.py From 8ba7da6ee704310c824677641be293833437e37e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 15 Jun 2023 16:37:50 +0800 Subject: [PATCH 14/27] debug --- ding/policy/sil.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index aea939682c..58e7d423f8 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -19,7 +19,7 @@ class SILA2CPolicy(Policy): """ config = dict( # (string) RL policy register name (refer to function "register_policy"). - type='sil', + type='sil_a2c', # (bool) Whether to use cuda for network. cuda=False, # (bool) Whether to use on-policy training pipeline(behaviour policy and training policy are the same) @@ -104,14 +104,20 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. """ data_sil = data['replay_data'] - data_sil = default_preprocess_learn(data_sil, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) - data = data['new_data'] - data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) + data_sil = [ + default_preprocess_learn(data_sil[i], ignore_done=self._cfg.learn.ignore_done, use_nstep=False) + for i in range(len(data_sil)) + ] + data_onpolicy = data['new_data'] + data_onpolicy = default_preprocess_learn( + data_onpolicy, ignore_done=self._cfg.learn.ignore_done, use_nstep=False + ) if self._cuda: - data = to_device(data, self._device) + data_onpolicy = to_device(data_onpolicy, self._device) + data_sil = to_device(data_sil, self._device) self._learn_model.train() - for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): + for batch in split_data_generator(data_onpolicy, self._cfg.learn.batch_size, shuffle=True): # forward output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') From b0cf8b67d1ff5c64259148c814dcf475f8bc2b82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 15 Jun 2023 17:09:49 +0800 Subject: [PATCH 15/27] debug --- ding/entry/serial_entry_sil.py | 3 -- ding/policy/sil.py | 1 + .../config/lunarlander_sil_a2c_config.py | 51 +++++++++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py diff --git a/ding/entry/serial_entry_sil.py b/ding/entry/serial_entry_sil.py index f353250f4c..c50a3528a2 100644 --- a/ding/entry/serial_entry_sil.py +++ b/ding/entry/serial_entry_sil.py @@ -87,9 +87,6 @@ def serial_pipeline_sil( # Learner's before_run hook. learner.call_hook('before_run') - # Accumulate plenty of data at the beginning of training. - if cfg.policy.get('random_collect_size', 0) > 0: - random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer) while True: collect_kwargs = commander.step() # Evaluate policy performance diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 58e7d423f8..239e4e066b 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -117,6 +117,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: data_sil = to_device(data_sil, self._device) self._learn_model.train() + data_onpolicy = {data_onpolicy[k] for k in ['obs', 'adv', 'value', 'action', 'weight']} for batch in split_data_generator(data_onpolicy, self._cfg.learn.batch_size, shuffle=True): # forward output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') diff --git a/dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py b/dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py new file mode 100644 index 0000000000..2da0c78069 --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py @@ -0,0 +1,51 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +lunarlander_sil_config = dict( + exp_name='lunarlander_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + env_id='LunarLander-v2', + n_evaluator_episode=evaluator_env_num, + stop_value=200, + ), + policy=dict( + cuda=False, + sil_update_per_collect=1, + model=dict( + obs_shape=8, + action_shape=4, + ), + learn=dict( + batch_size=160, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +lunarlander_sil_config = EasyDict(lunarlander_sil_config) +main_config = lunarlander_sil_config + +lunarlander_sil_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +lunarlander_sil_create_config = EasyDict(lunarlander_sil_create_config) +create_config = lunarlander_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c lunarlander_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0) From 0820cabfabf95bf0f5bb11ad4667feb87d41d67e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 15 Jun 2023 18:57:36 +0800 Subject: [PATCH 16/27] add minigrid --- ding/policy/sil.py | 7 ++- .../config/minigrid_sil_a2c_config.py | 57 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 dizoo/minigrid/config/minigrid_sil_a2c_config.py diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 239e4e066b..8e9ac5abd1 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -103,21 +103,26 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: Returns: - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. """ + # Extract off-policy data data_sil = data['replay_data'] data_sil = [ default_preprocess_learn(data_sil[i], ignore_done=self._cfg.learn.ignore_done, use_nstep=False) for i in range(len(data_sil)) ] + # Extract on-policy data data_onpolicy = data['new_data'] + for i in range(len(data_onpolicy)): + data_onpolicy[i] = {k: data_onpolicy[i][k] for k in ['obs', 'adv', 'value', 'action', 'done']} data_onpolicy = default_preprocess_learn( data_onpolicy, ignore_done=self._cfg.learn.ignore_done, use_nstep=False ) + data_onpolicy['weight'] = None + # Put data to correct device. if self._cuda: data_onpolicy = to_device(data_onpolicy, self._device) data_sil = to_device(data_sil, self._device) self._learn_model.train() - data_onpolicy = {data_onpolicy[k] for k in ['obs', 'adv', 'value', 'action', 'weight']} for batch in split_data_generator(data_onpolicy, self._cfg.learn.batch_size, shuffle=True): # forward output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') diff --git a/dizoo/minigrid/config/minigrid_sil_a2c_config.py b/dizoo/minigrid/config/minigrid_sil_a2c_config.py new file mode 100644 index 0000000000..c6427bb564 --- /dev/null +++ b/dizoo/minigrid/config/minigrid_sil_a2c_config.py @@ -0,0 +1,57 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +minigrid_sil_config = dict( + exp_name='minigrid_sil_a2c_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + # typical MiniGrid env id: + # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'}, + # please refer to https://github.com/Farama-Foundation/MiniGrid for details. + env_id='MiniGrid-Empty-8x8-v0', + n_evaluator_episode=5, + max_step=300, + stop_value=0.96, + ), + policy=dict( + cuda=False, + sil_update_per_collect=1, + model=dict( + obs_shape=2835, + action_shape=7, + encoder_hidden_size_list=[256, 128, 64, 64], + ), + learn=dict( + batch_size=64, + learning_rate=0.0003, + value_weight=0.5, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=128, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +minigrid_sil_config = EasyDict(minigrid_sil_config) +main_config = minigrid_sil_config + +minigrid_sil_create_config = dict( + env=dict( + type='minigrid', + import_names=['dizoo.minigrid.envs.minigrid_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +minigrid_sil_create_config = EasyDict(minigrid_sil_create_config) +create_config = minigrid_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c minigrid_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0) From e27bdc586b405039f6d0adc61f344a2c4f385a68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 16 Jun 2023 14:42:51 +0800 Subject: [PATCH 17/27] update ppo+sil --- ding/policy/sil.py | 238 +++++++++++++++++- .../config/lunarlander_sil_a2c_config.py | 6 +- .../config/cartpole_sil_a2c_config.py | 2 +- .../config/minigrid_sil_a2c_config.py | 4 +- 4 files changed, 241 insertions(+), 9 deletions(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 8e9ac5abd1..9f88220a1b 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -1,13 +1,17 @@ from typing import List, Dict, Any, Tuple, Union from collections import namedtuple import torch +import copy +import numpy as np -from ding.rl_utils import sil_data, sil_error, get_gae_with_default_last_value, get_train_sample, a2c_data, a2c_error -from ding.torch_utils import Adam, to_device +from ding.rl_utils import sil_data, sil_error, a2c_data, a2c_error, ppo_data, ppo_error, ppo_policy_error,\ + ppo_policy_data, get_gae_with_default_last_value, get_train_sample, gae, gae_data, ppo_error_continuous, get_gae +from ding.torch_utils import Adam, to_device, to_dtype, unsqueeze from ding.model import model_wrap -from ding.utils import POLICY_REGISTRY, split_data_generator +from ding.utils import POLICY_REGISTRY, split_data_generator, RunningMeanStd from ding.utils.data import default_collate, default_decollate from .base_policy import Policy +from .ppo import PPOPolicy from .common_utils import default_preprocess_learn @@ -327,3 +331,231 @@ def _monitor_vars_learn(self) -> List[str]: 'a2c_policy_loss', 'sil_policy_loss', 'sil_value_loss', 'a2c_value_loss', 'a2c_total_loss', 'sil_total_loss', 'a2c_entropy_loss', 'adv_abs_max', 'grad_norm', 'policy_clipfrac', 'value_clipfrac' ] + + +@POLICY_REGISTRY.register('sil_ppo') +class SILPPOPolicy(PPOPolicy): + r""" + Overview: + Policy class of SIL algorithm combined with PPO, paper link: https://arxiv.org/abs/1806.05635 + """ + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='sil_ppo', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used) + on_policy=True, + # (bool) Whether to use priority(priority sample, IS weight, update priority) + priority=False, + # (bool) Whether to use Importance Sampling Weight to correct biased update due to priority. + # If True, priority must be True. + priority_IS_weight=False, + # (bool) Whether to recompurete advantages in each iteration of on-policy PPO + recompute_adv=True, + # (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous', 'hybrid'] + action_space='discrete', + # (bool) Whether to use nstep return to calculate value target, otherwise, use return = adv + value + nstep_return=False, + # (bool) Whether to enable multi-agent training, i.e.: MAPPO + multi_agent=False, + # (bool) Whether to need policy data in process transition + transition_with_policy_data=True, + # (int) Number of epochs to use SIL loss to update the policy. + sil_update_per_collect=1, + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + value_weight=0.5, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + entropy_weight=0.0, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.2, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=True, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + grad_clip_value=0.5, + ignore_done=False, + ), + collect=dict( + # (int) Only one of [n_sample, n_episode] shoule be set + # n_sample=64, + # (int) Cut trajectories into pieces with length "unroll_len". + unroll_len=1, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) Reward's future discount factor, aka. gamma. + discount_factor=0.99, + # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc) + gae_lambda=0.95, + ), + eval=dict(), + ) + + def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: + r""" + Overview: + Forward and backward function of learn mode. + Arguments: + - data (:obj:`dict`): Dict type data + Returns: + - info_dict (:obj:`Dict[str, Any]`): + Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \ + adv_abs_max, approx_kl, clipfrac + """ + # Extract off-policy data + data_sil = data['replay_data'] + data_sil = [ + default_preprocess_learn(data_sil[i], ignore_done=self._cfg.learn.ignore_done, use_nstep=False) + for i in range(len(data_sil)) + ] + # Extract on-policy data + data_onpolicy = data['new_data'] + for i in range(len(data_onpolicy)): + data_onpolicy[i] = {k: data_onpolicy[i][k] for k in ['obs', 'adv', 'value', 'action', 'done']} + data_onpolicy = default_preprocess_learn( + data_onpolicy, ignore_done=self._cfg.learn.ignore_done, use_nstep=False + ) + data_onpolicy['weight'] = None + # Put data to correct device. + if self._cuda: + data_onpolicy = to_device(data_onpolicy, self._device) + data_sil = to_device(data_sil, self._device) + self._learn_model.train() + data_onpolicy['obs'] = to_dtype(data_onpolicy['obs'], torch.float32) + if 'next_obs' in data_onpolicy: + data_onpolicy['next_obs'] = to_dtype(data_onpolicy['next_obs'], torch.float32) + data_sil['obs'] = to_dtype(data_sil['obs'], torch.float32) + if 'next_obs' in data_sil: + data_sil['next_obs'] = to_dtype(data_sil['next_obs'], torch.float32) + # ==================== + # PPO forward + # ==================== + return_infos = [] + self._learn_model.train() + + for epoch in range(self._cfg.learn.epoch_per_collect): + if self._recompute_adv: # calculate new value using the new updated value network + with torch.no_grad(): + value = self._learn_model.forward(data['obs'], mode='compute_critic')['value'] + next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value'] + if self._value_norm: + value *= self._running_mean_std.std + next_value *= self._running_mean_std.std + + traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory + compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], traj_flag) + data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda) + + unnormalized_returns = value + data['adv'] + + if self._value_norm: + data['value'] = value / self._running_mean_std.std + data['return'] = unnormalized_returns / self._running_mean_std.std + self._running_mean_std.update(unnormalized_returns.cpu().numpy()) + else: + data['value'] = value + data['return'] = unnormalized_returns + + else: # don't recompute adv + if self._value_norm: + unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std + data['return'] = unnormalized_return / self._running_mean_std.std + self._running_mean_std.update(unnormalized_return.cpu().numpy()) + else: + data['return'] = data['adv'] + data['value'] + + for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): + output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') + adv = batch['adv'] + if self._adv_norm: + # Normalize advantage in a train_batch + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + + # Calculate ppo error + ppo_batch = ppo_data( + output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, + batch['return'], batch['weight'] + ) + ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio) + wv, we = self._value_weight, self._entropy_weight + total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + + self._optimizer.zero_grad() + total_loss.backward() + self._optimizer.step() + + return_info = { + 'cur_lr': self._optimizer.defaults['lr'], + 'ppo_total_loss': total_loss.item(), + 'policy_loss': ppo_loss.policy_loss.item(), + 'value_loss': ppo_loss.value_loss.item(), + 'entropy_loss': ppo_loss.entropy_loss.item(), + 'adv_max': adv.max().item(), + 'adv_mean': adv.mean().item(), + 'value_mean': output['value'].mean().item(), + 'value_max': output['value'].max().item(), + 'approx_kl': ppo_info.approx_kl, + 'clipfrac': ppo_info.clipfrac, + } + return_infos.append(return_info) + + return_info_real = {k: sum([return_infos[i][k] for i in range(len(return_infos))]) + / len([return_infos[i][k] for i in range(len(return_infos))]) + for k in return_infos[0].keys()} + + for batch in data_sil: + # forward + output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') + + adv = batch['adv'] + return_ = batch['value'] + adv + if self._adv_norm: + # norm adv in total train_batch + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + error_data = sil_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) + + # Calculate SIL loss + sil_loss, sil_info = sil_error(error_data) + wv = self._value_weight + sil_total_loss = sil_loss.policy_loss + wv * sil_loss.value_loss + + # ==================== + # SIL-learning update + # ==================== + + self._optimizer.zero_grad() + sil_total_loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + list(self._learn_model.parameters()), + max_norm=self._grad_norm, + ) + self._optimizer.step() + + sil_learn_info = { + 'total_loss': sil_total_loss.item() + return_info_real['ppo_total_loss'], + 'sil_total_loss': sil_total_loss.item(), + 'sil_policy_loss': sil_loss.policy_loss.item(), + 'sil_value_loss': sil_loss.value_loss.item(), + 'policy_clipfrac': sil_info.policy_clipfrac, + 'value_clipfrac': sil_info.value_clipfrac + } + + return_info_real.update(sil_learn_info) + return return_info_real + + def _monitor_vars_learn(self) -> List[str]: + variables = list(set(super()._monitor_vars_learn() + [ + 'sil_policy_loss', 'sil_value_loss', 'ppo_total_loss', + 'sil_total_loss', 'policy_clipfrac', 'value_clipfrac' + ])) + return variables diff --git a/dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py b/dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py index 2da0c78069..ebcebdb6be 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py @@ -1,7 +1,7 @@ from easydict import EasyDict -collector_env_num = 8 -evaluator_env_num = 8 +collector_env_num = 4 +evaluator_env_num = 4 lunarlander_sil_config = dict( exp_name='lunarlander_sil_a2c_seed0', env=dict( @@ -13,7 +13,7 @@ ), policy=dict( cuda=False, - sil_update_per_collect=1, + sil_update_per_collect=5, model=dict( obs_shape=8, action_shape=4, diff --git a/dizoo/classic_control/cartpole/config/cartpole_sil_a2c_config.py b/dizoo/classic_control/cartpole/config/cartpole_sil_a2c_config.py index 726ba26827..dc45bdaabd 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_sil_a2c_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_sil_a2c_config.py @@ -10,7 +10,7 @@ ), policy=dict( cuda=False, - sil_update_per_collect=1, + sil_update_per_collect=5, model=dict( obs_shape=4, action_shape=2, diff --git a/dizoo/minigrid/config/minigrid_sil_a2c_config.py b/dizoo/minigrid/config/minigrid_sil_a2c_config.py index c6427bb564..ed4aa5e6eb 100644 --- a/dizoo/minigrid/config/minigrid_sil_a2c_config.py +++ b/dizoo/minigrid/config/minigrid_sil_a2c_config.py @@ -1,7 +1,7 @@ from easydict import EasyDict -collector_env_num = 8 -evaluator_env_num = 8 +collector_env_num = 4 +evaluator_env_num = 4 minigrid_sil_config = dict( exp_name='minigrid_sil_a2c_seed0', env=dict( From 6025dcdace57915ce3f0ac8b83a4a2705aad69ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 16 Jun 2023 14:46:33 +0800 Subject: [PATCH 18/27] a2c + sil --- ding/policy/sil.py | 161 ++------------------------------------------- 1 file changed, 6 insertions(+), 155 deletions(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 9f88220a1b..03526a2500 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -12,11 +12,12 @@ from ding.utils.data import default_collate, default_decollate from .base_policy import Policy from .ppo import PPOPolicy +from .a2c import A2CPolicy from .common_utils import default_preprocess_learn @POLICY_REGISTRY.register('sil_a2c') -class SILA2CPolicy(Policy): +class SILA2CPolicy(A2CPolicy): r""" Overview: Policy class of SIL algorithm combined with A2C, paper link: https://arxiv.org/abs/1806.05635 @@ -69,35 +70,6 @@ class SILA2CPolicy(Policy): eval=dict(), ) - def default_model(self) -> Tuple[str, List[str]]: - return 'vac', ['ding.model.template.vac'] - - def _init_learn(self) -> None: - r""" - Overview: - Learn mode init method. Called by ``self.__init__``. - Init the optimizer, algorithm config, main and target models. - """ - # Optimizer - self._optimizer = Adam( - self._model.parameters(), - lr=self._cfg.learn.learning_rate, - betas=self._cfg.learn.betas, - eps=self._cfg.learn.eps - ) - - # Algorithm config - self._priority = self._cfg.priority - self._priority_IS_weight = self._cfg.priority_IS_weight - self._value_weight = self._cfg.learn.value_weight - self._entropy_weight = self._cfg.learn.entropy_weight - self._adv_norm = self._cfg.learn.adv_norm - self._grad_norm = self._cfg.learn.grad_norm - - # Main and target models - self._learn_model = model_wrap(self._model, wrapper_name='base') - self._learn_model.reset() - def _forward_learn(self, data: dict) -> Dict[str, Any]: r""" Overview: @@ -205,132 +177,11 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'grad_norm': grad_norm, } - def _state_dict_learn(self) -> Dict[str, Any]: - return { - 'model': self._learn_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - self._learn_model.load_state_dict(state_dict['model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - - def _init_collect(self) -> None: - r""" - Overview: - Collect mode init method. Called by ``self.__init__``. - Init traj and unroll length, collect model. - """ - - self._unroll_len = self._cfg.collect.unroll_len - self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') - self._collect_model.reset() - # Algorithm - self._gamma = self._cfg.collect.discount_factor - self._gae_lambda = self._cfg.collect.gae_lambda - - def _forward_collect(self, data: dict) -> dict: - r""" - Overview: - Forward function of collect mode. - Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. - ReturnsKeys - - necessary: ``action`` - """ - data_id = list(data.keys()) - data = default_collate(list(data.values())) - if self._cuda: - data = to_device(data, self._device) - self._collect_model.eval() - with torch.no_grad(): - output = self._collect_model.forward(data, mode='compute_actor_critic') - if self._cuda: - output = to_device(output, 'cpu') - output = default_decollate(output) - return {i: d for i, d in zip(data_id, output)} - - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: - r""" - Overview: - Generate dict type transition data from inputs. - Arguments: - - obs (:obj:`Any`): Env observation - - model_output (:obj:`dict`): Output of collect model, including at least ['action'] - - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ - (here 'obs' indicates obs after env step). - Returns: - - transition (:obj:`dict`): Dict type transition data. - """ - transition = { - 'obs': obs, - 'next_obs': timestep.obs, - 'action': model_output['action'], - 'value': model_output['value'], - 'reward': timestep.reward, - 'done': timestep.done, - } - return transition - - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - r""" - Overview: - Get the trajectory and the n step return data, then sample from the n_step return data - Arguments: - - data (:obj:`list`): The trajectory's buffer list - Returns: - - samples (:obj:`dict`): The training samples generated - """ - data = get_gae_with_default_last_value( - data, - data[-1]['done'], - gamma=self._gamma, - gae_lambda=self._gae_lambda, - cuda=self._cuda, - ) - return get_train_sample(data, self._unroll_len) - - def _init_eval(self) -> None: - r""" - Overview: - Evaluate mode init method. Called by ``self.__init__``. - Init eval model with argmax strategy. - """ - self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') - self._eval_model.reset() - - def _forward_eval(self, data: dict) -> dict: - r""" - Overview: - Forward function of eval mode, similar to ``self._forward_collect``. - Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. - Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. - ReturnsKeys - - necessary: ``action`` - """ - data_id = list(data.keys()) - data = default_collate(list(data.values())) - if self._cuda: - data = to_device(data, self._device) - self._eval_model.eval() - with torch.no_grad(): - output = self._eval_model.forward(data, mode='compute_actor') - if self._cuda: - output = to_device(output, 'cpu') - output = default_decollate(output) - return {i: d for i, d in zip(data_id, output)} - def _monitor_vars_learn(self) -> List[str]: - return super()._monitor_vars_learn() + [ - 'a2c_policy_loss', 'sil_policy_loss', 'sil_value_loss', 'a2c_value_loss', 'a2c_total_loss', - 'sil_total_loss', 'a2c_entropy_loss', 'adv_abs_max', 'grad_norm', 'policy_clipfrac', 'value_clipfrac' - ] + return list(set(super()._monitor_vars_learn() + [ + 'sil_policy_loss', 'sil_value_loss', 'a2c_total_loss', + 'sil_total_loss', 'policy_clipfrac', 'value_clipfrac' + ])) @POLICY_REGISTRY.register('sil_ppo') From 39fd3b76913615673cb6fb90b331bdf2e53c4c5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 16 Jun 2023 15:03:44 +0800 Subject: [PATCH 19/27] update ppo+sil --- ding/policy/__init__.py | 2 +- ding/policy/command_mode_policy_instance.py | 7 +- ding/policy/sil.py | 32 ++++++---- .../config/lunarlander_ppo_config.py | 53 +++++++++++++++ .../config/lunarlander_sil_ppo_config.py | 54 ++++++++++++++++ .../config/cartpole_sil_ppo_config.py | 57 +++++++++++++++++ .../config/minigrid_sil_ppo_config.py | 64 +++++++++++++++++++ 7 files changed, 256 insertions(+), 13 deletions(-) create mode 100644 dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py create mode 100644 dizoo/box2d/lunarlander/config/lunarlander_sil_ppo_config.py create mode 100644 dizoo/classic_control/cartpole/config/cartpole_sil_ppo_config.py create mode 100644 dizoo/minigrid/config/minigrid_sil_ppo_config.py diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 41de15422c..d906c8a021 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -50,7 +50,7 @@ from .pc import ProcedureCloningBFSPolicy from .bcq import BCQPolicy -from .sil import SILA2CPolicy +from .sil import SILA2CPolicy, SILPPOPolicy # new-type policy from .ppof import PPOFPolicy diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 483f91d678..5641fdf3f8 100755 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -36,7 +36,7 @@ from .sql import SQLPolicy from .bc import BehaviourCloningPolicy from .ibc import IBCPolicy -from .sil import SILA2CPolicy +from .sil import SILA2CPolicy, SILPPOPolicy from .dqfd import DQFDPolicy from .r2d3 import R2D3Policy @@ -438,3 +438,8 @@ def _get_setting_eval(self, command_info: dict) -> dict: @POLICY_REGISTRY.register('sil_a2c_command') class SILA2CCommandModePolicy(SILA2CPolicy, DummyCommandModePolicy): pass + + +@POLICY_REGISTRY.register('sil_ppo_command') +class SILPPOCommandModePolicy(SILPPOPolicy, DummyCommandModePolicy): + pass diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 03526a2500..108454f7fb 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -178,10 +178,14 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: } def _monitor_vars_learn(self) -> List[str]: - return list(set(super()._monitor_vars_learn() + [ - 'sil_policy_loss', 'sil_value_loss', 'a2c_total_loss', - 'sil_total_loss', 'policy_clipfrac', 'value_clipfrac' - ])) + return list( + set( + super()._monitor_vars_learn() + [ + 'sil_policy_loss', 'sil_value_loss', 'a2c_total_loss', 'sil_total_loss', 'policy_clipfrac', + 'value_clipfrac' + ] + ) + ) @POLICY_REGISTRY.register('sil_ppo') @@ -359,9 +363,11 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: } return_infos.append(return_info) - return_info_real = {k: sum([return_infos[i][k] for i in range(len(return_infos))]) - / len([return_infos[i][k] for i in range(len(return_infos))]) - for k in return_infos[0].keys()} + return_info_real = { + k: sum([return_infos[i][k] + for i in range(len(return_infos))]) / len([return_infos[i][k] for i in range(len(return_infos))]) + for k in return_infos[0].keys() + } for batch in data_sil: # forward @@ -405,8 +411,12 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: return return_info_real def _monitor_vars_learn(self) -> List[str]: - variables = list(set(super()._monitor_vars_learn() + [ - 'sil_policy_loss', 'sil_value_loss', 'ppo_total_loss', - 'sil_total_loss', 'policy_clipfrac', 'value_clipfrac' - ])) + variables = list( + set( + super()._monitor_vars_learn() + [ + 'sil_policy_loss', 'sil_value_loss', 'ppo_total_loss', 'sil_total_loss', 'policy_clipfrac', + 'value_clipfrac' + ] + ) + ) return variables diff --git a/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py b/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py new file mode 100644 index 0000000000..1668376d1f --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py @@ -0,0 +1,53 @@ +from easydict import EasyDict + +lunarlander_ppo_config = dict( + exp_name='lunarlander_ppo_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + env_id='LunarLander-v2', + n_evaluator_episode=8, + stop_value=200, + ), + policy=dict( + cuda=True, + model=dict( + obs_shape=8, + action_shape=4, + ), + learn=dict( + update_per_collect=4, + batch_size=64, + learning_rate=0.001, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + nstep=1, + nstep_return=False, + adv_norm=True, + ), + collect=dict( + n_sample=128, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +lunarlander_ppo_config = EasyDict(lunarlander_ppo_config) +main_config = lunarlander_ppo_config +lunarlander_ppo_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo'), +) +lunarlander_ppo_create_config = EasyDict(lunarlander_ppo_create_config) +create_config = lunarlander_ppo_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial -c lunarlander_offppo_config.py -s 0` + from ding.entry import serial_pipeline + serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7)) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_sil_ppo_config.py b/dizoo/box2d/lunarlander/config/lunarlander_sil_ppo_config.py new file mode 100644 index 0000000000..73b8ab2d9c --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_sil_ppo_config.py @@ -0,0 +1,54 @@ +from easydict import EasyDict + +lunarlander_sil_ppo_config = dict( + exp_name='lunarlander_sil_ppo_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + env_id='LunarLander-v2', + n_evaluator_episode=8, + stop_value=200, + ), + policy=dict( + cuda=True, + sil_update_per_collect=1, + model=dict( + obs_shape=8, + action_shape=4, + ), + learn=dict( + update_per_collect=4, + batch_size=64, + learning_rate=0.001, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + nstep=1, + nstep_return=False, + adv_norm=True, + ), + collect=dict( + n_sample=128, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +lunarlander_sil_ppo_config = EasyDict(lunarlander_sil_ppo_config) +main_config = lunarlander_sil_ppo_config +lunarlander_sil_ppo_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_ppo'), +) +lunarlander_sil_ppo_create_config = EasyDict(lunarlander_sil_ppo_create_config) +create_config = lunarlander_sil_ppo_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_sil -c lunarlander_sil_ppo_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(1e7)) diff --git a/dizoo/classic_control/cartpole/config/cartpole_sil_ppo_config.py b/dizoo/classic_control/cartpole/config/cartpole_sil_ppo_config.py new file mode 100644 index 0000000000..281d719361 --- /dev/null +++ b/dizoo/classic_control/cartpole/config/cartpole_sil_ppo_config.py @@ -0,0 +1,57 @@ +from easydict import EasyDict + +cartpole_sil_ppo_config = dict( + exp_name='cartpole_sil_ppo_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=195, + ), + policy=dict( + cuda=False, + action_space='discrete', + sil_update_per_collect=1, + model=dict( + obs_shape=4, + action_shape=2, + action_space='discrete', + encoder_hidden_size_list=[64, 64, 128], + critic_head_hidden_size=128, + actor_head_hidden_size=128, + ), + learn=dict( + epoch_per_collect=2, + batch_size=64, + learning_rate=0.001, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + learner=dict(hook=dict(save_ckpt_after_iter=100)), + ), + collect=dict( + n_sample=256, + unroll_len=1, + discount_factor=0.9, + gae_lambda=0.95, + ), + eval=dict(evaluator=dict(eval_freq=100, ), ), + ), +) +cartpole_sil_ppo_config = EasyDict(cartpole_sil_ppo_config) +main_config = cartpole_sil_ppo_config +cartpole_sil_ppo_create_config = dict( + env=dict( + type='cartpole', + import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='sil_ppo'), +) +cartpole_sil_ppo_create_config = EasyDict(cartpole_sil_ppo_create_config) +create_config = cartpole_sil_ppo_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_sil -c cartpole_sil_ppo_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0) diff --git a/dizoo/minigrid/config/minigrid_sil_ppo_config.py b/dizoo/minigrid/config/minigrid_sil_ppo_config.py new file mode 100644 index 0000000000..c743391217 --- /dev/null +++ b/dizoo/minigrid/config/minigrid_sil_ppo_config.py @@ -0,0 +1,64 @@ +from easydict import EasyDict + +collector_env_num = 8 +minigrid_sil_ppo_config = dict( + exp_name="minigrid_sil_ppo_seed0", + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + # typical MiniGrid env id: + # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'}, + # please refer to https://github.com/Farama-Foundation/MiniGrid for details. + env_id='MiniGrid-Empty-8x8-v0', + max_step=300, + stop_value=0.96, + ), + policy=dict( + cuda=True, + recompute_adv=True, + sil_update_per_collect=1, + action_space='discrete', + model=dict( + obs_shape=2835, + action_shape=7, + action_space='discrete', + encoder_hidden_size_list=[256, 128, 64, 64], + ), + learn=dict( + epoch_per_collect=10, + update_per_collect=1, + batch_size=320, + learning_rate=3e-4, + value_weight=0.5, + entropy_weight=0.001, + clip_ratio=0.2, + adv_norm=True, + value_norm=True, + ), + collect=dict( + collector_env_num=collector_env_num, + n_sample=int(3200), + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +minigrid_sil_ppo_config = EasyDict(minigrid_sil_ppo_config) +main_config = minigrid_sil_ppo_config +minigrid_sil_ppo_create_config = dict( + env=dict( + type='minigrid', + import_names=['dizoo.minigrid.envs.minigrid_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo_sil'), +) +minigrid_sil_ppo_create_config = EasyDict(minigrid_sil_ppo_create_config) +create_config = minigrid_sil_ppo_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_sil -c minigrid_sil_ppo_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0) From d206c245a216cb392f272671a690bc8603565e95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 16 Jun 2023 15:09:32 +0800 Subject: [PATCH 20/27] polish ppo+sil --- ding/policy/sil.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 108454f7fb..2c4494e57b 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -1,16 +1,9 @@ -from typing import List, Dict, Any, Tuple, Union -from collections import namedtuple +from typing import List, Dict, Any import torch -import copy -import numpy as np - -from ding.rl_utils import sil_data, sil_error, a2c_data, a2c_error, ppo_data, ppo_error, ppo_policy_error,\ - ppo_policy_data, get_gae_with_default_last_value, get_train_sample, gae, gae_data, ppo_error_continuous, get_gae -from ding.torch_utils import Adam, to_device, to_dtype, unsqueeze -from ding.model import model_wrap -from ding.utils import POLICY_REGISTRY, split_data_generator, RunningMeanStd -from ding.utils.data import default_collate, default_decollate -from .base_policy import Policy + +from ding.rl_utils import sil_data, sil_error, a2c_data, a2c_error, ppo_data, ppo_error, gae, gae_data +from ding.torch_utils import to_device, to_dtype +from ding.utils import POLICY_REGISTRY, split_data_generator from .ppo import PPOPolicy from .a2c import A2CPolicy from .common_utils import default_preprocess_learn @@ -285,12 +278,16 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: data_onpolicy = to_device(data_onpolicy, self._device) data_sil = to_device(data_sil, self._device) self._learn_model.train() + # Convert dtype for on-policy data. data_onpolicy['obs'] = to_dtype(data_onpolicy['obs'], torch.float32) - if 'next_obs' in data_onpolicy: + if 'next_obs' in data_onpolicy[0]: data_onpolicy['next_obs'] = to_dtype(data_onpolicy['next_obs'], torch.float32) - data_sil['obs'] = to_dtype(data_sil['obs'], torch.float32) - if 'next_obs' in data_sil: - data_sil['next_obs'] = to_dtype(data_sil['next_obs'], torch.float32) + # Convert dtype for sil-data. + for i in range(len(data_sil)): + data_sil[i]['obs'] = to_dtype(data_sil[i]['obs'], torch.float32) + if 'next_obs' in data_sil[0]: + for i in range(len(data_sil)): + data_sil[i]['next_obs'] = to_dtype(data_sil[i]['next_obs'], torch.float32) # ==================== # PPO forward # ==================== From 6bd0872cf9dcf25e845e1ca3236363451574ba1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 16 Jun 2023 15:19:07 +0800 Subject: [PATCH 21/27] polish ppo+sil --- ding/policy/sil.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 2c4494e57b..3f0dfaf961 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -280,7 +280,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: self._learn_model.train() # Convert dtype for on-policy data. data_onpolicy['obs'] = to_dtype(data_onpolicy['obs'], torch.float32) - if 'next_obs' in data_onpolicy[0]: + if 'next_obs' in data_onpolicy: data_onpolicy['next_obs'] = to_dtype(data_onpolicy['next_obs'], torch.float32) # Convert dtype for sil-data. for i in range(len(data_sil)): @@ -297,35 +297,37 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: for epoch in range(self._cfg.learn.epoch_per_collect): if self._recompute_adv: # calculate new value using the new updated value network with torch.no_grad(): - value = self._learn_model.forward(data['obs'], mode='compute_critic')['value'] - next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value'] + value = self._learn_model.forward(data_onpolicy['obs'], mode='compute_critic')['value'] + next_value = self._learn_model.forward(data_onpolicy['next_obs'], mode='compute_critic')['value'] if self._value_norm: value *= self._running_mean_std.std next_value *= self._running_mean_std.std - traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory - compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], traj_flag) - data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda) + traj_flag = data_onpolicy.get('traj_flag', None) # traj_flag indicates termination of trajectory + compute_adv_data = gae_data( + value, next_value, data_onpolicy['reward'], data_onpolicy['done'], traj_flag + ) + data_onpolicy['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda) - unnormalized_returns = value + data['adv'] + unnormalized_returns = value + data_onpolicy['adv'] if self._value_norm: - data['value'] = value / self._running_mean_std.std - data['return'] = unnormalized_returns / self._running_mean_std.std + data_onpolicy['value'] = value / self._running_mean_std.std + data_onpolicy['return'] = unnormalized_returns / self._running_mean_std.std self._running_mean_std.update(unnormalized_returns.cpu().numpy()) else: - data['value'] = value - data['return'] = unnormalized_returns + data_onpolicy['value'] = value + data_onpolicy['return'] = unnormalized_returns else: # don't recompute adv if self._value_norm: - unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std - data['return'] = unnormalized_return / self._running_mean_std.std + unnormalized_return = data_onpolicy['adv'] + data_onpolicy['value'] * self._running_mean_std.std + data_onpolicy['return'] = unnormalized_return / self._running_mean_std.std self._running_mean_std.update(unnormalized_return.cpu().numpy()) else: - data['return'] = data['adv'] + data['value'] + data_onpolicy['return'] = data_onpolicy['adv'] + data_onpolicy['value'] - for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): + for batch in split_data_generator(data_onpolicy, self._cfg.learn.batch_size, shuffle=True): output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') adv = batch['adv'] if self._adv_norm: From be90f73bd791c41960564e278098d0aaa7b27f6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 16 Jun 2023 15:26:30 +0800 Subject: [PATCH 22/27] polish sil+ppo --- ding/policy/sil.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 3f0dfaf961..2b8a5b5113 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -215,6 +215,7 @@ class SILPPOPolicy(PPOPolicy): epoch_per_collect=10, batch_size=64, learning_rate=3e-4, + grad_norm=0.5, # ============================================================== # The following configs is algorithm-specific # ============================================================== @@ -268,7 +269,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: # Extract on-policy data data_onpolicy = data['new_data'] for i in range(len(data_onpolicy)): - data_onpolicy[i] = {k: data_onpolicy[i][k] for k in ['obs', 'adv', 'value', 'action', 'done']} + data_onpolicy[i] = {k: data_onpolicy[i][k] for k in ['obs', 'adv', 'value', 'action', 'done', 'next_obs', 'reward', 'logit']} data_onpolicy = default_preprocess_learn( data_onpolicy, ignore_done=self._cfg.learn.ignore_done, use_nstep=False ) @@ -393,7 +394,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: grad_norm = torch.nn.utils.clip_grad_norm_( list(self._learn_model.parameters()), - max_norm=self._grad_norm, + max_norm=self.config.learn.grad_norm, ) self._optimizer.step() From eef28a3656fbf29b75e277c23db933e0a0c4cf07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 22 Jun 2023 11:13:16 +0800 Subject: [PATCH 23/27] update recompute adv --- ding/policy/sil.py | 18 +++++++++++++++--- .../minigrid/config/minigrid_sil_ppo_config.py | 2 +- dizoo/minigrid/envs/minigrid_env.py | 4 ++-- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 2b8a5b5113..98f8c832e5 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -27,6 +27,7 @@ class SILA2CPolicy(A2CPolicy): priority_IS_weight=False, # (int) Number of epochs to use SIL loss to update the policy. sil_update_per_collect=1, + sil_recompute_adv=True, learn=dict( update_per_collect=1, # fixed value, this line should not be modified by users batch_size=64, @@ -123,10 +124,21 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: for batch in data_sil: # forward + with torch.no_grad(): + recomputed_value = self._learn_model.forward(data_onpolicy['obs'], mode='compute_critic')['value'] + recomputed_next_value = self._learn_model.forward(data_onpolicy['next_obs'], mode='compute_critic')['value'] + + traj_flag = data_onpolicy.get('traj_flag', None) # traj_flag indicates termination of trajectory + compute_adv_data = gae_data( + recomputed_value, recomputed_next_value, data_onpolicy['reward'], data_onpolicy['done'], traj_flag + ) + recomputed_adv = gae(compute_adv_data, self._gamma, self._gae_lambda) + + recomputed_returns = recomputed_value + recomputed_adv output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') - adv = batch['adv'] - return_ = batch['value'] + adv + adv = batch['adv'] if not self._cfg.sil_recompute_adv else recomputed_adv + return_ = batch['value'] + adv if not self._cfg.sil_recompute_adv else recomputed_returns if self._adv_norm: # norm adv in total train_batch adv = (adv - adv.mean()) / (adv.std() + 1e-8) @@ -394,7 +406,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: grad_norm = torch.nn.utils.clip_grad_norm_( list(self._learn_model.parameters()), - max_norm=self.config.learn.grad_norm, + max_norm=self.config["learn"]["grad_norm"], ) self._optimizer.step() diff --git a/dizoo/minigrid/config/minigrid_sil_ppo_config.py b/dizoo/minigrid/config/minigrid_sil_ppo_config.py index c743391217..77925548c6 100644 --- a/dizoo/minigrid/config/minigrid_sil_ppo_config.py +++ b/dizoo/minigrid/config/minigrid_sil_ppo_config.py @@ -53,7 +53,7 @@ import_names=['dizoo.minigrid.envs.minigrid_env'], ), env_manager=dict(type='subprocess'), - policy=dict(type='ppo_sil'), + policy=dict(type='sil_ppo'), ) minigrid_sil_ppo_create_config = EasyDict(minigrid_sil_ppo_create_config) create_config = minigrid_sil_ppo_create_config diff --git a/dizoo/minigrid/envs/minigrid_env.py b/dizoo/minigrid/envs/minigrid_env.py index e0bdbfbc07..f495dfa59f 100644 --- a/dizoo/minigrid/envs/minigrid_env.py +++ b/dizoo/minigrid/envs/minigrid_env.py @@ -60,7 +60,7 @@ def reset(self) -> np.ndarray: self._env = ObsPlusPrevActRewWrapper(self._env) self._init_flag = True if self._flat_obs: - self._observation_space = gym.spaces.Box(0, 1, shape=(2835, ), dytpe=np.float32) + self._observation_space = gym.spaces.Box(0, 1, shape=(2835, )) else: self._observation_space = self._env.observation_space # to be compatiable with subprocess env manager @@ -70,7 +70,7 @@ def reset(self) -> np.ndarray: self._observation_space.dtype = np.dtype('float32') self._action_space = self._env.action_space self._reward_space = gym.spaces.Box( - low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 + low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ) ) self._eval_episode_return = 0 From 7f9a10a52f6ab6df6bbf7c4bd06391d8e3c94ccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 23 Jun 2023 23:18:28 +0800 Subject: [PATCH 24/27] polish config --- ding/policy/sil.py | 16 +++--- .../serial/freeway/freeway_sil_a2c_config.py | 54 +++++++++++++++++++ .../config/serial/frostbite/frostbite.py | 54 +++++++++++++++++++ .../gravitar/gravitar_sil_a2c_config.py | 54 +++++++++++++++++++ .../config/serial/hero/hero_sil_a2c_config.py | 54 +++++++++++++++++++ .../montezuma/montezuma_sil_a2c_config.py | 54 +++++++++++++++++++ .../private_eye/private_eye_sil_a2c_config.py | 54 +++++++++++++++++++ 7 files changed, 332 insertions(+), 8 deletions(-) create mode 100644 dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py create mode 100644 dizoo/atari/config/serial/frostbite/frostbite.py create mode 100644 dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py create mode 100644 dizoo/atari/config/serial/hero/hero_sil_a2c_config.py create mode 100644 dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py create mode 100644 dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 98f8c832e5..cf0c601e9d 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -82,7 +82,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: # Extract on-policy data data_onpolicy = data['new_data'] for i in range(len(data_onpolicy)): - data_onpolicy[i] = {k: data_onpolicy[i][k] for k in ['obs', 'adv', 'value', 'action', 'done']} + data_onpolicy[i] = {k: data_onpolicy[i][k] for k in ['obs', 'next_obs', 'reward', 'adv', 'value', 'action', 'done']} data_onpolicy = default_preprocess_learn( data_onpolicy, ignore_done=self._cfg.learn.ignore_done, use_nstep=False ) @@ -125,12 +125,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: for batch in data_sil: # forward with torch.no_grad(): - recomputed_value = self._learn_model.forward(data_onpolicy['obs'], mode='compute_critic')['value'] - recomputed_next_value = self._learn_model.forward(data_onpolicy['next_obs'], mode='compute_critic')['value'] + recomputed_value = self._learn_model.forward(batch['obs'], mode='compute_critic')['value'] + recomputed_next_value = self._learn_model.forward(batch['next_obs'], mode='compute_critic')['value'] - traj_flag = data_onpolicy.get('traj_flag', None) # traj_flag indicates termination of trajectory + traj_flag = batch.get('traj_flag', None) # traj_flag indicates termination of trajectory compute_adv_data = gae_data( - recomputed_value, recomputed_next_value, data_onpolicy['reward'], data_onpolicy['done'], traj_flag + recomputed_value, recomputed_next_value, batch['reward'], batch['done'], traj_flag ) recomputed_adv = gae(compute_adv_data, self._gamma, self._gae_lambda) @@ -172,10 +172,10 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'sil_total_loss': sil_total_loss.item(), 'a2c_total_loss': a2c_total_loss.item(), 'sil_policy_loss': sil_loss.policy_loss.item(), - 'a2c_policy_loss': a2c_loss.policy_loss.item(), + 'policy_loss': a2c_loss.policy_loss.item(), 'sil_value_loss': sil_loss.value_loss.item(), - 'a2c_value_loss': a2c_loss.value_loss.item(), - 'a2c_entropy_loss': a2c_loss.entropy_loss.item(), + 'value_loss': a2c_loss.value_loss.item(), + 'entropy_loss': a2c_loss.entropy_loss.item(), 'policy_clipfrac': sil_info.policy_clipfrac, 'value_clipfrac': sil_info.value_clipfrac, 'adv_abs_max': adv.abs().max().item(), diff --git a/dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py b/dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py new file mode 100644 index 0000000000..20905e20d7 --- /dev/null +++ b/dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py @@ -0,0 +1,54 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +freeway_sil_config = dict( + exp_name='freeway_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='FreewayNoFrameskip-v4', + # 'ALE/freewayRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=3, + encoder_hidden_size_list=[128, 128, 512], + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +freeway_sil_config = EasyDict(freeway_sil_config) +main_config = freeway_sil_config + +freeway_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +freeway_sil_create_config = EasyDict(freeway_sil_create_config) +create_config = freeway_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c freeway_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/atari/config/serial/frostbite/frostbite.py b/dizoo/atari/config/serial/frostbite/frostbite.py new file mode 100644 index 0000000000..c5e4101973 --- /dev/null +++ b/dizoo/atari/config/serial/frostbite/frostbite.py @@ -0,0 +1,54 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +frostbite_sil_config = dict( + exp_name='frostbite_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='FrostbiteNoFrameskip-v4', + # 'ALE/frostbiteRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=18, + encoder_hidden_size_list=[128, 128, 512], + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +frostbite_sil_config = EasyDict(frostbite_sil_config) +main_config = frostbite_sil_config + +frostbite_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +frostbite_sil_create_config = EasyDict(frostbite_sil_create_config) +create_config = frostbite_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c frostbite_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py b/dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py new file mode 100644 index 0000000000..ee57d134b9 --- /dev/null +++ b/dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py @@ -0,0 +1,54 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +gravitar_sil_config = dict( + exp_name='gravitar_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='GravitarNoFrameskip-v4', + # 'ALE/gravitarRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=18, + encoder_hidden_size_list=[128, 128, 512], + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +gravitar_sil_config = EasyDict(gravitar_sil_config) +main_config = gravitar_sil_config + +gravitar_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +gravitar_sil_create_config = EasyDict(gravitar_sil_create_config) +create_config = gravitar_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c gravitar_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/atari/config/serial/hero/hero_sil_a2c_config.py b/dizoo/atari/config/serial/hero/hero_sil_a2c_config.py new file mode 100644 index 0000000000..4961db771a --- /dev/null +++ b/dizoo/atari/config/serial/hero/hero_sil_a2c_config.py @@ -0,0 +1,54 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +hero_sil_config = dict( + exp_name='hero_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='HeroNoFrameskip-v4', + # 'ALE/heroRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=18, + encoder_hidden_size_list=[128, 128, 512], + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +hero_sil_config = EasyDict(hero_sil_config) +main_config = hero_sil_config + +hero_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +hero_sil_create_config = EasyDict(hero_sil_create_config) +create_config = hero_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c hero_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py b/dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py new file mode 100644 index 0000000000..28e8c4856b --- /dev/null +++ b/dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py @@ -0,0 +1,54 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +montezuma_sil_config = dict( + exp_name='montezuma_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='MontezumaRevengeNoFrameskip-v4', + # 'ALE/MontezumaRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=18, + encoder_hidden_size_list=[128, 128, 512], + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +montezuma_sil_config = EasyDict(montezuma_sil_config) +main_config = montezuma_sil_config + +montezuma_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +montezuma_sil_create_config = EasyDict(montezuma_sil_create_config) +create_config = montezuma_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c montezuma_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py b/dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py new file mode 100644 index 0000000000..850fbc922b --- /dev/null +++ b/dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py @@ -0,0 +1,54 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +private_eye_sil_config = dict( + exp_name='private_eye_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='PrivateEyeNoFrameskip-v4', + # 'ALE/private_eyeRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=18, + encoder_hidden_size_list=[128, 128, 512], + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +private_eye_sil_config = EasyDict(private_eye_sil_config) +main_config = private_eye_sil_config + +private_eye_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +private_eye_sil_create_config = EasyDict(private_eye_sil_create_config) +create_config = private_eye_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c private_eye_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) From d2e0ef594ebba52a77f1606498ac44c54687c64d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 23 Jun 2023 23:21:42 +0800 Subject: [PATCH 25/27] polish config --- dizoo/minigrid/config/minigrid_sil_a2c_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dizoo/minigrid/config/minigrid_sil_a2c_config.py b/dizoo/minigrid/config/minigrid_sil_a2c_config.py index ed4aa5e6eb..7c29abdf31 100644 --- a/dizoo/minigrid/config/minigrid_sil_a2c_config.py +++ b/dizoo/minigrid/config/minigrid_sil_a2c_config.py @@ -10,7 +10,7 @@ # typical MiniGrid env id: # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'}, # please refer to https://github.com/Farama-Foundation/MiniGrid for details. - env_id='MiniGrid-Empty-8x8-v0', + env_id='MiniGrid-DoorKey-8x8-v0', n_evaluator_episode=5, max_step=300, stop_value=0.96, From a51417088bec79d964d147ee5a84d7e5ec59eac7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 23 Jun 2023 23:40:58 +0800 Subject: [PATCH 26/27] polish config --- dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py | 2 ++ dizoo/atari/config/serial/frostbite/frostbite.py | 2 ++ dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py | 2 ++ dizoo/atari/config/serial/hero/hero_sil_a2c_config.py | 2 ++ dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py | 2 ++ .../config/serial/private_eye/private_eye_sil_a2c_config.py | 2 ++ 6 files changed, 12 insertions(+) diff --git a/dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py b/dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py index 20905e20d7..157212ee9d 100644 --- a/dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py +++ b/dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py @@ -20,6 +20,8 @@ obs_shape=[4, 84, 84], action_shape=3, encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, ), learn=dict( batch_size=40, diff --git a/dizoo/atari/config/serial/frostbite/frostbite.py b/dizoo/atari/config/serial/frostbite/frostbite.py index c5e4101973..34f3465a78 100644 --- a/dizoo/atari/config/serial/frostbite/frostbite.py +++ b/dizoo/atari/config/serial/frostbite/frostbite.py @@ -20,6 +20,8 @@ obs_shape=[4, 84, 84], action_shape=18, encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, ), learn=dict( batch_size=40, diff --git a/dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py b/dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py index ee57d134b9..1be32b3813 100644 --- a/dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py +++ b/dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py @@ -20,6 +20,8 @@ obs_shape=[4, 84, 84], action_shape=18, encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, ), learn=dict( batch_size=40, diff --git a/dizoo/atari/config/serial/hero/hero_sil_a2c_config.py b/dizoo/atari/config/serial/hero/hero_sil_a2c_config.py index 4961db771a..2cb8bd2ae3 100644 --- a/dizoo/atari/config/serial/hero/hero_sil_a2c_config.py +++ b/dizoo/atari/config/serial/hero/hero_sil_a2c_config.py @@ -20,6 +20,8 @@ obs_shape=[4, 84, 84], action_shape=18, encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, ), learn=dict( batch_size=40, diff --git a/dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py b/dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py index 28e8c4856b..c14a40ea6e 100644 --- a/dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py +++ b/dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py @@ -20,6 +20,8 @@ obs_shape=[4, 84, 84], action_shape=18, encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, ), learn=dict( batch_size=40, diff --git a/dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py b/dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py index 850fbc922b..3870690047 100644 --- a/dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py +++ b/dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py @@ -20,6 +20,8 @@ obs_shape=[4, 84, 84], action_shape=18, encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, ), learn=dict( batch_size=40, From 86f69c2be81492bc86d51db7e51439be36632056 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 23 Jun 2023 23:51:06 +0800 Subject: [PATCH 27/27] rename config files --- .../frostbite/{frostbite.py => frostbite_sil_a2c_config.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename dizoo/atari/config/serial/frostbite/{frostbite.py => frostbite_sil_a2c_config.py} (100%) diff --git a/dizoo/atari/config/serial/frostbite/frostbite.py b/dizoo/atari/config/serial/frostbite/frostbite_sil_a2c_config.py similarity index 100% rename from dizoo/atari/config/serial/frostbite/frostbite.py rename to dizoo/atari/config/serial/frostbite/frostbite_sil_a2c_config.py