Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python 3.6.8,torch 1.7.1+cu110,cuda 11.1环境下微调chid数据报错,显卡是3090 #10

Closed
zhenhao-huang opened this issue Jan 16, 2021 · 9 comments

Comments

@zhenhao-huang
Copy link

运行三个大小的规模finetune_chid_small.sh,finetune_chid_medium.sh,finetune_chid_large.sh都会报如下错误:
0%| | 0/577157 [00:00<?, ?it/s]
Traceback (most recent call last):
File "finetune_chid.py", line 375, in
main()
File "finetune_chid.py", line 292, in main
output = model(**batch)
File "/home/klein/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/klein/anaconda3/lib/python3.6/site-packages/deepspeed/runtime/engine.py", line 854, in forward
loss = self.module(*inputs, **kwargs)
File "/home/klein/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/klein/Desktop/CPM/CPM-Finetune/model/distributed.py", line 78, in forward
return self.module(*inputs, **kwargs)
File "/home/klein/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, **kwargs)
File "/home/klein/Desktop/CPM/CPM-Finetune/fp16/fp16.py", line 65, in forward
return fp16_to_fp32(self.module((fp32_to_fp16(inputs)), **kwargs))
File "/home/klein/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/klein/Desktop/CPM/CPM-Finetune/model/gpt2_modeling.py", line 97, in forward
transformer_output = self.transformer(embeddings, attention_mask)
File "/home/klein/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/klein/Desktop/CPM/CPM-Finetune/mpu/transformer.py", line 416, in forward
hidden_states = layer(hidden_states, attention_mask)
File "/home/klein/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/klein/Desktop/CPM/CPM-Finetune/mpu/transformer.py", line 294, in forward
mlp_output = self.mlp(layernorm_output)
File "/home/klein/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/klein/Desktop/CPM/CPM-Finetune/mpu/transformer.py", line 209, in forward
intermediate_parallel = gelu(intermediate_parallel)
File "/home/klein/Desktop/CPM/CPM-Finetune/mpu/transformer.py", line 166, in gelu
return gelu_impl(x)
RuntimeError: default_program(56): error: identifier "aten_mul_flat__1" is undefined

default_program(57): error: no operator "=" matches these operands
operand types are: half = float

default_program(61): error: identifier "aten_add_flat__1" is undefined

default_program(62): error: no operator "=" matches these operands
operand types are: half = float

4 errors detected in the compilation of "default_program".

nvrtc compilation failed:

#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)

template
device T maximum(T a, T b) {
return isnan(a) ? a : (a > b ? a : b);
}

template
device T minimum(T a, T b) {
return isnan(a) ? a : (a < b ? a : b);
}

#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
#if defined(__cplusplus)
struct align(2) __half {
host device __half() { }

protected:
unsigned short __x;
};

/* All intrinsic functions are only available to nvcc compilers /
#if defined(CUDACC)
/ Definitions of intrinsics */
device __half __float2half(const float f) {
__half val;
asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f));
return val;
}

device float __half2float(const __half h) {
float val;
asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
return val;
}
#endif /* defined(CUDACC) /
#endif / defined(__cplusplus) */
#undef __HALF_TO_US
#undef __HALF_TO_CUS

typedef __half half;

extern "C" global
void func_1(half* t0, half* aten_mul_flat, half* aten_add_flat, half* aten_tanh_flat, half* aten_add_flat_1, half* aten_mul_flat_1, half* aten_mul_flat_2, half* aten_mul_flat_3) {
{
float v = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
aten_mul_flat_1[512 * blockIdx.x + threadIdx.x] = _float2half(v * 0.04471499845385551f);
float aten_mul_flat = __half2float(aten_mul_flat_2[512 * blockIdx.x + threadIdx.x]);
aten_mul_flat__1 = __float2half(_half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]) * 0.7978845834732056f);
aten_mul_flat_2[512 * blockIdx.x + threadIdx.x] = aten_mul_flat;
float v_1 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
aten_mul_flat_3[512 * blockIdx.x + threadIdx.x] = _float2half(v_1 * 0.5f);
float aten_add_flat = __half2float(aten_add_flat_1[512 * blockIdx.x + threadIdx.x]);
aten_add_flat__1 = __float2half((__half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]) * 0.04471499845385551f) * _half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]) + 1.f);
aten_add_flat_1[512 * blockIdx.x + threadIdx.x] = aten_add_flat;
float v_2 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
float v_3 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
float v_4 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
aten_tanh_flat[512 * blockIdx.x + threadIdx.x] = __float2half(tanhf((v_2 * 0.7978845834732056f) * ((v_3 * 0.04471499845385551f) * v_4 + 1.f)));
float v_5 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
float v_6 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
float v_7 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
aten_add_flat[512 * blockIdx.x + threadIdx.x] = __float2half((tanhf((v_5 * 0.7978845834732056f) * ((v_6 * 0.04471499845385551f) * v_7 + 1.f))) + 1.f);
float v_8 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
float v_9 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
float v_10 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
float v_11 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 5120 + 5120 * (((512 * blockIdx.x + threadIdx.x) / 5120) % 725)]);
aten_mul_flat[512 * blockIdx.x + threadIdx.x] = __float2half((v_8 * 0.5f) * ((tanhf((v_9 * 0.7978845834732056f) * ((v_10 * 0.04471499845385551f) * v_11 + 1.f))) + 1.f));
}
}

@zhenhao-huang
Copy link
Author

解决了,是pytorch版本问题,Preview (Nightly)或者1.7.1之后的新版本(未来)修复了这个问题

@keezen
Copy link

keezen commented Jan 18, 2021

解决了,是pytorch版本问题,Preview (Nightly)或者1.7.1之后的新版本(未来)修复了这个问题

请问你在ChID上准确率可以达到多少呢?loss最后大概收敛到多少?

@zhenhao-huang
Copy link
Author

@keezen 目前两张3090来看是跑不动large模型的,我试下打开cpu_offload能不能跑得动

@keezen
Copy link

keezen commented Jan 19, 2021

@keezen 目前两张3090来看是跑不动large模型的,我试下打开cpu_offload能不能跑得动

好的,ChID微调我4张V100最后准确度只有10%,老哥有新进展了麻烦同步一下,谢谢~

@LittleRedLynn
Copy link

兄弟 你是在容器里面跑的吗 如果是的话可以打包分享下你的环境吗 十分感谢

@zhenhao-huang
Copy link
Author

@LittleRedLynn 我是在本地跑的,环境是pytorch预览版+cu110,cuda 11.1,apex要另外编译,见NVIDIA/apex#988

@lulu51230
Copy link

@keezen 目前两张3090来看是跑不动large模型的,我试下打开cpu_offload能不能跑得动

好的,ChID微调我4张V100最后准确度只有10%,老哥有新进展了麻烦同步一下,谢谢~

可能是gpu规模以及batch_size的大小,导致准确率较低的问题。我也遇到了你这个问题。

@lulu51230
Copy link

@keezen 目前两张3090来看是跑不动large模型的,我试下打开cpu_offload能不能跑得动

好的,ChID微调我4张V100最后准确度只有10%,老哥有新进展了麻烦同步一下,谢谢~

可能是gpu规模以及batch_size的大小,导致准确率较低的问题。我也遇到了你这个问题。

因为现有的默认参数都是以他们的实验环境训练得出的。

@Vladimir-Bayes
Copy link

Thanks. After changing the version of my pytorch from 1.7.1 to 1.8.1, the problem is setttled down.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants