forked from malllabiisc/ASAP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
asap_pool.py
163 lines (119 loc) · 5.6 KB
/
asap_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import math
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_scatter import scatter_add, scatter_max
from torch_geometric.nn import GCNConv
from le_conv import LEConv
from torch_geometric.utils import add_remaining_self_loops, remove_self_loops, softmax
from torch_geometric.nn.pool.topk_pool import topk
from torch_sparse import coalesce
from torch_sparse import transpose
from torch_sparse import spspmm
# torch.set_num_threads(1)
def StAS(index_A, value_A, index_S, value_S, device, N, kN):
r"""StAS: a function which returns new edge weights for the pooled graph using the formula S^{T}AS"""
index_A, value_A = coalesce(index_A, value_A, m=N, n=N)
index_S, value_S = coalesce(index_S, value_S, m=N, n=kN)
index_B, value_B = spspmm(index_A, value_A, index_S, value_S, N, N, kN)
index_St, value_St = transpose(index_S, value_S, N, kN)
index_B, value_B = coalesce(index_B, value_B, m=N, n=kN)
# index_E, value_E = spspmm(index_St.cpu(), value_St.cpu(), index_B.cpu(), value_B.cpu(), kN, N, kN)
index_E, value_E = spspmm(index_St, value_St, index_B, value_B, kN, N, kN)
# return index_E.to(device), value_E.to(device)
return index_E, value_E
def graph_connectivity(device, perm, edge_index, edge_weight, score, ratio, batch, N):
r"""graph_connectivity: is a function which internally calls StAS func to maintain graph connectivity"""
kN = perm.size(0)
perm2 = perm.view(-1, 1)
# mask contains bool mask of edges which originate from perm (selected) nodes
mask = (edge_index[0]==perm2).sum(0, dtype=torch.bool)
# create the S
S0 = edge_index[1][mask].view(1, -1)
S1 = edge_index[0][mask].view(1, -1)
index_S = torch.cat([S0, S1], dim=0)
value_S = score[mask].detach().squeeze()
# relabel for pooling ie: make S [N x kN]
n_idx = torch.zeros(N, dtype=torch.long)
n_idx[perm] = torch.arange(perm.size(0))
index_S[1] = n_idx[index_S[1]]
# create A
index_A = edge_index.clone()
if edge_weight is None:
value_A = value_S.new_ones(edge_index[0].size(0))
else:
value_A = edge_weight.clone()
fill_value=1
index_E, value_E = StAS(index_A, value_A, index_S, value_S, device, N, kN)
index_E, value_E = remove_self_loops(edge_index=index_E, edge_attr=value_E)
index_E, value_E = add_remaining_self_loops(edge_index=index_E, edge_weight=value_E,
fill_value=fill_value, num_nodes=kN)
return index_E, value_E
class ASAP_Pooling(torch.nn.Module):
def __init__(self, in_channels, ratio, dropout_att=0, negative_slope=0.2):
super(ASAP_Pooling, self).__init__()
self.in_channels = in_channels
self.ratio = ratio
self.negative_slope = negative_slope
self.dropout_att = dropout_att
self.lin_q = Linear(in_channels, in_channels)
self.gat_att = Linear(2*in_channels, 1)
self.gnn_score = LEConv(self.in_channels, 1) # gnn_score: uses LEConv to find cluster fitness scores
self.gnn_intra_cluster = GCNConv(self.in_channels, self.in_channels) # gnn_intra_cluster: uses GCN to account for intra cluster properties, e.g., edge-weights
self.reset_parameters()
def reset_parameters(self):
self.lin_q.reset_parameters()
self.gat_att.reset_parameters()
self.gnn_score.reset_parameters()
self.gnn_intra_cluster.reset_parameters()
def forward(self, x, edge_index, edge_weight=None, batch=None):
if batch is None:
batch = edge_index.new_zeros(x.size(0))
# NxF
x = x.unsqueeze(-1) if x.dim() == 1 else x
# Add Self Loops
fill_value = 1
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
edge_index, edge_weight = add_remaining_self_loops(edge_index=edge_index, edge_weight=edge_weight,
fill_value=fill_value, num_nodes=num_nodes.sum())
N = x.size(0) # total num of nodes in batch
# ExF
x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight)
x_pool_j = x_pool[edge_index[1]]
x_j = x[edge_index[1]]
#---Master query formation---
# NxF
X_q, _ = scatter_max(x_pool_j, edge_index[0], dim=0)
# NxF
M_q = self.lin_q(X_q)
# ExF
M_q = M_q[edge_index[0].tolist()]
score = self.gat_att(torch.cat((M_q, x_pool_j), dim=-1))
score = F.leaky_relu(score, self.negative_slope)
score = softmax(score, edge_index[0], num_nodes=num_nodes.sum())
# Sample attention coefficients stochastically.
score = F.dropout(score, p=self.dropout_att, training=self.training)
# ExF
v_j = x_j * score.view(-1, 1)
#---Aggregation---
# NxF
out = scatter_add(v_j, edge_index[0], dim=0)
#---Cluster Selection
# Nx1
fitness = torch.sigmoid(self.gnn_score(x=out, edge_index=edge_index)).view(-1)
perm = topk(x=fitness, ratio=self.ratio, batch=batch)
x = out[perm] * fitness[perm].view(-1, 1)
#---Maintaining Graph Connectivity
batch = batch[perm]
edge_index, edge_weight = graph_connectivity(
device = x.device,
perm=perm,
edge_index=edge_index,
edge_weight=edge_weight,
score=score,
ratio=self.ratio,
batch=batch,
N=N)
return x, edge_index, edge_weight, batch, perm
def __repr__(self):
return '{}({}, ratio={})'.format(self.__class__.__name__, self.in_channels, self.ratio)