-
Notifications
You must be signed in to change notification settings - Fork 0
/
epoch_fun.py
94 lines (59 loc) · 2.22 KB
/
epoch_fun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
'''
Author: Zehui Lin
Date: 2021-01-04 22:18:30
LastEditors: Zehui Lin
LastEditTime: 2021-01-05 11:48:49
Description: file content
'''
import torch
import numpy as np
import torch.nn as nn
from utils import AverageMeter
from init import CFG_class
CFG = CFG_class()
def train_fn(train_loader, model, optimizer, scheduler, device, ema):
losses = AverageMeter()
model.train()
for step, (cont_x, cate_x, y) in enumerate(train_loader):
cont_x, cate_x, y = cont_x.to(device), cate_x.to(device), y.to(device)
batch_size = cont_x.size(0)
pred = model(cont_x, cate_x)
loss = nn.BCEWithLogitsLoss()(pred, y)
losses.update(loss.item(), batch_size)
if CFG.gradient_accumulation_steps > 1:
loss = loss / CFG.gradient_accumulation_steps
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
if (step + 1) % CFG.gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if ema is not None:
ema.update(model.parameters())
return losses.avg
def validate_fn(valid_loader, model, device):
losses = AverageMeter()
model.eval()
val_preds = []
for step, (cont_x, cate_x, y) in enumerate(valid_loader):
cont_x, cate_x, y = cont_x.to(device), cate_x.to(device), y.to(device)
batch_size = cont_x.size(0)
with torch.no_grad():
pred = model(cont_x, cate_x)
loss = nn.BCEWithLogitsLoss()(pred, y)
losses.update(loss.item(), batch_size)
val_preds.append(pred.sigmoid().detach().cpu().numpy())
if CFG.gradient_accumulation_steps > 1:
loss = loss / CFG.gradient_accumulation_steps
val_preds = np.concatenate(val_preds)
return losses.avg, val_preds
def inference_fn(test_loader, model, device):
model.eval()
preds = []
for step, (cont_x, cate_x) in enumerate(test_loader):
cont_x, cate_x = cont_x.to(device), cate_x.to(device)
with torch.no_grad():
pred = model(cont_x, cate_x)
preds.append(pred.sigmoid().detach().cpu().numpy())
preds = np.concatenate(preds)
return preds