Skip to content

Commit

Permalink
polish(nyz) polish r2d2 comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Sep 24, 2023
1 parent 3d51760 commit ec401c1
Show file tree
Hide file tree
Showing 5 changed files with 386 additions and 109 deletions.
32 changes: 30 additions & 2 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,22 @@ def _get_batch_size(self) -> Union[int, Dict[str, int]]:
# *************************************** collect function ************************************

@abstractmethod
def _forward_collect(self, data: dict, **kwargs) -> dict:
def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
"""
Overview:
Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
data, such as the action to interact with the envs, or the action logits to calculate the loss in learn \
mode. This method is left to be implemented by the subclass, and more arguments can be added in ``kwargs`` \
part if necessary.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \
dict is the same as the input data, i.e. environment id.
"""
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -574,7 +589,20 @@ def _get_n_episode(self) -> Union[int, None]:
# *************************************** eval function ************************************

@abstractmethod
def _forward_eval(self, data: dict) -> Dict[str, Any]:
def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
"""
Overview:
Policy forward function of eval mode (evaluation policy performance, such as interacting with envs or \
computing metrics on validation dataset). Forward means that the policy gets some necessary data (mainly \
observation) from the envs and then returns the output data, such as the action to interact with the envs. \
This method is left to be implemented by the subclass.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
"""
raise NotImplementedError

# don't need to implement _reset_eval method by force
Expand Down
54 changes: 35 additions & 19 deletions ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,18 +329,27 @@ def _init_collect(self) -> None:
def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
"""
Overview:
Forward computation graph of collect mode(collect training data), with eps_greedy for exploration.
Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \
exploration, i.e., classic epsilon-greedy exploration strategy.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
- eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step.
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
- eps (:obj:`float`): The epsilon value for exploration.
Returns:
- output (:obj:`Dict[int, Any]`): The dict of predicting policy_output(action) for the interaction with \
env and the constructing of transition.
ArgumentsKeys:
- necessary: ``obs``
ReturnsKeys
- necessary: ``logit``, ``action``
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \
dict is the same as the input data, i.e. environment id.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
.. note::
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
Expand Down Expand Up @@ -409,17 +418,24 @@ def _init_eval(self) -> None:
def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
"""
Overview:
Forward computation graph of eval mode(evaluate policy performance), at most cases, it is similar to \
``self._forward_collect``.
Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
action to interact with the envs.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
ArgumentsKeys:
- necessary: ``obs``
ReturnsKeys
- necessary: ``action``
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
.. note::
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
Expand Down
14 changes: 7 additions & 7 deletions ding/policy/pdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@POLICY_REGISTRY.register('pdqn')
class PDQNPolicy(Policy):
""":
"""
Overview:
Policy class of PDQN algorithm, which extends the DQN algorithm on discrete-continuous hybrid action spaces.
Paper link: https://arxiv.org/abs/1810.06394.
Expand Down Expand Up @@ -59,12 +59,12 @@ class PDQNPolicy(Policy):
| ``_sigma`` | during collection
17 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
| 'linear'].
18 | ``other.eps. float 0.95 | start value of exploration rate | [0,1]
| start``
19 | ``other.eps. float 0.05 | end value of exploration rate | [0,1]
| end``
20 | ``other.eps. int 10000 | decay length of exploration | greater than 0. set
| decay`` | decay=10000 means
18 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1]
| ``start``
19 | ``other.eps.`` float 0.05 | end value of exploration rate | [0,1]
| ``end``
20 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set
| ``decay`` | decay=10000 means
| the exploration rate
| decay from start
| value to end value
Expand Down
Loading

0 comments on commit ec401c1

Please sign in to comment.