diff --git a/README.md b/README.md index d6c5b6b..edbde67 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This repository contains a Pytorch Geometric implementation of the [Gravity-Inspired Graph Autoencoders for Directed Link Prediction](https://doi.org/10.1145/3357384.3358023) paper. -With resepect to the [original implementation](https://github.com/deezer/gravity_graph_autoencoders), it has the following limitations: +With respect to the [original implementation](https://github.com/deezer/gravity_graph_autoencoders), it has the following limitations: - Only the Cora dataset has currently been tested (but support for the other datasets should be already there); - Only the Gravity-GAE, Gravity-VGAE, SourceTarget-GAE and SourceTarget-VGAE are implemented. diff --git a/gravity_gae/Convolution.py b/gravity_gae/Convolution.py index fa2aed0..1261a8c 100644 --- a/gravity_gae/Convolution.py +++ b/gravity_gae/Convolution.py @@ -2,13 +2,17 @@ import torch_sparse from torch_geometric.nn import MessagePassing - -# This conv expects self loops to have already been added, and that the rows of the adjm are NOT normalized by the inverse out-degree +1. Such normalization will be done using the "mean" aggregation inherited from MessagePassing The adjm in question should be the transpose of the usual adjm. class Conv(MessagePassing): + """ + This convolutional layer expects self-loops to have already been added, and + that the rows of the adjacency matrix are NOT normalized by the inverse out-degree +1. + Such normalization will be done using the "mean" aggregation inherited from MessagePassing. + The adjacency matrix in question should be the transpose of the usual adjacency matrix. + """ def __init__(self, in_channels, out_channels): - super().__init__(aggr = "mean") - self.W = Linear(in_channels, out_channels, bias = False) + super().__init__(aggr="mean") + self.W = Linear(in_channels, out_channels, bias=False) def message_and_aggregate(self, adj_t, x): return torch_sparse.matmul(adj_t, x, reduce=self.aggr) @@ -18,13 +22,9 @@ def message(self, x_j): def forward(self, x, edge_index): transformed = self.W(x) - - transformed_aggregated_normalized = self.propagate(edge_index, x = transformed) - + transformed_aggregated_normalized = self.propagate(edge_index, x=transformed) return transformed_aggregated_normalized def reset_parameters(self): super().reset_parameters() self.W.reset_parameters() - - \ No newline at end of file diff --git a/gravity_gae/GNN.py b/gravity_gae/GNN.py index d44540b..597efe4 100644 --- a/gravity_gae/GNN.py +++ b/gravity_gae/GNN.py @@ -7,7 +7,9 @@ class LayerWrapper(Module): - def __init__(self, layer, normalization_before_activation = None, activation = None, normalization_after_activation =None, dropout_p = None, _add_remaining_self_loops = False, uses_sparse_representation = False ): + def __init__(self, layer, normalization_before_activation=None, activation=None, + normalization_after_activation=None, dropout_p=None, + _add_remaining_self_loops=False, uses_sparse_representation=False): super().__init__() self.activation = activation self.normalization_before_activation = normalization_before_activation @@ -18,82 +20,71 @@ def __init__(self, layer, normalization_before_activation = None, activation = N self.uses_sparse_representation = uses_sparse_representation def forward(self, batch): - - new_batch = copy.copy(batch) if self._add_remaining_self_loops and not self.uses_sparse_representation: - new_batch.edge_index, _ = add_remaining_self_loops(new_batch.edge_index) + new_batch.edge_index, _ = add_remaining_self_loops(new_batch.edge_index) elif self._add_remaining_self_loops and self.uses_sparse_representation: new_batch.adj_t = torch_sparse.fill_diag(new_batch.adj_t, 2) if not self.uses_sparse_representation: - new_batch.x = self.layer(x = new_batch.x, edge_index = new_batch.edge_index) + new_batch.x = self.layer(x=new_batch.x, edge_index=new_batch.edge_index) else: - - new_batch.x = self.layer(x = new_batch.x, edge_index = new_batch.adj_t) + new_batch.x = self.layer(x=new_batch.x, edge_index=new_batch.adj_t) if self.normalization_before_activation is not None: new_batch.x = self.normalization_before_activation(new_batch.x) if self.activation is not None: new_batch.x = self.activation(new_batch.x) if self.normalization_after_activation is not None: - new_batch.x = self.normalization_after_activation(new_batch.x) + new_batch.x = self.normalization_after_activation(new_batch.x) if self.dropout_p is not None: - new_batch.x = dropout(new_batch.x, p=self.dropout_p, training=self.training) - + new_batch.x = dropout(new_batch.x, p=self.dropout_p, training=self.training) return new_batch - - - + class GNN_FB(Module): - def __init__(self, gnn_layers, preprocessing_layers = [], postprocessing_layers = []): + def __init__(self, gnn_layers, preprocessing_layers=[], postprocessing_layers=[]): super().__init__() + self.net = torch.nn.Sequential(*preprocessing_layers, *gnn_layers, *postprocessing_layers) - self.net = torch.nn.Sequential(*preprocessing_layers, *gnn_layers, *postprocessing_layers ) - - def forward(self, batch): - return self.net(batch) - - + return self.net(batch) class LinkPropertyPredictorGravity(Module): - def __init__(self, l, EPS = 1e-2, CLAMP = None, train_l = True): + def __init__(self, l, EPS=1e-2, CLAMP=None, train_l=True): super().__init__() self.l_initialization = l - self.l = Parameter(torch.tensor([l]), requires_grad = train_l ) + self.l = Parameter(torch.tensor([l]), requires_grad=train_l) self.EPS = EPS self.CLAMP = CLAMP - def forward(self, batch): - new_batch = copy.copy(batch) + def forward(self, batch): + new_batch = copy.copy(batch) if batch.edge_label_index in ["general", "biased", "bidirectional"]: - m_i = new_batch.x[:,-1].reshape(-1,1).expand((-1,new_batch.x.size(0))).t() - r = new_batch.x[:,:-1] + m_i = new_batch.x[:, -1].reshape(-1, 1).expand((-1, new_batch.x.size(0))).t() + r = new_batch.x[:, :-1] - norm = (r * r).sum(dim = 1, keepdim = True) + norm = (r * r).sum(dim=1, keepdim=True) r1r2 = torch.matmul(r, r.t()) - r2 = norm - 2*r1r2 + norm.t() + r2 = norm - 2 * r1r2 + norm.t() logr2 = torch.log(r2 + self.EPS) if self.CLAMP is not None: - logr2 = logr2.clamp(min = -self.CLAMP, max = self.CLAMP) + logr2 = logr2.clamp(min=-self.CLAMP, max=self.CLAMP) - new_batch.x = (m_i - self.l * logr2).reshape(-1) + new_batch.x = (m_i - self.l * logr2).reshape(-1) else: + m_j = new_batch.x[new_batch.edge_label_index[1, :], -1] - m_j = new_batch.x[new_batch.edge_label_index[1,:],-1] + diff = new_batch.x[new_batch.edge_label_index[0, :], :-1] - new_batch.x[new_batch.edge_label_index[1, :], :-1] - diff = new_batch.x[new_batch.edge_label_index[0,:], :-1] - new_batch.x[new_batch.edge_label_index[1,:], :-1] - - r2 = (diff * diff).sum(dim = 1) + r2 = (diff * diff).sum(dim=1) new_batch.x = m_j - self.l * torch.log(r2 + self.EPS) return new_batch @@ -102,33 +93,28 @@ def reset_parameters(self): self.l.data = torch.tensor([self.l_initialization]).to(self.l.data.device) - - class LinkPropertyPredictorSourceTarget(Module): def __init__(self): super().__init__() - def forward(self, batch): - new_batch = copy.copy(batch) + def forward(self, batch): + new_batch = copy.copy(batch) hidden_dimension = batch.x.size(1) - half_dimension = int(hidden_dimension/2) + half_dimension = int(hidden_dimension / 2) if batch.edge_label_index in ["general", "biased", "bidirectional"] and self.training: - source = batch.x[:, :half_dimension] target = batch.x[:, half_dimension:] new_batch.x = torch.matmul(source, target.t()).reshape(-1) else: + new_batch.x = (new_batch.x[new_batch.edge_label_index[0, :], :half_dimension] * + new_batch.x[new_batch.edge_label_index[1, :], half_dimension:]).sum(dim=1).reshape(-1) - new_batch.x = (new_batch.x[new_batch.edge_label_index[0,:], :half_dimension] * new_batch.x[new_batch.edge_label_index[1,:], half_dimension:]).sum(dim = 1).reshape(-1) - return new_batch - - - + class ParallelLayerWrapper(Module): def __init__(self, layerwrappers): @@ -137,6 +123,3 @@ def __init__(self, layerwrappers): def forward(self, batch): return [layerwrapper(batch) for layerwrapper in self.layerwrappers] - - - \ No newline at end of file diff --git a/gravity_gae/VGAE.py b/gravity_gae/VGAE.py index a04bd56..142df09 100644 --- a/gravity_gae/VGAE.py +++ b/gravity_gae/VGAE.py @@ -1,12 +1,10 @@ - import torch import copy from torch.nn import Module class VGAE_Reparametrization(Module): - def __init__(self, MAX_LOGSTD = None, num_noise_samples = 100): + def __init__(self, MAX_LOGSTD=None, num_noise_samples=100): super().__init__() - self.MAX_LOGSTD = MAX_LOGSTD self.num_noise_samples = num_noise_samples @@ -14,16 +12,13 @@ def forward(self, mu_logstd): new_batch = copy.copy(mu_logstd[0]) mu_batch, logstd_batch = mu_logstd - - mu, logstd = mu_batch.x, logstd_batch.x if self.MAX_LOGSTD is not None: - logstd = logstd.clamp(max = self.MAX_LOGSTD) - + logstd = logstd.clamp(max=self.MAX_LOGSTD) if self.training: - new_batch.x = mu + torch.randn_like(logstd) * torch.exp(logstd) + new_batch.x = mu + torch.randn_like(logstd) * torch.exp(logstd) else: new_batch.x = mu diff --git a/gravity_gae/custom_losses.py b/gravity_gae/custom_losses.py index bb75b6b..e3ebc6e 100644 --- a/gravity_gae/custom_losses.py +++ b/gravity_gae/custom_losses.py @@ -5,19 +5,13 @@ import copy from sklearn.metrics import average_precision_score, roc_auc_score - -def recon_loss(logits, ground_truths, EPS = 1e-10): - - pos_mask = (ground_truths == 1.) - - pos_loss = -torch.log( sigmoid(logits[pos_mask]) + EPS ).mean() - - neg_loss = -torch.log(1. - sigmoid(logits[~pos_mask]) + EPS).mean() +def recon_loss(logits, ground_truths, EPS=1e-10): + pos_mask = (ground_truths == 1.0) + pos_loss = -torch.log(sigmoid(logits[pos_mask]) + EPS).mean() + neg_loss = -torch.log(1.0 - sigmoid(logits[~pos_mask]) + EPS).mean() return pos_loss + neg_loss - - def hitsk(model, test_data_split, k): test_data_split_pos = copy.copy(test_data_split) test_data_split_pos.edge_label_index = test_data_split.pos_edge_label_index @@ -25,30 +19,30 @@ def hitsk(model, test_data_split, k): test_data_split_neg = copy.copy(test_data_split) test_data_split_neg.edge_label_index = test_data_split.neg_edge_label_index - return compute_hitsk(model(test_data_split_pos).x, model(test_data_split_neg).x, k ) + return compute_hitsk(model(test_data_split_pos).x, model(test_data_split_neg).x, k) def compute_hitsk(y_pred_pos, y_pred_neg, k): - tot = (y_pred_pos > torch.sort(y_pred_neg, descending = True)[0][k]).sum() + tot = (y_pred_pos > torch.sort(y_pred_neg, descending=True)[0][k]).sum() return tot / y_pred_pos.size(0) - - def average_precision(model, test_data): - return average_precision_score(test_data.edge_label.cpu().detach().numpy(), sigmoid(model(test_data).x).cpu().detach().numpy()) #avp.compute() - + preds = sigmoid(model(test_data).x).cpu().detach().numpy() + labels = test_data.edge_label.cpu().detach().numpy() + return average_precision_score(labels, preds) def auc_loss(logits, ground_truths): - return 1. - roc_auc_score(ground_truths.cpu(), logits.x.cpu()) + return 1.0 - roc_auc_score(ground_truths.cpu(), logits.x.cpu()) def ap_loss(logits, ground_truths): - return 1. - average_precision_score(ground_truths.cpu().detach().numpy(), sigmoid(logits.x).cpu().detach().numpy()) + preds = sigmoid(logits.x).cpu().detach().numpy() + labels = ground_truths.cpu().detach().numpy() + return 1.0 - average_precision_score(labels, preds) def losses_sum_closure(losses): - - return lambda logits, ground_truths: np.sum([loss(logits, ground_truths) for loss in losses]) + return lambda logits, ground_truths: np.sum([loss(logits, ground_truths) for loss in losses]) class StandardLossWrapper(Module): - def __init__(self, norm, loss): + def __init__(self, norm, loss): super().__init__() self.loss = loss self.norm = norm @@ -56,10 +50,9 @@ def __init__(self, norm, loss): def forward(self, batch, ground_truth): return self.norm * self.loss(batch.x, ground_truth) - -def kl_loss(mu , logstd): - return -0.5 * torch.mean(torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1)) - +def kl_loss(mu, logstd): + kld = 1 + 2 * logstd - mu.pow(2) - logstd.exp().pow(2) + return -0.5 * torch.mean(torch.sum(kld, dim=1)) class VGAELossWrapper(Module): def __init__(self, norm, loss): @@ -67,10 +60,7 @@ def __init__(self, norm, loss): self.loss = loss self.norm = norm - def forward(self, batch, ground_truth): - - # x_nan_idxs = torch.where(torch.isnan(batch.x) == 1)[0].tolist() - - return self.norm * self.loss(batch.x, ground_truth) + (0.5 / batch.x.size(0)) * kl_loss(batch.mu, batch.logstd) - + loss_val = self.norm * self.loss(batch.x, ground_truth) + kl_val = (0.5 / batch.x.size(0)) * kl_loss(batch.mu, batch.logstd) + return loss_val + kl_val diff --git a/gravity_gae/methods.py b/gravity_gae/methods.py index f041dc1..e459b3c 100644 --- a/gravity_gae/methods.py +++ b/gravity_gae/methods.py @@ -18,5 +18,4 @@ from preprocessing import * from input_data import * from models import * -from data_loaders import * - +from data_loaders import * \ No newline at end of file diff --git a/gravity_gae/utils.py b/gravity_gae/utils.py index 5e3d777..dcd2aa8 100644 --- a/gravity_gae/utils.py +++ b/gravity_gae/utils.py @@ -1,4 +1,3 @@ - import numpy as np from math import log10, floor from pandas import DataFrame @@ -6,53 +5,44 @@ # Model utilities def reset_parameters(module): for layer in module.children(): - # print(f"layer 1= {layer}") - if hasattr(layer, 'reset_parameters'): - print(f"resetting {layer}") - layer.reset_parameters() - elif len(list(layer.children())) > 0: - reset_parameters(layer) - + if hasattr(layer, 'reset_parameters'): + print(f"resetting {layer}") + layer.reset_parameters() + elif len(list(layer.children())) > 0: + reset_parameters(layer) def print_model_parameters_names(model): for name, param in model.named_parameters(): if param.requires_grad: print(name, param.data) - - - def summarize_link_prediction_evaluation(performances): mean_std_dict = {} for metric in ['AUC', 'F1', 'hitsk', 'AP']: - vals = [] - for run in performances: - vals.append(run[metric]) - - mean_std_dict[metric] = {"mean":np.nanmean( list(filter(None, vals)) ), "std": np.nanstd( list(filter(None, vals)) )} - + vals = [run[metric] for run in performances] + filtered_vals = list(filter(None, vals)) + mean_std_dict[metric] = { + "mean": np.nanmean(filtered_vals), + "std": np.nanstd(filtered_vals) + } return mean_std_dict - - def round_to_first_significative_digit(x): digit = -int(floor(log10(abs(x)))) return digit, round(x, digit) def pretty_print_link_performance_evaluation(mean_std_dict, model_name): performances_strings = {} - - for (metric,mean_std) in mean_std_dict.items(): + for metric, mean_std in mean_std_dict.items(): if np.isnan(mean_std["mean"]): performances_strings[metric] = str(None) elif mean_std["std"] == 0: digit, mean_rounded = round_to_first_significative_digit(mean_std["mean"]) - performances_strings[metric] = str(mean_rounded) + " +- " + str(mean_std["std"]) + performances_strings[metric] = f"{mean_rounded} +- {mean_std['std']}" else: digit, std_rounded = round_to_first_significative_digit(mean_std["std"]) mean_rounded = round(mean_std["mean"], digit) - performances_strings[metric] = str(mean_rounded) + " +- " + str(std_rounded) - + performances_strings[metric] = f"{mean_rounded} +- {std_rounded}" - df = DataFrame(performances_strings.values(), columns = [model_name], index = performances_strings.keys() ) + df = DataFrame(performances_strings.values(), columns=[model_name], index=performances_strings.keys()) return df.to_markdown(index=True) \ No newline at end of file