Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add SIL policy #675

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions ding/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream
from .serial_entry_bco import serial_pipeline_bco
from .serial_entry_pc import serial_pipeline_pc
from .serial_entry_sil import serial_pipeline_sil
137 changes: 137 additions & 0 deletions ding/entry/serial_entry_sil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from typing import Union, Optional, List, Any, Tuple
import os
import torch
from ditk import logging
from functools import partial
from tensorboardX import SummaryWriter
from copy import deepcopy

from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector, create_serial_evaluator
from ding.config import read_config, compile_config
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from .utils import random_collect


def serial_pipeline_sil(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
dynamic_seed: Optional[bool] = True,
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline entry for off-policy RL.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
- dynamic_seed(:obj:`Optional[bool]`): set dynamic seed for collector.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed, dynamic_seed=dynamic_seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = create_serial_evaluator(
cfg.policy.eval.evaluator,
env=evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
commander = BaseSerialCommander(
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
)
# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')

while True:
collect_kwargs = commander.step()
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data by default config n_sample/n_episode
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)

tot_train_data = {'new_data': new_data, 'replay_data': []}

# Learn policy from collected data
for i in range(cfg.policy.sil_update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging.warning(
"Replay buffer's data can only train for {} steps. ".format(i) +
"You can modify data collect config, e.g. increasing n_sample, n_episode."
)
break
tot_train_data['replay_data'].append(train_data)
learner.train(tot_train_data, collector.envstep)
if learner.policy.get_attribute('priority'):
replay_buffer.update(learner.priority_info)
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

# Learner's after_run hook.
learner.call_hook('after_run')
import time
import pickle
import numpy as np
with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f:
eval_value_raw = [d['eval_episode_return'] for d in eval_info]
final_data = {
'stop': stop,
'env_step': collector.envstep,
'train_iter': learner.train_iter,
'eval_value': np.mean(eval_value_raw),
'eval_value_raw': eval_value_raw,
'finish_time': time.ctime(),
}
pickle.dump(final_data, f)
return policy
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .pc import ProcedureCloningBFSPolicy

from .bcq import BCQPolicy
from .sil import SILA2CPolicy, SILPPOPolicy

# new-type policy
from .ppof import PPOFPolicy
11 changes: 11 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .sql import SQLPolicy
from .bc import BehaviourCloningPolicy
from .ibc import IBCPolicy
from .sil import SILA2CPolicy, SILPPOPolicy

from .dqfd import DQFDPolicy
from .r2d3 import R2D3Policy
Expand Down Expand Up @@ -432,3 +433,13 @@ def _get_setting_learn(self, command_info: dict) -> dict:

def _get_setting_eval(self, command_info: dict) -> dict:
return {}


@POLICY_REGISTRY.register('sil_a2c_command')
class SILA2CCommandModePolicy(SILA2CPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('sil_ppo_command')
class SILPPOCommandModePolicy(SILPPOPolicy, DummyCommandModePolicy):
pass
Loading