From d85fab070df7e0629f6fd95bf76eae0ce69eec5e Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Mon, 9 Oct 2023 23:09:31 +0800 Subject: [PATCH] polish(nyz): polish cql/dt comments --- ding/policy/cql.py | 328 ++++++++++++++------------------------------- ding/policy/dt.py | 88 +++++++----- 2 files changed, 157 insertions(+), 259 deletions(-) diff --git a/ding/policy/cql.py b/ding/policy/cql.py index 78c3eda642..c18d74fa18 100644 --- a/ding/policy/cql.py +++ b/ding/policy/cql.py @@ -12,147 +12,118 @@ from ding.utils import POLICY_REGISTRY from ding.utils.data import default_collate, default_decollate from .sac import SACPolicy -from .dqn import DQNPolicy +from .qrdqn import QRDQNPolicy from .common_utils import default_preprocess_learn @POLICY_REGISTRY.register('cql') class CQLPolicy(SACPolicy): """ - Overview: - Policy class of CQL algorithm. - - Config: - == ==================== ======== ============= ================================= ======================= - ID Symbol Type Default Value Description Other(Shape) - == ==================== ======== ============= ================================= ======================= - 1 ``type`` str td3 | RL policy register name, refer | this arg is optional, - | to registry ``POLICY_REGISTRY`` | a placeholder - 2 ``cuda`` bool True | Whether to use cuda for network | - 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for - | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ - | | buffer when training starts. | TD3. - 4 | ``model.policy_`` int 256 | Linear layer size for policy | - | ``embedding_size`` | network. | - 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | - | ``embedding_size`` | network. | - 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when - | ``embedding_size`` | network. | model.value_network - | | | is False. - 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when - | ``_rate_q`` | network. | model.value_network - | | | is True. - 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when - | ``_rate_policy`` | network. | model.value_network - | | | is True. - 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when - | ``_rate_value`` | network. | model.value_network - | | | is False. - 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- - | | coefficient. | zation for auto - | | | `alpha`, when - | | | auto_alpha is True - 11 | ``learn.repara_`` bool True | Determine whether to use | - | ``meterization`` | reparameterization trick. | - 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter - | ``auto_alpha`` | auto temperature parameter | determines the - | | `alpha`. | relative importance - | | | of the entropy term - | | | against the reward. - 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only - | ``ignore_done`` | done flag. | in halfcheetah env. - 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation - | ``target_theta`` | target network. | factor in polyak aver - | | | aging for target - | | | networks. - == ==================== ======== ============= ================================= ======================= - """ + Overview: + Policy class of CQL algorithm for continuous control. Paper link: https://arxiv.org/abs/2006.04779. + + Config: + == ==================== ======== ============= ================================= ======================= + ID Symbol Type Default Value Description Other(Shape) + == ==================== ======== ============= ================================= ======================= + 1 ``type`` str cql | RL policy register name, refer | this arg is optional, + | to registry ``POLICY_REGISTRY`` | a placeholder + 2 ``cuda`` bool True | Whether to use cuda for network | + 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for + | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ + | | buffer when training starts. | TD3. + 4 | ``model.policy_`` int 256 | Linear layer size for policy | + | ``embedding_size`` | network. | + 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | + | ``embedding_size`` | network. | + 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when + | ``embedding_size`` | network. | model.value_network + | | | is False. + 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when + | ``_rate_q`` | network. | model.value_network + | | | is True. + 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when + | ``_rate_policy`` | network. | model.value_network + | | | is True. + 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when + | ``_rate_value`` | network. | model.value_network + | | | is False. + 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- + | | coefficient. | zation for auto + | | | `alpha`, when + | | | auto_alpha is True + 11 | ``learn.repara_`` bool True | Determine whether to use | + | ``meterization`` | reparameterization trick. | + 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter + | ``auto_alpha`` | auto temperature parameter | determines the + | | `alpha`. | relative importance + | | | of the entropy term + | | | against the reward. + 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only + | ``ignore_done`` | done flag. | in halfcheetah env. + 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation + | ``target_theta`` | target network. | factor in polyak aver + | | | aging for target + | | | networks. + == ==================== ======== ============= ================================= ======================= + """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). - type='sac', - # (bool) Whether to use cuda for network. + type='cql', + # (bool) Whether to use cuda for policy. cuda=False, - # (bool type) on_policy: Determine whether on-policy or off-policy. + # (bool) on_policy: Determine whether on-policy or off-policy. # on-policy setting influences the behaviour of buffer. - # Default False in SAC. on_policy=False, - multi_agent=False, - # (bool type) priority: Determine whether to use priority in buffer sample. - # Default False in SAC. + # (bool) priority: Determine whether to use priority in buffer sample. priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, # (int) Number of training samples(randomly collected) in replay buffer when training starts. - # Default 10000 in SAC. random_collect_size=10000, model=dict( # (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation. # Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one . # Default to True. twin_critic=True, - - # (bool type) value_network: Determine whether to use value network as the - # original SAC paper (arXiv 1801.01290). - # using value_network needs to set learning_rate_value, learning_rate_q, - # and learning_rate_policy in `cfg.policy.learn`. - # Default to False. - # value_network=False, - # (str type) action_space: Use reparameterization trick for continous action action_space='reparameterization', - # (int) Hidden size for actor network head. actor_head_hidden_size=256, - # (int) Hidden size for critic network head. critic_head_hidden_size=256, ), + # learn_mode config learn=dict( - - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates (iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... update_per_collect=1, # (int) Minibatch size for gradient descent. batch_size=256, - - # (float type) learning_rate_q: Learning rate for soft q network. - # Default to 3e-4. - # Please set to 1e-3, when model.value_network is True. + # (float) learning_rate_q: Learning rate for soft q network. learning_rate_q=3e-4, - # (float type) learning_rate_policy: Learning rate for policy network. - # Default to 3e-4. - # Please set to 1e-3, when model.value_network is True. + # (float) learning_rate_policy: Learning rate for policy network. learning_rate_policy=3e-4, - # (float type) learning_rate_value: Learning rate for value network. - # `learning_rate_value` should be initialized, when model.value_network is True. - # Please set to 3e-4, when model.value_network is True. - learning_rate_value=3e-4, - - # (float type) learning_rate_alpha: Learning rate for auto temperature parameter `\alpha`. - # Default to 3e-4. + # (float) learning_rate_alpha: Learning rate for auto temperature parameter ``alpha``. learning_rate_alpha=3e-4, - # (float type) target_theta: Used for soft update of the target network, + # (float) target_theta: Used for soft update of the target network, # aka. Interpolation factor in polyak averaging for target networks. - # Default to 0.005. target_theta=0.005, # (float) discount factor for the discounted sum of rewards, aka. gamma. discount_factor=0.99, - - # (float type) alpha: Entropy regularization coefficient. + # (float) alpha: Entropy regularization coefficient. # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`. # Default to 0.2. alpha=0.2, - - # (bool type) auto_alpha: Determine whether to use auto temperature parameter `\alpha` . + # (bool) auto_alpha: Determine whether to use auto temperature parameter `\alpha` . # Temperature parameter determines the relative importance of the entropy term against the reward. # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. # Default to False. # Note that: Using auto alpha needs to set learning_rate_alpha in `cfg.policy.learn`. auto_alpha=True, - # (bool type) log_space: Determine whether to use auto `\alpha` in log space. + # (bool) log_space: Determine whether to use auto `\alpha` in log space. log_space=True, # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. @@ -162,46 +133,30 @@ class CQLPolicy(SACPolicy): # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), # when the episode step is greater than max episode step. ignore_done=False, - # (float) Weight uniform initialization range in the last output layer + # (float) Weight uniform initialization range in the last output layer. init_w=3e-3, - # (int) The numbers of action sample each at every state s from a uniform-at-random + # (int) The numbers of action sample each at every state s from a uniform-at-random. num_actions=10, # (bool) Whether use lagrange multiplier in q value loss. with_lagrange=False, - # (float) The threshold for difference in Q-values + # (float) The threshold for difference in Q-values. lagrange_thresh=-1, # (float) Loss weight for conservative item. min_q_weight=1.0, # (bool) Whether to use entropy in target q. with_q_entropy=False, ), - collect=dict( - # (int) Cut trajectories into pieces with length "unroll_len". - unroll_len=1, - ), - eval=dict(), - other=dict( - replay_buffer=dict( - # (int type) replay_buffer_size: Max size of replay buffer. - replay_buffer_size=1000000, - # (int type) max_use: Max use times of one data in the buffer. - # Data will be removed once used for too many times. - # Default to infinite. - # max_use=256, - ), - ), + eval=dict(), # for compatibility ) def _init_learn(self) -> None: - r""" + """ Overview: Learn mode init method. Called by ``self.__init__``. - Init q, value and policy's optimizers, algorithm config, main and target models. + Init q and policy's optimizers, algorithm config, main and target models. """ - # Init self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight - self._value_network = False self._twin_critic = self._cfg.model.twin_critic self._num_actions = self._cfg.learn.num_actions @@ -235,11 +190,6 @@ def _init_learn(self) -> None: self._model.critic_head[-1].last.bias.data.uniform_(-init_w, init_w) # Optimizers - if self._value_network: - self._optimizer_value = Adam( - self._model.value_critic.parameters(), - lr=self._cfg.learn.learning_rate_value, - ) self._optimizer_q = Adam( self._model.critic.parameters(), lr=self._cfg.learn.learning_rate_q, @@ -292,7 +242,7 @@ def _init_learn(self) -> None: self._forward_learn_cnt = 0 def _forward_learn(self, data: dict) -> Dict[str, Any]: - r""" + """ Overview: Forward and backward function of learn mode. Arguments: @@ -511,8 +461,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: **loss_dict } - def _get_policy_actions(self, data: Dict, num_actions=10, epsilon: float = 1e-6) -> List: - + def _get_policy_actions(self, data: Dict, num_actions: int = 10, epsilon: float = 1e-6) -> List: # evaluate to get action distribution obs = data['obs'] obs = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1]) @@ -528,7 +477,7 @@ def _get_policy_actions(self, data: Dict, num_actions=10, epsilon: float = 1e-6) return action, log_prob.view(-1, num_actions, 1) - def _get_q_value(self, data: Dict, keep=True) -> torch.Tensor: + def _get_q_value(self, data: Dict, keep: bool = True) -> torch.Tensor: new_q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] if self._twin_critic: new_q_value = [value.view(-1, self._num_actions, 1) for value in new_q_value] @@ -540,16 +489,17 @@ def _get_q_value(self, data: Dict, keep=True) -> torch.Tensor: @POLICY_REGISTRY.register('discrete_cql') -class DiscreteCQLPolicy(DQNPolicy): +class DiscreteCQLPolicy(QRDQNPolicy): """ - Overview: - Policy class of discrete CQL algorithm in discrete environments. + Overview: + Policy class of discrete CQL algorithm in discrete action space environments. + Paper link: https://arxiv.org/abs/2006.04779. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='discrete_cql', - # (bool) Whether to use cuda for network. + # (bool) Whether to use cuda for policy. cuda=False, # (bool) Whether the RL algorithm is on-policy or off-policy. on_policy=False, @@ -559,53 +509,31 @@ class DiscreteCQLPolicy(DQNPolicy): discount_factor=0.97, # (int) N-step reward for target q_value estimation nstep=1, + # learn_mode config learn=dict( - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates (iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... update_per_collect=1, + # (int) Minibatch size for one gradient descent. batch_size=64, + # (float) Learning rate for soft q network. learning_rate=0.001, - # ============================================================== - # The following configs are algorithm-specific - # ============================================================== # (int) Frequence of target network update. target_update_freq=100, - # (bool) Whether ignore done(usually for max step termination env) + # (bool) Whether ignore done(usually for max step termination env). ignore_done=False, # (float) Loss weight for conservative item. min_q_weight=1.0, ), - # collect_mode config - collect=dict( - # (int) Cut trajectories into pieces with length "unroll_len". - unroll_len=1, - ), - eval=dict(), - # other config - other=dict( - # Epsilon greedy with decay. - eps=dict( - # (str) Decay type. Support ['exp', 'linear']. - type='exp', - start=0.95, - end=0.1, - # (int) Decay length(env step) - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=10000, ) - ), + eval=dict(), # for compatibility ) - def default_model(self) -> Tuple[str, List[str]]: - return 'qrdqn', ['ding.model.template.q_learning'] - def _init_learn(self) -> None: - r""" - Overview: - Learn mode init method. Called by ``self.__init__``. - Init the optimizer, algorithm config, main and target models. - """ + """ + Overview: + Learn mode init method. Called by ``self.__init__``. + Init the optimizer, algorithm config, main and target models. + """ self._min_q_weight = self._cfg.learn.min_q_weight self._priority = self._cfg.priority # Optimizer @@ -627,14 +555,14 @@ def _init_learn(self) -> None: self._target_model.reset() def _forward_learn(self, data: dict) -> Dict[str, Any]: - r""" - Overview: - Forward and backward function of learn mode. - Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] - Returns: - - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. - """ + """ + Overview: + Forward and backward function of learn mode. + Arguments: + - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] + Returns: + - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. + """ data = default_preprocess_learn( data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True ) @@ -701,70 +629,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: # '[histogram]action_distribution': data['action'], } - def _state_dict_learn(self) -> Dict[str, Any]: - return { - 'model': self._learn_model.state_dict(), - 'target_model': self._target_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - self._learn_model.load_state_dict(state_dict['model']) - self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - - def _init_collect(self) -> None: - r""" - Overview: - Collect mode init method. Called by ``self.__init__``. - Init traj and unroll length, collect model. - Enable the eps_greedy_sample - """ - self._unroll_len = self._cfg.collect.unroll_len - self._gamma = self._cfg.discount_factor # necessary for parallel - self._nstep = self._cfg.nstep # necessary for parallel - self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') - self._collect_model.reset() - - def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: + def _monitor_vars_learn(self) -> List[str]: """ Overview: - Forward computation graph of collect mode(collect training data), with eps_greedy for exploration. - Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. - - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting policy_output(action) for the interaction with \ - env and the constructing of transition. - ArgumentsKeys: - - necessary: ``obs`` - ReturnsKeys - - necessary: ``logit``, ``action`` + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. """ - data_id = list(data.keys()) - data = default_collate(list(data.values())) - if self._cuda: - data = to_device(data, self._device) - self._collect_model.eval() - with torch.no_grad(): - output = self._collect_model.forward(data, eps=eps) - if self._cuda: - output = to_device(output, 'cpu') - output = default_decollate(output) - return {i: d for i, d in zip(data_id, output)} - - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - r""" - Overview: - Get the trajectory and the n step return data, then sample from the n_step return data - Arguments: - - data (:obj:`list`): The trajectory's cache - Returns: - - samples (:obj:`dict`): The training samples generated - """ - data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) - return get_train_sample(data, self._unroll_len) - - def _monitor_vars_learn(self) -> List[str]: return ['cur_lr', 'total_loss', 'q_target', 'q_value'] diff --git a/ding/policy/dt.py b/ding/policy/dt.py index adef441820..c630c8949d 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -1,7 +1,4 @@ -"""The code is adapted from https://github.com/nikhilbarhate99/min-decision-transformer -""" - -from typing import List, Dict, Any, Tuple +from typing import List, Dict, Any, Tuple, Optional from collections import namedtuple import torch.nn.functional as F import torch @@ -15,10 +12,10 @@ @POLICY_REGISTRY.register('dt') class DTPolicy(Policy): - r""" + """ Overview: Policy class of Decision Transformer algorithm in discrete environments. - Paper link: https://arxiv.org/abs/2106.01345 + Paper link: https://arxiv.org/abs/2106.01345. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -42,10 +39,22 @@ class DTPolicy(Policy): ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \ + ``ding.model.template.q_learning``. + """ return 'dt', ['ding.model.template.dt'] def _init_learn(self) -> None: - r""" + """ Overview: Learn mode init method. Called by ``self.__init__``. Init the optimizer, algorithm config, main and target models. @@ -84,13 +93,13 @@ def _init_learn(self) -> None: self.max_env_score = -1.0 def _forward_learn(self, data: list) -> Dict[str, Any]: - r""" - Overview: - Forward and backward function of learn mode. - Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] - Returns: - - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. + """ + Overview: + Forward and backward function of learn mode. + Arguments: + - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] + Returns: + - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. """ self._learn_model.train() @@ -145,7 +154,7 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: } def _init_eval(self) -> None: - r""" + """ Overview: Evaluate mode init method. Called by ``self.__init__``, initialize eval_model. """ @@ -196,6 +205,22 @@ def _init_eval(self) -> None: ) def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance, such as interacting with envs. \ + Forward means that the policy gets some input data (current obs/return-to-go and historical information) \ + from the envs and then returns the output data, such as the action to interact with the envs. \ + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs and \ + reward to calculate running return-to-go. The key of the dict is environment id and the value is the \ + corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + Decision Transformer will do different operations for different types of envs in evaluation. + """ # save and forward data_id = list(data.keys()) @@ -279,7 +304,17 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _reset_eval(self, data_id: List[int] = None) -> None: + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset some stateful variables for eval mode when necessary, such as the historical info of transformer \ + for decision transformer. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different environments/episodes in evaluation in ``data_id`` will have different history. + Arguments: + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + specified by ``data_id``. + """ # clean data if data_id is None: self.t = [0 for _ in range(self.eval_batch_size)] @@ -339,21 +374,14 @@ def _reset_eval(self, data_id: List[int] = None) -> None: self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device) self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) - def _state_dict_learn(self) -> Dict[str, Any]: - return { - 'model': self._learn_model.state_dict(), - # 'target_model': self._target_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - self._learn_model.load_state_dict(state_dict['model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - - def _load_state_dict_eval(self, state_dict: Dict[str, Any]) -> None: - self._eval_model.load_state_dict(state_dict) - def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return ['cur_lr', 'action_loss'] def _init_collect(self) -> None: