-
Notifications
You must be signed in to change notification settings - Fork 9
/
test.py
86 lines (74 loc) · 2.28 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from pytorch_lightning.callbacks import ModelCheckpoint
import os
from argparse import ArgumentParser
import os
import gc
import datetime
import numpy as np
import pandas as pd
import numpy as np
import torch
import pytorch_lightning as pl
from CGAT.lightning_module import LightningModel
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
SEED = 1
torch.manual_seed(SEED)
np.random.seed(SEED)
def main(hparams):
"""
testing routine
Args:
hparams: checkpoint of the model to be tested and gpu, parallel backend etc.,
defined in the argument parser in if __name__ == '__main__':
Returns:
"""
checkpoint_path=hparams.ckp
model = LightningModel.load_from_checkpoint(
checkpoint_path=checkpoint_path,train=hparams.train,test=hparams.test, test_path = hparams.test_path, val_path=hparams.val_path, fea_path= hparams.fea_path
)
trainer = pl.Trainer(
gpus=[hparams.first_gpu+el for el in range(hparams.gpus)],
)
trainer.test(model)
if __name__ == '__main__':
root_dir = os.path.dirname(os.path.realpath(__file__))
parent_parser = ArgumentParser(add_help=False)
parent_parser.add_argument(
'--gpus',
type=int,
default=1,
help='how many gpus'
)
parent_parser.add_argument(
'--amp_optimization',
type=str,
default='00',
help="mixed precision format, default 00 (32), 01 mixed, 02 closer to 16, should not be used during testing"
)
parent_parser.add_argument(
'--first-gpu',
type=int,
default=0,
help='gpu number to use [first_gpu, ..., first_gpu+gpus]'
)
parent_parser.add_argument(
'--ckp',
type=str,
default='',
help='ckp path, if left empty no checkpoint is used'
)
parent_parser.add_argument(
'--hparams',
type=str,
default='',
help='path for hparams of ckp if left empty no checkpoint is used'
)
parent_parser.add_argument("--test",
action="store_true",
help="whether to train or test"
)
# each LightningModule defines arguments relevant to it
parser = LightningModel.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
print(hyperparams)
main(hyperparams)