From 5c025426a1883a33c83e79c495e59a133b9525fa Mon Sep 17 00:00:00 2001 From: HarryXuancy <52876902+HarryXuancy@users.noreply.github.com> Date: Mon, 16 Oct 2023 22:35:14 +0800 Subject: [PATCH] feature(xcy): add muzero config for connect4 (#107) * polish(xcy):add muzero config for connect4 * polish(xcy):adjusting parameters in sp_mode --- .../config/connect4_muzero_bot_mode_config.py | 83 +++++++++++++++++++ .../config/connect4_muzero_sp_mode_config.py | 83 +++++++++++++++++++ 2 files changed, 166 insertions(+) create mode 100644 zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py create mode 100644 zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py diff --git a/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py b/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py new file mode 100644 index 000000000..2372453ba --- /dev/null +++ b/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py @@ -0,0 +1,83 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 50 +reanalyze_ratio = 0. +batch_size = 256 +max_env_step = int(5e5) +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +connect4_muzero_config = dict( + exp_name= + f'data_mz_ctree/connect4_botmode_rulebot_seed0', + env=dict( + battle_mode='play_with_bot_mode', + bot_action_type='rule', + channel_last=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, 6, 7), + action_space_size=7, + image_channel=3, + num_res_blocks=1, + num_channels=64, + support_scale=300, + reward_support_size=601, + value_support_size=601, + ), + cuda=True, + env_type='board_games', + game_segment_length=int(6 * 7 / 2), # for battle_mode='play_with_bot_mode' + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + # NOTE:In board_games, we set large td_steps to make sure the value target is the final outcome. + td_steps=int(6 * 7 / 2), # for battle_mode='play_with_bot_mode' + # NOTE:In board_games, we set discount_factor=1. + discount_factor=1, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e5), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +connect4_muzero_config = EasyDict(connect4_muzero_config) +main_config = connect4_muzero_config + +connect4_muzero_create_config = dict( + env=dict( + type='connect4', + import_names=['zoo.board_games.connect4.envs.connect4_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), +) +connect4_muzero_create_config = EasyDict(connect4_muzero_create_config) +create_config = connect4_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + + train_muzero([main_config, create_config], seed=1, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py b/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py new file mode 100644 index 000000000..0e4bf5d8a --- /dev/null +++ b/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py @@ -0,0 +1,83 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 50 +reanalyze_ratio = 0. +batch_size = 256 +max_env_step = int(5e5) +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +connect4_muzero_config = dict( + exp_name= + f'data_mz_ctree/connect4_spmode_rulebot_seed0', + env=dict( + battle_mode='self_play_mode', + bot_action_type='rule', + channel_last=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, 6, 7), + action_space_size=7, + image_channel=3, + num_res_blocks=1, + num_channels=64, + support_scale=300, + reward_support_size=601, + value_support_size=601, + ), + cuda=True, + env_type='board_games', + game_segment_length=int(6 * 7), # for battle_mode='self_play_mode' + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + # NOTE:In board_games, we set large td_steps to make sure the value target is the final outcome. + td_steps=int(6 * 7), # for battle_mode='self_play_mode' + # NOTE:In board_games, we set discount_factor=1. + discount_factor=1, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e5), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +connect4_muzero_config = EasyDict(connect4_muzero_config) +main_config = connect4_muzero_config + +connect4_muzero_create_config = dict( + env=dict( + type='connect4', + import_names=['zoo.board_games.connect4.envs.connect4_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), +) +connect4_muzero_create_config = EasyDict(connect4_muzero_create_config) +create_config = connect4_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + + train_muzero([main_config, create_config], seed=1, max_env_step=max_env_step) \ No newline at end of file