Skip to content

Commit

Permalink
polish(nyz): complete r2d2 comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Oct 12, 2023
1 parent 5462ee6 commit 1a3e259
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 33 deletions.
4 changes: 2 additions & 2 deletions ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ def _init_collect(self) -> None:
This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
.. note::
If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
.. tip::
Some variables need to initialize independently in different modes, such as gamma and nstep in DQN. This \
Expand Down
98 changes: 67 additions & 31 deletions ding/policy/r2d2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class R2D2Policy(Policy):
== ==================== ======== ============== ======================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============== ======================================== =======================
1 ``type`` str dqn | RL policy register name, refer to | This arg is optional,
1 ``type`` str r2d2 | RL policy register name, refer to | This arg is optional,
| registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
| erent from modes
Expand Down Expand Up @@ -155,24 +155,28 @@ def default_model(self) -> Tuple[str, List[str]]:
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated \
by import_names path. For example about DQN, its registered name is ``drqn`` and the import_names is \
by import_names path. For example about R2D2, its registered name is ``drqn`` and the import_names is \
``ding.model.template.q_learning``.
"""
return 'drqn', ['ding.model.template.q_learning']

def _init_learn(self) -> None:
r"""
"""
Overview:
Init the learner model of R2D2Policy
Arguments:
- learning_rate (:obj:`float`): The learning rate fo the optimizer
- gamma (:obj:`float`): The discount factor
- nstep (:obj:`int`): The num of n step return
- value_rescale (:obj:`bool`): Whether to use value rescaled loss in algorithm
- burnin_step (:obj:`int`): The num of step of burnin
Initialize the learn mode of policy, including some attributes and modules. For R2D2, it mainly contains \
optimizer, algorithm-specific arguments such as burnin_step, value_rescale and gamma, main and target model.
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
The _init_learn method takes the argument from the self._cfg.learn in the config file
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
Expand Down Expand Up @@ -204,16 +208,15 @@ def _init_learn(self) -> None:
self._learn_model.reset()
self._target_model.reset()

def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict:
def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""
Overview:
Preprocess the data to fit the required data format for learning
Arguments:
- data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
- data (:obj:`List[Dict[str, Any]]`): The data collected from collect function
Returns:
- data (:obj:`Dict[str, Any]`): the processed data, including at least \
- data (:obj:`Dict[str, torch.Tensor]`): The processed data, including at least \
['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
- data_info (:obj:`dict`): the data info, such as replay_buffer_idx, replay_unique_id
"""
# data preprocess
data = timestep_collate(data)
Expand Down Expand Up @@ -271,18 +274,35 @@ def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict:

return data

def _forward_learn(self, data: dict) -> Dict[str, Any]:
def _forward_learn(self, data: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
"""
Overview:
Forward and backward function of learn mode.
Acquire the data, calculate the loss and optimize learner model.
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data (trajectory for R2D2) from the replay buffer and then \
returns the output result, including various training information such as loss, q value, priority.
Arguments:
- data (:obj:`dict`): Dict type data, including at least \
['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
- data (:obj:`List[List[Dict[int, Any]]]`): The input data used for policy forward, including a batch of \
training samples. For each dict element, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the time and \
batch dimension by the utility functions ``self._data_preprocess_learn``. \
For R2D2, each element in list is a trajectory with the length of ``unroll_len``, and the element in \
trajectory list is a dict containing at least the following keys: ``obs``, ``action``, ``prev_state``, \
``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
and ``value_gamma``.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss
- cur_lr (:obj:`float`): Current learning rate
- total_loss (:obj:`float`): The calculated loss
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
.. note::
For more detailed examples, please refer to our unittest for R2D2Policy: ``ding.policy.tests.test_r2d2``.
"""
# forward
data = self._data_preprocess_learn(data) # output datatype: Dict
Expand Down Expand Up @@ -416,10 +436,21 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
self._optimizer.load_state_dict(state_dict['optimizer'])

def _init_collect(self) -> None:
r"""
"""
Overview:
Collect mode init method. Called by ``self.__init__``.
Init traj and unroll length, collect model.
Initialize the collect mode of policy, including related attributes and modules. For R2D2, it contains the \
collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism and \
maintain the hidden state of rnn. Besides, there are some initialization operations about other \
algorithm-specific arguments such as burnin_step, unroll_len and nstep.
This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
.. note::
If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
.. tip::
Some variables need to initialize independently in different modes, such as gamma and nstep in R2D2. This \
design is for the convenience of parallel execution of different policy modes.
"""
self._nstep = self._cfg.nstep
self._burnin_step = self._cfg.burnin_step
Expand Down Expand Up @@ -497,11 +528,11 @@ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.
"""
Overview:
Process and pack one timestep transition data info a dict, which can be directly used for training and \
saved in replay buffer. For DQN, it contains obs, action, prev_state, reward, done.
saved in replay buffer. For R2D2, it contains obs, action, prev_state, reward, done.
Arguments:
- obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
- policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
as input. For DQN, it contains the action and the prev_state of RNN.
as input. For R2D2, it contains the action and the prev_state of RNN.
- timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
reward, done, info, etc.
Expand Down Expand Up @@ -538,10 +569,15 @@ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str,
return get_train_sample(transitions, self._unroll_len)

def _init_eval(self) -> None:
r"""
"""
Overview:
Evaluate mode init method. Called by ``self.__init__``.
Init eval model with argmax strategy.
Initialize the eval mode of policy, including related attributes and modules. For R2D2, it contains the \
eval model to greedily select action with argmax q_value mechanism and main the hidden state.
This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
.. note::
If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
"""
self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num)
self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample')
Expand Down

0 comments on commit 1a3e259

Please sign in to comment.