diff --git a/pygod/detector/card.py b/pygod/detector/card.py index d53fd37..6bfb5f7 100644 --- a/pygod/detector/card.py +++ b/pygod/detector/card.py @@ -65,6 +65,9 @@ class CARD(DeepDetector): gama: float, optional The proportion of the local reconstruction in contrastive learning module. Default: ``0.5`` + alpha: float, optional + The proprotion of the community embedding in the conbine_encoder. + Default: ``0.1`` verbose : int, optional Verbosity mode. Range in [0, 3]. Larger value for printing out more log information. Default: ``0``. @@ -115,6 +118,7 @@ def __init__(self, subgraph_num_neigh=4, fp=0.6, gama=0.5, + alpha=0.1, verbose=0, save_emb=False, compile_model=False, @@ -138,6 +142,7 @@ def __init__(self, self.subgraph_num_neigh = subgraph_num_neigh self.fp = fp self.gama = gama + self.alpha = alpha def process_graph(self, data): community_adj, self.diff_data = CARDBase.process_graph(data) @@ -151,14 +156,15 @@ def init_model(self, **kwargs): self.hid_dim) return CARDBase(in_dim=self.in_dim, + subgraph_num_neigh=self.subgraph_num_neigh, fp=self.fp, gama=self.gama, + alpha=self.alpha, hid_dim=self.hid_dim, num_layers=self.num_layers, dropout=self.dropout, act=self.act, backbone=self.backbone, - subgraph_num_neigh=self.subgraph_num_neigh, **kwargs).to(self.device) def forward_model(self, data): diff --git a/pygod/nn/card.py b/pygod/nn/card.py index a1b9a37..ce6a8ef 100644 --- a/pygod/nn/card.py +++ b/pygod/nn/card.py @@ -36,6 +36,9 @@ class CARDBase(nn.Module): gama: float, optional The proportion of the local reconstruction in contrastive learning module. Default: ``0.5`` + alpha: float, optional + The proprotion of the community embedding in the conbine_encoder. + Default: ``0.1`` hid_dim : int, optional Hidden dimension of model. Default: ``64``. num_layers : int, optional @@ -54,9 +57,9 @@ class CARDBase(nn.Module): def __init__(self, in_dim, + subgraph_num_neigh=4, fp=0.6, gama=0.4, - subgraph_num_neigh=4, alpha=0.1, hid_dim=64, num_layers=4, @@ -171,18 +174,18 @@ def loss_func(self, logits, diff_logits, x_, local_x_, x, con_label): Parameters ---------- - logits : _type_ - _description_ - diff_logits : _type_ - _description_ - x_ : _type_ - _description_ - local_x_ : _type_ - _description_ - x : _type_ - _description_ - con_label : _type_ - _description_ + logits : torch.Tensor + Discriminator logits of positive subgraphs batch. + diff_logits : torch.Tensor + Discriminator logits of negative subgraphs batch. + x_ : torch.Tensor + Global reconstructed attribute embeddings. + local_x_ : torch.Tensor + Local reconstructed attribute embeddings. + x : torch.Tensor + Input attribute embeddings. + con_label : torch.Tensor + Contrastive learning pseudo label Returns ------- @@ -254,13 +257,6 @@ def _train_subgraph_network(self, data): x = subgraph.x edge_index = subgraph.edge_index - # diff_subgraphs = NeighborLoader( - # self.diff, num_neighbors=[-1] * self.num_layers) - # diff_subgraph = diff_subgraphs([index]) - # diff_subgraph.x[0, :] = 0 - # diff_x = diff_subgraph.x.to(self.device) - # diff_edge_index = diff_subgraph.edge_index.to(self.device) - ori_emb = self.encoder(x, edge_index) community_emb = self.community_encoder(community_adj) combine_emb = self.combine_encoder(