Skip to content

Commit

Permalink
polish(pu): polish comments in _sample_orig_reanalyze_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Oct 16, 2024
1 parent ea16e48 commit b4ae014
Showing 1 changed file with 54 additions and 74 deletions.
128 changes: 54 additions & 74 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:

# print(f'len(game_segment)=:len(game_segment.action_segment): {len(game_segment)}')
# print(f'len(game_segment.obs_segment): {game_segment.obs_segment.shape[0]}')
# if pos_in_game_segment >= self._cfg.game_segment_length:
# pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()

# # TODO:0923 测试中 self._cfg.game_segment_length后面的child_visits可能没有更新过, 为了有效样本多一些需要让self._cfg.game_segment_length略大一些?
# In the reanalysis phase, `pos_in_game_segment` should be a multiple of `num_unroll_steps`.
# Indices exceeding `game_segment_length` are padded with the next segment and are not updated
# in the current implementation. Therefore, we need to sample `pos_in_game_segment` within
# [0, game_segment_length - num_unroll_steps] to avoid padded data.
# TODO: Consider increasing `self._cfg.game_segment_length` to ensure sampling efficiency.
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item()

Expand All @@ -164,49 +166,78 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple:
"""
Overview:
sample orig_data that contains:
game_segment_list: a list of game segments
pos_in_game_segment_list: transition index in game (relative index)
batch_index_list: the index of start transition of sampled minibatch in replay buffer
weights_list: the weight concerning the priority
make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)
This function samples a batch of game segments for reanalysis from the replay buffer.
It uses priority sampling based on the `reanalyze_time` of each game segment, with segments
that have been reanalyzed more frequently receiving lower priority.
The function returns a tuple containing information about the sampled game segments,
including their positions within each segment and the time the batch was created.
Arguments:
- batch_size (:obj:`int`): batch size
- batch_size (:obj:`int`):
The number of samples to draw in this batch.
Returns:
- Tuple:
A tuple containing the following elements:
- game_segment_list: A list of the sampled game segments.
- pos_in_game_segment_list: A list of indices representing the position of each transition
within its corresponding game segment.
- batch_index_list: The indices of the sampled game segments in the replay buffer.
- make_time: A list of timestamps (set to `0` in this implementation) indicating when
the batch was created.
Key Details:
1. **Priority Sampling**:
Game segments are sampled based on a probability distribution calculated using
the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently
are less likely to be selected.
2. **Segment Slicing**:
Each selected game segment is sampled at regular intervals determined by the
`num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled
from each selected segment.
3. **Handling Extra Samples**:
If the `batch_size` is not perfectly divisible by the number of samples per segment,
additional segments are sampled to make up the difference.
4. **Reanalyze Time Update**:
The `reanalyze_time` attribute of each sampled game segment is incremented to reflect
that it has been selected for reanalysis again.
Raises:
- ValueError:
If the `game_segment_length` is too small to accommodate the `num_unroll_steps`.
"""
assert self._beta > 0
train_sample_num = len(self.game_segment_buffer)
assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75."
valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition)

# 计算每个 game_segment 能采样的次数
# Calculate the number of samples per segment
samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps

# 确保 batch_size 可以分配给多个 game_segment
# Make sure that the batch size can be divided by the number of samples per segment
if samples_per_segment == 0:
raise ValueError("The game segment length is too small for num_unroll_steps.")

# 计算每个 game_segment 需要采样的数量
# Calculate the number of samples per segment
batch_size_per_segment = batch_size // samples_per_segment

# 如果 batch_size 不能整除,处理余数部分
# If the batch size cannot be divided, process the remainder part
extra_samples = batch_size % samples_per_segment

# 利用 game_segment_buffer 中的 reanalyze_time 来生成权重
# We use the reanalyze_time in the game_segment_buffer to generate weights
reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]])

# 计算权重: reanalyze_time 越大,权重越小 (使用exp(-reanalyze_time))
# Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time))
base_decay_rate = 100
decay_rate = base_decay_rate / valid_sample_num
weights = np.exp(-decay_rate * reanalyze_times)

# 将权重标准化为概率分布
# Normalize the weights to a probability distribution
probabilities = weights / np.sum(weights)

# 根据新生成的概率分布进行采样
# Sample game segments according to the probabilities
selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False,
p=probabilities)

# 如果有多余的样本需要分配,随机选些 game_segment 再多采样一次
# If there are extra samples to be allocated, randomly select some game segments and sample again
if extra_samples > 0:
extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=False, p=probabilities)
selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments))
Expand All @@ -219,10 +250,10 @@ def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple:
game_segment_idx -= self.base_idx
game_segment = self.game_segment_buffer[game_segment_idx]

# 更新 reanalyze_time 只增加一次
# Update reanalyze_time only once
game_segment.reanalyze_time += 1

# 采样位置应该是 0, 0 + num_unroll_steps, ... (num_unroll_steps 的整数倍)
# The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps)
for i in range(samples_per_segment):
game_segment_list.append(game_segment)
pos_in_game_segment = i * self._cfg.num_unroll_steps
Expand All @@ -231,63 +262,12 @@ def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple:
pos_in_game_segment_list.append(pos_in_game_segment)
batch_index_list.append(game_segment_idx)

# 生成批次创建时间
# make_time = [time.time() for _ in range(len(batch_index_list))]
# Set the make_time for each sample (set to 0 for now, but can be the actual time if needed).
make_time = [0. for _ in range(len(batch_index_list))]

orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time)
return orig_data

# def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple:
# """
# Overview:
# sample orig_data that contains:
# game_segment_list: a list of game segments
# pos_in_game_segment_list: transition index in game (relative index)
# batch_index_list: the index of start transition of sampled minibatch in replay buffer
# weights_list: the weight concerning the priority
# make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)
# Arguments:
# - batch_size (:obj:`int`): batch size
# - beta: float the parameter in PER for calculating the priority
# """
# assert self._beta > 0
# train_sample_num = (self.get_num_of_transitions()//self._cfg.num_unroll_steps)
#
# valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition)
# base_decay_rate = 5
# # decay rate becomes smaller as the number of samples increases
# decay_rate = base_decay_rate / valid_sample_num
# # Generate exponentially decaying weights (only for the first 3/4 of the samples)
# weights = np.exp(-decay_rate * np.arange(valid_sample_num))
# # Normalize the weights to a probability distribution
# probabilities = weights / np.sum(weights)
# batch_index_list = np.random.choice(valid_sample_num, batch_size, replace=False, p=probabilities)
#
# if self._cfg.reanalyze_outdated is True:
# # NOTE: used in reanalyze part
# batch_index_list.sort()
#
# game_segment_list = []
# pos_in_game_segment_list = []
#
# for idx in batch_index_list:
# # Select the first step of each sequence of length <self._cfg.num_unroll_steps> as the starting position
# game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx*self._cfg.num_unroll_steps]
# game_segment_idx -= self.base_idx
# game_segment = self.game_segment_buffer[game_segment_idx]
# game_segment_list.append(game_segment)
# # TODO: check the correctness of the following code
# if pos_in_game_segment >= self._cfg.game_segment_length:
# pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
# pos_in_game_segment_list.append(pos_in_game_segment)
#
#
# make_time = [time.time() for _ in range(len(batch_index_list))]
#
# orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time)
# return orig_data

def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple:
"""
Overview:
Expand Down

0 comments on commit b4ae014

Please sign in to comment.