Skip to content

Commit

Permalink
polish(pu): polish reanalyze in buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Sep 26, 2024
1 parent dba9ca7 commit d5fff6d
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 761 deletions.
3 changes: 1 addition & 2 deletions lzero/entry/train_unizero_reanalyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def train_unizero_reanalyze(
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"

# Import the correct GameBuffer class based on the policy type
# game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}
game_buffer_classes = {'unizero': 'UniZeroReGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}
game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}

GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
game_buffer_classes[create_cfg.policy.type])
Expand Down
1 change: 0 additions & 1 deletion lzero/mcts/buffer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer
from .game_buffer_sampled_muzero import SampledMuZeroGameBuffer
from .game_buffer_sampled_unizero import SampledUniZeroGameBuffer
from .game_buffer_rezero_uz import UniZeroReGameBuffer
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer
from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer
from .game_buffer_rezero_mz import ReZeroMZGameBuffer
Expand Down
31 changes: 7 additions & 24 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,6 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
probs /= probs.sum()

# sample according to transition index
# TODO(pu): replace=True
# print(f"num transitions is {num_of_transitions}")
# print(f"length of probs is {len(probs)}")
batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False)

if self._cfg.reanalyze_outdated is True:
Expand All @@ -147,9 +144,7 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:

game_segment_list.append(game_segment)
# pos_in_game_segment_list.append(pos_in_game_segment)

# pos_in_game_segment_list.append(max(pos_in_game_segment, self._cfg.game_segment_length - self._cfg.num_unroll_steps))
# TODO
# TODO: check
if pos_in_game_segment > self._cfg.game_segment_length - self._cfg.num_unroll_steps:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps + 1, 1).item()
pos_in_game_segment_list.append(pos_in_game_segment)
Expand All @@ -160,7 +155,7 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time)
return orig_data

def _sample_orig_reanalyze_data_uz(self, batch_size: int) -> Tuple:
def _sample_orig_reanalyze_batch_data(self, batch_size: int) -> Tuple:
"""
Overview:
sample orig_data that contains:
Expand All @@ -176,16 +171,14 @@ def _sample_orig_reanalyze_data_uz(self, batch_size: int) -> Tuple:
assert self._beta > 0
train_sample_num = (self.get_num_of_transitions()//self._cfg.num_unroll_steps)

# TODO: 只选择前 3/4 的样本
valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition)
# TODO: 动态调整衰减率,假设你希望衰减率与 valid_sample_num 成反比
base_decay_rate = 5 # 基础衰减率,可以根据经验设定
decay_rate = base_decay_rate / valid_sample_num # 随着样本数量增加,衰减率变小
# 生成指数衰减的权重 (仅对前 3/4 的样本)
base_decay_rate = 5
# decay rate becomes smaller as the number of samples increases
decay_rate = base_decay_rate / valid_sample_num
# Generate exponentially decaying weights (only for the first 3/4 of the samples)
weights = np.exp(-decay_rate * np.arange(valid_sample_num))
# 将权重归一化为概率分布
# Normalize the weights to a probability distribution
probabilities = weights / np.sum(weights)
# 按照概率分布进行采样 (仅在前 3/4 中采样)
batch_index_list = np.random.choice(valid_sample_num, batch_size, replace=False, p=probabilities)

if self._cfg.reanalyze_outdated is True:
Expand All @@ -203,12 +196,6 @@ def _sample_orig_reanalyze_data_uz(self, batch_size: int) -> Tuple:
game_segment_list.append(game_segment)
pos_in_game_segment_list.append(pos_in_game_segment)

# pos_in_game_segment_list.append(max(pos_in_game_segment, self._cfg.game_segment_length - self._cfg.num_unroll_steps))
# # TODO
# if pos_in_game_segment > self._cfg.game_segment_length - self._cfg.num_unroll_steps:
# pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps + 1, 1).item()
# pos_in_game_segment_list.append(pos_in_game_segment)


make_time = [time.time() for _ in range(len(batch_index_list))]

Expand Down Expand Up @@ -250,10 +237,6 @@ def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple:

game_segment_list.append(game_segment)
pos_in_game_segment_list.append(pos_in_game_segment)
# TODO
# if pos_in_game_segment > self._cfg.game_segment_length - self._cfg.num_unroll_steps:
# pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps + 1, 1).item()
# pos_in_game_segment_list.append(pos_in_game_segment)

make_time = [time.time() for _ in range(len(batch_index_list))]

Expand Down
7 changes: 3 additions & 4 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ def reanalyze_buffer(
policy._target_model.eval()
self.policy = policy
# obtain the current_batch and prepare target context
policy_re_context = self._make_batch_for_reanalyze(batch_size, 1)
policy_re_context = self._make_batch_for_reanalyze(batch_size)
# target policy
self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model)

def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]:
"""
Overview:
first sample orig_data through ``_sample_orig_data()``,
Expand All @@ -103,12 +103,11 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) ->
current_batch: the inputs of batch
Arguments:
- batch_size (:obj:`int`): the batch size of orig_data from replay buffer.
- reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed)
Returns:
- context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
"""
# obtain the batch context from replay buffer
orig_data = self._sample_orig_reanalyze_data_uz(batch_size)
orig_data = self._sample_orig_reanalyze_batch_data(batch_size)
game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
batch_size = len(batch_index_list)
# obtain the context of reanalyzed policy targets
Expand Down
Loading

0 comments on commit d5fff6d

Please sign in to comment.