-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
31 lines (27 loc) · 1.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from d2l import torch as d2l
from model import transformer
import argparse
import torch
def train(opt):
model = transformer.ViT(opt.img_size, opt.patch_size, opt.num_hiddens, opt.mlp_num_hiddens, opt.num_heads,
opt.num_blocks, opt.emb_dropout, opt.block_dropout, opt.lr)
data = d2l.FashionMNIST(batch_size=8, resize=(opt.img_size, opt.img_size))
trainer = d2l.Trainer(max_epochs=10, num_gpus=0)
trainer.fit(model, data)
return model
def save_model(model, path="weights/vit.pt"):
torch.save(model.state_dict(), path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img-size', type=int, default=96, help="Size of image")
parser.add_argument('--patch-size', type=int, default=16, help="Size of patches")
parser.add_argument('--num-hiddens', type=int, default=512)
parser.add_argument('--mlp-num-hiddens', type=int, default=2048)
parser.add_argument('--num-heads', type=int, default=8, help="Number of head attentions")
parser.add_argument('--num-blocks', type=int, default=2, help="Number of blocks")
parser.add_argument('--emb-dropout', type=float, default=.1, help="Embedded dropout")
parser.add_argument('--block-dropout', type=float, default=.1, help="Block dropout")
parser.add_argument('--lr', type=float, default=.1, help="Learning rate")
opt = parser.parse_args()
model = train(opt)
save_model(model)