-
Notifications
You must be signed in to change notification settings - Fork 57
/
train.py
54 lines (51 loc) · 2.13 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from model import Model
import numpy as np
import os
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 256
train_dataset = mnist.MNIST(root='./train', train=True, transform=ToTensor())
test_dataset = mnist.MNIST(root='./test', train=False, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
model = Model().to(device)
sgd = SGD(model.parameters(), lr=1e-1)
loss_fn = CrossEntropyLoss()
all_epoch = 100
prev_acc = 0
for current_epoch in range(all_epoch):
model.train()
for idx, (train_x, train_label) in enumerate(train_loader):
train_x = train_x.to(device)
train_label = train_label.to(device)
sgd.zero_grad()
predict_y = model(train_x.float())
loss = loss_fn(predict_y, train_label.long())
loss.backward()
sgd.step()
all_correct_num = 0
all_sample_num = 0
model.eval()
for idx, (test_x, test_label) in enumerate(test_loader):
test_x = test_x.to(device)
test_label = test_label.to(device)
predict_y = model(test_x.float()).detach()
predict_y =torch.argmax(predict_y, dim=-1)
current_correct_num = predict_y == test_label
all_correct_num += np.sum(current_correct_num.to('cpu').numpy(), axis=-1)
all_sample_num += current_correct_num.shape[0]
acc = all_correct_num / all_sample_num
print('accuracy: {:.3f}'.format(acc), flush=True)
if not os.path.isdir("models"):
os.mkdir("models")
torch.save(model, 'models/mnist_{:.3f}.pkl'.format(acc))
if np.abs(acc - prev_acc) < 1e-4:
break
prev_acc = acc
print("Model finished training")