forked from gordicaleksa/pytorch-GAT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
training_script_ppi.py
251 lines (199 loc) · 13.2 KB
/
training_script_ppi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import argparse
import time
from sklearn.metrics import f1_score
import torch
import torch.nn as nn
from torch.optim import Adam
from models.definitions.GAT import GAT
from utils.data_loading import load_graph_data
from utils.constants import *
import utils.utils as utils
# Simple decorator function so that I don't have to pass arguments that don't change from epoch to epoch
def get_main_loop(config, gat, sigmoid_cross_entropy_loss, optimizer, patience_period, time_start):
device = next(gat.parameters()).device # fetch the device info from the model instead of passing it as a param
def main_loop(phase, data_loader, epoch=0):
global BEST_VAL_PERF, BEST_VAL_LOSS, PATIENCE_CNT, writer
# Certain modules behave differently depending on whether we're training the model or not.
# e.g. nn.Dropout - we only want to drop model weights during the training.
if phase == LoopPhase.TRAIN:
gat.train()
else:
gat.eval()
# Iterate over batches of graph data (2 graphs per batch was used in the original paper for the PPI dataset)
# We merge them into a single graph with 2 connected components, that's the main idea. After that
# the implementation #3 is agnostic to the fact that those are multiple and not a single graph!
for batch_idx, (node_features, gt_node_labels, edge_index) in enumerate(data_loader):
# Push the batch onto GPU - note PPI is to big to load the whole dataset into a normal GPU
# it takes almost 8 GBs of VRAM to train it on a GPU
edge_index = edge_index.to(device)
node_features = node_features.to(device)
gt_node_labels = gt_node_labels.to(device)
# I pack data into tuples because GAT uses nn.Sequential which expects this format
graph_data = (node_features, edge_index)
# Note: [0] just extracts the node_features part of the data (index 1 contains the edge_index)
# shape = (N, C) where N is the number of nodes in the batch and C is the number of classes (121 for PPI)
# GAT imp #3 is agnostic to the fact that we actually have multiple graphs
# (it sees a single graph with multiple connected components)
nodes_unnormalized_scores = gat(graph_data)[0]
# Example: because PPI has 121 labels let's make a simple toy example that will show how the loss works.
# Let's say we have 3 labels instead and a single node's unnormalized (raw GAT output) scores are [-3, 0, 3]
# What this loss will do is first it will apply a sigmoid and so we'll end up with: [0.048, 0.5, 0.95]
# next it will apply a binary cross entropy across all of these and find the average, and that's it!
# So if the true classes were [0, 0, 1] the loss would be (-log(1-0.048) + -log(1-0.5) + -log(0.95))/3.
# You can see that the logarithm takes 2 forms depending on whether the true label is 0 or 1,
# either -log(1-x) or -log(x) respectively. Easy-peasy. <3
loss = sigmoid_cross_entropy_loss(nodes_unnormalized_scores, gt_node_labels)
if phase == LoopPhase.TRAIN:
optimizer.zero_grad() # clean the trainable weights gradients in the computational graph (.grad fields)
loss.backward() # compute the gradients for every trainable weight in the computational graph
optimizer.step() # apply the gradients to weights
# Calculate the main metric - micro F1
# Convert unnormalized scores into predictions. Explanation:
# If the unnormalized score is bigger than 0 that means that sigmoid would have a value higher than 0.5
# (by sigmoid's definition) and thus we have predicted 1 for that label otherwise we have predicted 0.
pred = (nodes_unnormalized_scores > 0).float().cpu().numpy()
gt = gt_node_labels.cpu().numpy()
micro_f1 = f1_score(gt, pred, average='micro')
#
# Logging
#
global_step = len(data_loader) * epoch + batch_idx
if phase == LoopPhase.TRAIN:
# Log metrics
if config['enable_tensorboard']:
writer.add_scalar('training_loss', loss.item(), global_step)
writer.add_scalar('training_micro_f1', micro_f1, global_step)
# Log to console
if config['console_log_freq'] is not None and batch_idx % config['console_log_freq'] == 0:
print(f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] |'
f' epoch={epoch + 1} | batch={batch_idx + 1} | train micro-F1={micro_f1}.')
# Save model checkpoint
if config['checkpoint_freq'] is not None and (epoch + 1) % config['checkpoint_freq'] == 0 and batch_idx == 0:
ckpt_model_name = f'gat_{config["dataset_name"]}_ckpt_epoch_{epoch + 1}.pth'
config['test_perf'] = -1 # test perf not calculated yet, note: perf means main metric micro-F1 here
torch.save(utils.get_training_state(config, gat), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))
elif phase == LoopPhase.VAL:
# Log metrics
if config['enable_tensorboard']:
writer.add_scalar('val_loss', loss.item(), global_step)
writer.add_scalar('val_micro_f1', micro_f1, global_step)
# Log to console
if config['console_log_freq'] is not None and batch_idx % config['console_log_freq'] == 0:
print(f'GAT validation: time elapsed= {(time.time() - time_start):.2f} [s] |'
f' epoch={epoch + 1} | batch={batch_idx + 1} | val micro-F1={micro_f1}')
# The "patience" logic - should we break out from the training loop? If either validation micro-F1
# keeps going up or the val loss keeps going down we won't stop
if micro_f1 > BEST_VAL_PERF or loss.item() < BEST_VAL_LOSS:
BEST_VAL_PERF = max(micro_f1, BEST_VAL_PERF) # keep track of the best validation micro_f1 so far
BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS) # and the minimal loss
PATIENCE_CNT = 0 # reset the counter every time we encounter new best micro_f1
else:
PATIENCE_CNT += 1 # otherwise keep counting
if PATIENCE_CNT >= patience_period:
raise Exception('Stopping the training, the universe has no more patience for this training.')
else:
return micro_f1 # in the case of test phase we just report back the test micro_f1
return main_loop # return the decorated function
def train_gat_ppi(config):
"""
Very similar to Cora's training script. The main differences are:
1. Using dataloaders since we're dealing with an inductive setting - multiple graphs per batch
2. Doing multi-class classification (BCEWithLogitsLoss) and reporting micro-F1 instead of accuracy
3. Model architecture and hyperparams are a bit different (as reported in the GAT paper)
"""
global BEST_VAL_PERF, BEST_VAL_LOSS
# Checking whether you have a strong GPU. Since PPI training requires almost 8 GBs of VRAM
# I've added the option to force the use of CPU even though you have a GPU on your system (but it's too weak).
device = torch.device("cuda" if torch.cuda.is_available() and not config['force_cpu'] else "cpu")
# Step 1: prepare the data loaders
data_loader_train, data_loader_val, data_loader_test = load_graph_data(config, device)
# Step 2: prepare the model
gat = GAT(
num_of_layers=config['num_of_layers'],
num_heads_per_layer=config['num_heads_per_layer'],
num_features_per_layer=config['num_features_per_layer'],
add_skip_connection=config['add_skip_connection'],
bias=config['bias'],
dropout=config['dropout'],
layer_type=config['layer_type'],
log_attention_weights=False # no need to store attentions, used only in playground.py for visualizations
).to(device)
# Step 3: Prepare other training related utilities (loss & optimizer and decorator function)
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = Adam(gat.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
# The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops
main_loop = get_main_loop(
config,
gat,
loss_fn,
optimizer,
config['patience_period'],
time.time())
BEST_VAL_PERF, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0] # reset vars used for early stopping
# Step 4: Start the training procedure
for epoch in range(config['num_of_epochs']):
# Training loop
main_loop(phase=LoopPhase.TRAIN, data_loader=data_loader_train, epoch=epoch)
# Validation loop
with torch.no_grad():
try:
main_loop(phase=LoopPhase.VAL, data_loader=data_loader_val, epoch=epoch)
except Exception as e: # "patience has run out" exception :O
print(str(e))
break # break out from the training loop
# Step 5: Potentially test your model
# Don't overfit to the test dataset - only when you've fine-tuned your model on the validation dataset should you
# report your final loss and micro-F1 on the test dataset. Friends don't let friends overfit to the test data. <3
if config['should_test']:
micro_f1 = main_loop(phase=LoopPhase.TEST, data_loader=data_loader_test)
config['test_perf'] = micro_f1
print('*' * 50)
print(f'Test micro-F1 = {micro_f1}')
else:
config['test_perf'] = -1
# Save the latest GAT in the binaries directory
torch.save(
utils.get_training_state(config, gat),
os.path.join(BINARIES_PATH, utils.get_available_binary_name(config['dataset_name']))
)
def get_training_args():
parser = argparse.ArgumentParser()
# Training related
parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=200)
parser.add_argument("--patience_period", type=int, help="number of epochs with no improvement on val before terminating", default=100)
parser.add_argument("--lr", type=float, help="model learning rate", default=5e-3)
parser.add_argument("--weight_decay", type=float, help="L2 regularization on model weights", default=0)
parser.add_argument("--should_test", action='store_true', help='should test the model on the test dataset? (no by default)')
parser.add_argument("--force_cpu", action='store_true', help='use CPU if your GPU is too small (no by default)')
# Dataset related (note: we need the dataset name for metadata and related stuff, and not for picking the dataset)
parser.add_argument("--dataset_name", choices=[el.name for el in DatasetType], help='dataset to use for training', default=DatasetType.PPI.name)
parser.add_argument("--batch_size", type=int, help='number of graphs in a batch', default=2)
parser.add_argument("--should_visualize", action='store_true', help='should visualize the dataset? (no by default)')
# Logging/debugging/checkpoint related (helps a lot with experimentation)
parser.add_argument("--enable_tensorboard", action='store_true', help="enable tensorboard logging (no by default)")
parser.add_argument("--console_log_freq", type=int, help="log to output console (batch) freq (None for no logging)", default=10)
parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (epoch) freq (None for no logging)", default=5)
args = parser.parse_args()
# I'm leaving the hyperparam values as reported in the paper, but I experimented a bit and the comments suggest
# how you can make GAT achieve an even higher micro-F1 or make it smaller
gat_config = {
# GNNs, contrary to CNNs, are often shallow (it ultimately depends on the graph properties)
"num_of_layers": 3, # PPI has got 42% of nodes with all 0 features - that's why 3 layers are useful
"num_heads_per_layer": [4, 4, 6], # other values may give even better results from the reported ones
"num_features_per_layer": [PPI_NUM_INPUT_FEATURES, 256, 256, PPI_NUM_CLASSES], # 64 would also give ~0.975 uF1!
"add_skip_connection": True, # skip connection is very important! (keep it otherwise micro-F1 is almost 0)
"bias": True, # bias doesn't matter that much
"dropout": 0.0, # dropout hurts the performance (best to keep it at 0)
"layer_type": LayerType.IMP3 # the only implementation that supports the inductive setting
}
# Wrapping training configuration into a dictionary
training_config = dict()
for arg in vars(args):
training_config[arg] = getattr(args, arg)
training_config['ppi_load_test_only'] = False # load both train/val/test data loaders (don't change it)
# Add additional config information
training_config.update(gat_config)
return training_config
if __name__ == '__main__':
# Train the graph attention network (GAT)
train_gat_ppi(get_training_args())