-
Notifications
You must be signed in to change notification settings - Fork 34
/
train_GAugM.py
92 lines (84 loc) · 3.36 KB
/
train_GAugM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import copy
import json
import pickle
import argparse
import numpy as np
import scipy.sparse as sp
import torch
from models.GCN_dgl import GCN
from models.GAT_dgl import GAT
from models.GSAGE_dgl import GraphSAGE
from models.JKNet_dgl import JKNet
def sample_graph_det(adj_orig, A_pred, remove_pct, add_pct):
if remove_pct == 0 and add_pct == 0:
return copy.deepcopy(adj_orig)
orig_upper = sp.triu(adj_orig, 1)
n_edges = orig_upper.nnz
edges = np.asarray(orig_upper.nonzero()).T
if remove_pct:
n_remove = int(n_edges * remove_pct / 100)
pos_probs = A_pred[edges.T[0], edges.T[1]]
e_index_2b_remove = np.argpartition(pos_probs, n_remove)[:n_remove]
mask = np.ones(len(edges), dtype=bool)
mask[e_index_2b_remove] = False
edges_pred = edges[mask]
else:
edges_pred = edges
if add_pct:
n_add = int(n_edges * add_pct / 100)
# deep copy to avoid modifying A_pred
A_probs = np.array(A_pred)
# make the probabilities of the lower half to be zero (including diagonal)
A_probs[np.tril_indices(A_probs.shape[0])] = 0
# make the probabilities of existing edges to be zero
A_probs[edges.T[0], edges.T[1]] = 0
all_probs = A_probs.reshape(-1)
e_index_2b_add = np.argpartition(all_probs, -n_add)[-n_add:]
new_edges = []
for index in e_index_2b_add:
i = int(index / A_probs.shape[0])
j = index % A_probs.shape[0]
new_edges.append([i, j])
edges_pred = np.concatenate((edges_pred, new_edges), axis=0)
adj_pred = sp.csr_matrix((np.ones(len(edges_pred)), edges_pred.T), shape=adj_orig.shape)
adj_pred = adj_pred + adj_pred.T
return adj_pred
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='single')
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--gnn', type=str, default='gcn')
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--i', type=str, default='2')
args = parser.parse_args()
if args.gpu == '-1':
gpu = -1
else:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
gpu = 0
tvt_nids = pickle.load(open(f'data/graphs/{args.dataset}_tvt_nids.pkl', 'rb'))
adj_orig = pickle.load(open(f'data/graphs/{args.dataset}_adj.pkl', 'rb'))
features = pickle.load(open(f'data/graphs/{args.dataset}_features.pkl', 'rb'))
labels = pickle.load(open(f'data/graphs/{args.dataset}_labels.pkl', 'rb'))
if sp.issparse(features):
features = torch.FloatTensor(features.toarray())
params_all = json.load(open('best_parameters.json', 'r'))
params = params_all['GAugM'][args.dataset][args.gnn]
i = params['i']
A_pred = pickle.load(open(f'data/edge_probabilities/{args.dataset}_graph_{i}_logits.pkl', 'rb'))
adj_pred = sample_graph_det(adj_orig, A_pred, params['rm_pct'], params['add_pct'])
gnn = args.gnn
if gnn == 'gcn':
GNN = GCN
elif gnn == 'gat':
GNN = GAT
elif gnn == 'gsage':
GNN = GraphSAGE
elif gnn == 'jknet':
GNN = JKNet
accs = []
for _ in range(30):
gnn = GNN(adj_pred, adj_pred, features, labels, tvt_nids, print_progress=False, cuda=gpu, epochs=200)
acc, _, _ = gnn.fit()
accs.append(acc)
print(f'Micro F1: {np.mean(accs):.6f}, std: {np.std(accs):.6f}')