Skip to content

Commit

Permalink
fix(pu): fix train_entry_time to the same in ddp train
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Dec 20, 2024
1 parent 72a4875 commit 99b361d
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 192 deletions.
41 changes: 28 additions & 13 deletions lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from lzero.worker import MuZeroCollector as Collector
from .utils import random_collect
import torch.distributed as dist
from ding.utils import set_pkg_seed, get_rank, get_world_size


def train_unizero(
Expand Down Expand Up @@ -138,8 +139,17 @@ def train_unizero(

batch_size = policy._cfg.batch_size

rank = dist.get_rank()
if cfg.policy.multi_gpu:
# 获取当前的 world_size 和 rank
world_size = get_world_size()
rank = get_rank()
else:
world_size = 1
rank = 0

while True:
# torch.cuda.empty_cache()

# 记录 replay buffer 的内存使用情况
# logging.info(f"训练迭代 {learner.train_iter}: 正在记录 replay buffer 的内存使用情况...")
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
Expand Down Expand Up @@ -197,25 +207,28 @@ def train_unizero(
replay_buffer.remove_oldest_data_to_fit()
# logging.info(f"训练迭代 {learner.train_iter}: replay buffer 更新完成!")

# 同步训练前所有rank的准备状态
try:
dist.barrier()
# logging.info(f'Rank {rank}: 通过训练前的同步障碍')
except Exception as e:
logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}')
break
if world_size > 1:
# 同步训练前所有rank的准备状态
try:
dist.barrier()
# logging.info(f'Rank {rank}: 通过训练前的同步障碍')
except Exception as e:
logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}')
break

# 检查是否有足够数据进行训练
if collector.envstep > cfg.policy.train_start_after_envsteps:
if cfg.policy.sample_type == 'episode':
data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size
else:
data_sufficient = replay_buffer.get_num_of_transitions() > batch_size

if not data_sufficient:
logging.warning(f"训练迭代 {learner.train_iter}: replay buffer 数据不足,继续收集数据...")
# NOTE: 注意ddp训练时,不同rank可能有的replay buffer 数据不足,导致有的没有进入训练阶段,从而通信超时,需要确保同时进入训练阶段
logging.warning(f"Rank {rank}: 训练迭代 {learner.train_iter}: replay buffer 数据不足,继续收集数据...")
continue
# logging.info(f"Rank {dist.get_rank()}, update_per_collect:{update_per_collect}, 训练迭代 {learner.train_iter}: replay buffer 数据充足,开始训练!")

logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 开始训练!")

# 执行多轮训练
for i in range(update_per_collect):
Expand All @@ -230,7 +243,8 @@ def train_unizero(
log_vars = learner.train(train_data, collector.envstep)
if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
logging.info(f"Rank {dist.get_rank()}, 训练迭代 {learner.train_iter}: 训练完成!")

logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 训练完成!")

policy.recompute_pos_emb_diff_and_clear_cache()

Expand All @@ -240,6 +254,7 @@ def train_unizero(
break

learner.call_hook('after_run')
wandb.finish()
if cfg.policy.use_wandb:
wandb.finish()
logging.info("===== 训练完成 =====")
return policy
144 changes: 54 additions & 90 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from ding.torch_utils import MLP, ResBlock
from ding.utils import SequenceType
from ditk import logging
from openai import OpenAI
from transformers import AutoTokenizer


# use dataclass to make the output of network more convenient to use
@dataclass
Expand Down Expand Up @@ -273,121 +272,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
f"You should transform the observation shape to 64 or 96 in the env.")

return output


"""
使用vllm的BAAI/bge-base-en-v1.5 server, 模型的输入为id 需要先decode回string
使用 google-bert/bert-base-uncased , 模型的输入为id
"""
class HFLanguageRepresentationNetwork(nn.Module):
def __init__(
self,
url: str = 'BAAI/bge-base-en-v1.5',
embedding_size: int = 768,
group_size: int = 8,
api_base: str = "http://10.119.30.189:8081/v1",
api_key: str = "EMPTY"
):
def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8):
"""
初始化语言表示网络,使用 vLLM 的 API 服务获取嵌入。
初始化语言表示网络
参数:
- url (str): vLLM 服务的模型名称,默认为 'BAAI/bge-base-en-v1.5'。
- url (str): 预训练 Hugging Face 模型的地址,默认为 'google-bert/bert-base-uncased'。
- embedding_size (int): 输出嵌入的维度大小,默认为 768。
- group_size (int): SimNorm 的组大小,默认为 8。
- api_base (str): vLLM API 服务器的基本 URL,默认为 "http://10.119.30.189:8081/v1"。
- api_key (str): API 密钥,默认值为 "EMPTY"。
"""
super().__init__()
self.url = url
self.embedding_size = embedding_size
self.api_base = api_base
self.api_key = api_key
from transformers import AutoModel
# 加载 Hugging Face 预训练模型
self.model = AutoModel.from_pretrained(url)

# 初始化 OpenAI 客户端以连接 vLLM 的 API 服务器
self.client = OpenAI(
api_key=api_key,
base_url=api_base,
)

# 获取模型 ID
models = self.client.models.list()
self.model_id = models.data[0].id if models.data else url

# 初始化线性变换层(如果需要)
# 设置嵌入维度,如果目标维度不是 768,则添加一个线性变换层用于降维或升维
self.embedding_size = embedding_size
if self.embedding_size != 768:
self.embed_head = nn.Linear(768, self.embedding_size)
else:
self.embed_head = None

# 初始化 SimNorm
self.sim_norm = SimNorm(simnorm_dim=group_size)

# 初始化分词器,用于将 token 索引解码为字符串
self.tokenizer = AutoTokenizer.from_pretrained(url)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
"""
前向传播,获取输入序列的语言表示。
参数:
- x (torch.Tensor): 输入的张量,形状为 [batch_size, seq_len],类型为 torch.long。
- x (torch.Tensor): 输入的张量,通常是序列的 token 索引,形状为 [batch_size, seq_len]。
- no_grad (bool): 是否在无梯度模式下运行,默认为 True。
返回:
- torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。
"""
with torch.no_grad():
x = x.long() # 确保输入张量为长整型
if no_grad:
# 在 no_grad 模式下禁用梯度计算以节省显存
with torch.no_grad():
x = x.long() # 确保输入张量为长整型
outputs = self.model(x) # 获取模型的输出

# 模型输出的 last_hidden_state 形状为 [batch_size, seq_len, hidden_size]
# 我们通常取 [CLS] 标记对应的向量,即 outputs.last_hidden_state[:, 0, :]
cls_embedding = outputs.last_hidden_state[:, 0, :]

# 如果目标的 embedding_size 不是 768,则应用线性变换
if self.embedding_size == 768:
# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

return cls_embedding
else:
cls_embedding = self.embed_head(cls_embedding)

# # 检查索引范围
# min_idx = x.min().item()
# max_idx = x.max().item()
# # print(f"min_idx: {min_idx}, max_idx: {max_idx}")
# assert min_idx >= 0, "Negative token indices found."
# assert max_idx < self.tokenizer.vocab_size, f"Token index {max_idx} exceeds vocab size {self.tokenizer.vocab_size}."

# 将 token 索引解码为字符串
# 假设每个样本的 [CLS] token 在位置 0
# 可以根据实际情况调整
batch_size = x.size(0)
sentences: List[str] = []
for i in range(batch_size):
# 解码为字符串
tokens = x[i].tolist()
sentence = self.tokenizer.decode(tokens, skip_special_tokens=True)
sentences.append(sentence)

# 调用 vLLM 的嵌入 API
response = self.client.embeddings.create(
input=sentences,
model=self.model_id,
)
# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

# 提取嵌入并转换为张量
embeddings = [data.embedding for data in response.data] # List[List[float]]
embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=x.device) # [batch_size, 768]
return cls_embedding
else:
# 非 no_grad 模式下,启用梯度计算
x = x.long() # 确保输入张量为长整型
outputs = self.model(x)
cls_embedding = outputs.last_hidden_state[:, 0, :]

# 如果目标的 embedding_size 不是 768,则应用线性变换
if self.embedding_size == 768:
# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

return cls_embedding
else:
cls_embedding = self.embed_head(cls_embedding)

# 如果需要降维或升维
if self.embed_head is not None:
embeddings_tensor = self.embed_head(embeddings_tensor) # [batch_size, embedding_size]
# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

# 应用 SimNorm
embeddings_tensor = self.sim_norm(embeddings_tensor) # [batch_size, embedding_size]
return cls_embedding

return embeddings_tensor

def __getstate__(self):
state = self.__dict__.copy()
# 移除无法序列化的对象
del state['client']
del state['tokenizer']
return state

def __setstate__(self, state):
self.__dict__.update(state)
# 重新初始化无法序列化的对象
self.client = OpenAI(
api_key=self.api_key,
base_url=self.api_base,
)
self.tokenizer = AutoTokenizer.from_pretrained(self.url)


class RepresentationNetworkUniZero(nn.Module):
Expand Down
76 changes: 0 additions & 76 deletions lzero/model/common_noserve_input_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,82 +354,6 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:



class HFLanguageRepresentationNetwork_backup(nn.Module):
def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8):
"""
初始化语言表示网络
参数:
- url (str): 预训练 Hugging Face 模型的地址,默认为 'google-bert/bert-base-uncased'。
- embedding_size (int): 输出嵌入的维度大小,默认为 768。
"""
super().__init__()
from transformers import AutoModel
# 加载 Hugging Face 预训练模型
self.model = AutoModel.from_pretrained(url)

# 设置嵌入维度,如果目标维度不是 768,则添加一个线性变换层用于降维或升维
self.embedding_size = embedding_size
if self.embedding_size != 768:
self.embed_head = nn.Linear(768, self.embedding_size)

self.sim_norm = SimNorm(simnorm_dim=group_size)

def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
"""
前向传播,获取输入序列的语言表示。
参数:
- x (torch.Tensor): 输入的张量,通常是序列的 token 索引,形状为 [batch_size, seq_len]。
- no_grad (bool): 是否在无梯度模式下运行,默认为 True。
返回:
- torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。
"""
if no_grad:
# 在 no_grad 模式下禁用梯度计算以节省显存
with torch.no_grad():
x = x.long() # 确保输入张量为长整型
outputs = self.model(x) # 获取模型的输出

# 模型输出的 last_hidden_state 形状为 [batch_size, seq_len, hidden_size]
# 我们通常取 [CLS] 标记对应的向量,即 outputs.last_hidden_state[:, 0, :]
cls_embedding = outputs.last_hidden_state[:, 0, :]

# 如果目标的 embedding_size 不是 768,则应用线性变换
if self.embedding_size == 768:
# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

return cls_embedding
else:
cls_embedding = self.embed_head(cls_embedding)

# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

return cls_embedding
else:
# 非 no_grad 模式下,启用梯度计算
x = x.long() # 确保输入张量为长整型
outputs = self.model(x)
cls_embedding = outputs.last_hidden_state[:, 0, :]

# 如果目标的 embedding_size 不是 768,则应用线性变换
if self.embedding_size == 768:
# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

return cls_embedding
else:
cls_embedding = self.embed_head(cls_embedding)

# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

return cls_embedding


class RepresentationNetworkUniZero(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ pytest
line_profiler
xxhash
einops
openai==1.57.1
jericho
Loading

0 comments on commit 99b361d

Please sign in to comment.