-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_lgb.py
53 lines (38 loc) · 1.48 KB
/
run_lgb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import pandas as pd
import numpy as np
import argparse
from module.backtest import BackTest
from module.data import prepare_dataset
from module.gbdt_model import GBDTModel
def parse_args():
parser = argparse.ArgumentParser()
# model settings
parser.add_argument('--num-leaves', type=int, default=31)
parser.add_argument('--loss', type=str, default='mse')
# training settings
parser.add_argument('--num-round', type=int, default=500)
parser.add_argument('--early-stopping-round', type=int, default=50)
# data settings
parser.add_argument('--data', type=str, required=True, help='data source path')
args = parser.parse_args()
return args
def main():
args = parse_args()
dataset = prepare_dataset(args.data)
model = GBDTModel(args)
train_info = model.train(dataset,
train_seg=slice(20150101, 20191231 - 7),
valid_seg=slice(20200101, 20201231 - 7),
test_seg=slice(20210101, 20220915))
pred = model.predict(dataset.get_data_split('test'))
# pred.to_pickle('pred.pkl')
# print(pred)
# back-testing
# print('reading pred.pkl')
# pred = pd.read_pickle('pred.pkl')
print(f'Start backtesting ...')
backtester = BackTest(20210101, 20220915, args.data, [1, 2, 5, 10], 10)
results = backtester.alpha_backtest(pred, alpha_shifted=False, plot=True)
print(pd.DataFrame.from_dict(results))
if __name__ == '__main__':
main()