-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
163 lines (142 loc) · 6.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import numpy as np
import math
from scipy.interpolate import griddata, interp1d
# The max training tokens (D) used for each model family. We use the model with a middle-size
# vocabulary, 16384, to compute the total parameters when naming a file.
# For example, for the file named 'tiny_LLaMA_0000050M-V0004096IsoFLOP',
# the non-vocabulary parameters is 50_000_000 - 2*16384*d, where d is the embedding dim.
# and the vocabulary used is 4096.
max_D_list = [1.2*10**9,3.0*10**9, 5.3*10**9, 11.7*10**9, 27.7*10**9, 54.8*10**9]
max_D_dict = {'50M':1.2*10**9, '110M':3.0*10**9, '176M':5.3*10**9, '335M':11.7*10**9,
'682M':27.7*10**9, '1197M':54.8*10**9, '2975M':165.5*10**9}
embed_dim_dict = {'50M':512, '110M':768, '176M':768, '335M':1024,
'682M':1536,'1197M':2048, '2975M':3200}
model_size_dict = {'50M':50*10**6, '110M':110*10**6, '176M':176*10**6, '335M':350*10**6,
'682M':682*10**6, '1197M':1197*10**6,'2975M':2975*10**6}
steps_for_1epoch_dict = {'50M':1200, '110M':3000, '176M':5300, '335M':11700,
'682M':27700,'1197M':54800,'2975M':165500
}
def D_to_H(D,V=16384):
logv = np.log(np.minimum(V, 200_000))
term = 0.00639222*logv**2 - 0.15811069*logv + 1.20470122
return D/term
def H_to_D(H,V=16384):
logv = np.log(np.minimum(V, 200_000))
term = 0.00639222*logv**2 - 0.15811069*logv + 1.20470122
return H*term
def generate_interpolation_log(values, num=8, var=0):
return np.logspace(np.log10(values.min()*(1-var)), np.log10(values.max()*(1+var)), num)
def generate_interpolation_linear(values, num=8, var=0):
return np.linspace(values.min()*(1-var), values.max()*(1+var), num)
def Nnv_to_d(Nnv):
if Nnv <= 50_000_000:
d = 512
elif 50_000_000 < Nnv <= 200_000_000:
d = 768
elif 200_000_000 < Nnv <= 500_000_000:
d = 1024
elif 500_000_000 < Nnv <= 1_000_000_000:
d = 1536
elif 1_000_000_000 < Nnv <= 2_000_000_000:
d = 2048
elif 2_000_000_000 < Nnv <= 5_000_000_000:
d = 3200
elif 5_000_000_000 < Nnv <= 10_000_000_000:
d = 4096
elif 10_000_000_000 < Nnv <= 20_000_000_000:
d = 5120
elif 20_000_000_000 < Nnv <= 50_000_000_000:
d = 6048
elif 50_000_000_000 < Nnv <= 100_000_000_000:
d = 8192
elif 100_000_000_000 < Nnv <= 200_000_000_000:
d = 12288
elif 200_000_000_000 < Nnv <= 500_000_000_000:
d = 16384
elif 500_000_000_000 < Nnv <= 1000_000_000_000:
d = 20480
else:
d = 24576
# raise ValueError()
return float(d)
def func_flops(Nnv, H, V):
d = Nnv_to_d(Nnv)
logv = math.log(min(V, 200_000))
return 6*(Nnv+V*d)*H*(0.00639222*logv**2 - 0.15811069*logv + 1.20470122)
def interpolate(Nnv_data,H_data, V_data,flops_data, L_values, num_model, num_v,num_eval):
reshape_Nnv = np.reshape(Nnv_data, (num_model, num_v,num_eval ))
reshape_H = np.reshape(H_data, (num_model, num_v,num_eval ))
reshape_V = np.reshape(V_data, (num_model, num_v,num_eval ))
reshape_L = np.reshape(L_values, (num_model, num_v,num_eval))
interpolated_Nnv, interpolated_H, interpolated_V = [],[],[]
interpolated_flops, interpolated_loss = [], []
new_Nnv = generate_interpolation_log(np.unique(Nnv_data), 50)
new_H = generate_interpolation_log(np.unique(H_data), 50)
for vid in range(num_v):
cur_V = reshape_V[0,vid,0]
cur_all_Nnv = reshape_Nnv[:,vid,:].ravel()
cur_all_H = reshape_H[:,vid,:].ravel()
cur_all_L = reshape_L[:,vid,:].ravel()
new_points = np.array(list(zip(new_Nnv, new_H)))
points = np.array(list(zip(cur_all_Nnv, cur_all_H)))
new_L = griddata(points, cur_all_L, new_points, method='cubic')
for nnv,h,l in zip(new_Nnv, new_H, new_L):
if np.isnan(l):
continue
f = func_flops(nnv, h, cur_V)
interpolated_Nnv.append(nnv)
interpolated_H.append(h)
interpolated_V.append(cur_V)
interpolated_flops.append(f)
interpolated_loss.append(l)
new_V = generate_interpolation_log(np.unique(V_data), 20)
for modelid in range(num_model):
cur_Nnv = reshape_Nnv[modelid,0,0]
cur_all_V = reshape_V[modelid,:,0]
cur_all_L = reshape_L[modelid,:,-1]
interpolation_function = interp1d(cur_all_V, cur_all_L, kind='quadratic', fill_value='extrapolate')
new_L = interpolation_function(new_V)
new_H = np.array([D_to_H(max_D_list[modelid], V=i) for i in new_V])
for v,h,l in zip(new_V, new_H,new_L):
if np.isnan(l):
continue
f = func_flops(cur_Nnv, h, v)
interpolated_Nnv.append(cur_Nnv)
interpolated_H.append(h)
interpolated_V.append(v)
interpolated_flops.append(f)
interpolated_loss.append(l)
interpolated_Nnv, interpolated_H, interpolated_V, interpolated_flops, interpolated_loss = \
np.array(interpolated_Nnv), np.array(interpolated_H), np.array(interpolated_V), np.array(interpolated_flops), np.array(interpolated_loss)
return interpolated_Nnv, interpolated_H, interpolated_V, interpolated_flops, interpolated_loss
def merge_nearest_flops(flops_data, num_bin=10):
'''
flops_data is an ascending sequence.
Given a argument num_bin, we divide the values in flops_data into num_bin from minimum to maximum.
'''
bins_log = np.logspace(np.log10(flops_data.min()), np.log10(flops_data.max()), num_bin+1)[1:]
indices = np.digitize(flops_data, bins_log) # - 1
flops_data_binned = np.zeros_like(flops_data)
for i in range(num_bin):
bin_mask = indices == i
flops_data_binned[bin_mask] = flops_data[bin_mask].mean()
return flops_data_binned
def relative_mse(actual, predicted):
errors = actual - predicted
mse = np.mean(np.square(errors))
mean_squared = np.mean(actual)**2
relative_mse_mean_squared = mse / mean_squared
return relative_mse_mean_squared
def remove_outlier(flops, Nnvopt, Nvopt, Hopt, best_K_set,best_alpha_set):
flops, Nnvopt, Nvopt, Hopt = np.array(flops), np.array(Nnvopt), np.array(Nvopt), np.array(Hopt)
kept_idx = []
for idx,f in enumerate(flops):
ypred0 = np.exp(best_K_set[0])*(f)**best_alpha_set[0]
ypred1 = np.exp(best_K_set[1])*(f)**best_alpha_set[1]
ypred2 = np.exp(best_K_set[2])*(f)**best_alpha_set[2]
if abs(Nnvopt[idx]- ypred0)/ ypred0 < 0.4 and \
abs(Nvopt[idx]- ypred1)/ ypred1 < 0.4 and \
abs(Hopt[idx]- ypred2)/ ypred2 < 0.4 :
kept_idx.append(idx)
kept_idx = np.array(kept_idx)
return flops[kept_idx], Nnvopt[kept_idx], Nvopt[kept_idx], Hopt[kept_idx]