Skip to content

Commit

Permalink
Merge pull request #113 from Wollents/main
Browse files Browse the repository at this point in the history
Add CARD
  • Loading branch information
kayzliu authored Nov 14, 2024
2 parents fa36957 + c7df01e commit c84dcca
Show file tree
Hide file tree
Showing 10 changed files with 638 additions and 2 deletions.
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ CoLA 2021 GNN+AE+SSL Yes [#Liu2021Anomaly]_
GUIDE 2021 GNN+AE Yes [#Yuan2021Higher]_
CONAD 2022 GNN+AE+SSL Yes [#Xu2022Contrastive]_
GADNR 2024 GNN+AE Yes [#Roy2024Gadnr]_
CARD 2024 GNN+SSL+AE Yes [#Wang2024Card]_
================== ===== =========== =========== ========================================


Expand Down Expand Up @@ -269,3 +270,5 @@ Reference
.. [#Xu2022Contrastive] Xu, Z., Huang, X., Zhao, Y., Dong, Y., and Li, J., 2022. Contrastive Attributed Network Anomaly Detection with Data Augmentation. In Proceedings of the 26th Pacific-Asia Conference on Knowledge Discovery and Data Mining (PAKDD).
.. [#Roy2024Gadnr] Roy, A., Shu, J., Li, J., Yang, C., Elshocht, O., Smeets, J. and Li, P., 2024. GAD-NR: Graph Anomaly Detection via Neighborhood Reconstruction. In Proceedings of the 17th ACM International Conference on Web Search and Data Mining (WSDM).
.. [#Wang2024Card] Wang Y., Wang X., He C., Chen X., Luo Z., Duan L., Zuo J., 2024. Community-Guided Contrastive Learning with Anomaly-Aware Reconstruction for Anomaly Detection on Attributed Networks. Database Systems for Advanced Applications (DASFAA).
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ CoLA 2021 GNN+AE+SSL Yes :class:`pygod.detector.CoLA
GUIDE 2021 GNN+AE Yes :class:`pygod.detector.GUIDE`
CONAD 2022 GNN+AE+SSL Yes :class:`pygod.detector.CONAD`
GADNR 2024 GNN+AE Yes :class:`pygod.detector.GADNR`
CARD 2024 GNN+SSL+AE Yes :class:`pygod.detector.CARD`
================== ===== =========== =========== ==============================================


Expand Down
1 change: 1 addition & 0 deletions docs/pygod.detector.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pygod.detector
~pygod.detector.AdONE
~pygod.detector.ANOMALOUS
~pygod.detector.AnomalyDAE
~pygod.detector.CARD
~pygod.detector.CoLA
~pygod.detector.CONAD
~pygod.detector.DMGD
Expand Down
1 change: 1 addition & 0 deletions docs/pygod.nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pygod.nn

~pygod.nn.AdONEBase
~pygod.nn.AnomalyDAEBase
~pygod.nn.CARDBase
~pygod.nn.CoLABase
~pygod.nn.DMGDBase
~pygod.nn.DOMINANTBase
Expand Down
8 changes: 8 additions & 0 deletions docs/zreferences.bib
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,11 @@ @article{bandyopadhyay2020integrating
year={2020},
publisher={IOS Press BV}
}

@inproceedings{wang2024card,
author = {Wang, Yang and Wang, Xinye and He, Chengxin and Chen, Xiaocong and Luo, Zhaohang and Duan, Lei and Zuo, Jie},
title = {Community-Guided Contrastive Learning with Anomaly-Aware Reconstruction for Anomaly Detection on Attributed Networks},
booktitle = {Database Systems for Advanced Applications - 29th International Conference},
pages = {199--209},
year = {2024}
}
3 changes: 2 additions & 1 deletion pygod/detector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .adone import AdONE
from .anomalous import ANOMALOUS
from .anomalydae import AnomalyDAE
from .card import CARD
from .cola import CoLA
from .conad import CONAD
from .dmgd import DMGD
Expand All @@ -19,7 +20,7 @@
from .scan import SCAN

__all__ = [
"Detector", "DeepDetector", "AdONE", "ANOMALOUS", "AnomalyDAE", "CoLA",
"Detector", "DeepDetector", "AdONE", "ANOMALOUS", "AnomalyDAE", "CARD", "CoLA",
"CONAD", "DMGD", "DOMINANT", "DONE", "GAAN", "GADNR", "GAE", "GUIDE",
"OCGNN", "ONE", "Radar", "SCAN"
]
191 changes: 191 additions & 0 deletions pygod/detector/card.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
"""Community-Guided Contrastive Learning with Anomaly-Aware Reconstruction for
Anomaly Detection on Attributed Networks.(CARD) """
# Author: Yang Wang<[email protected]>
# License: BSD 2 clause

import torch
from torch_geometric.nn import GCN

from .base import DeepDetector
from ..nn import CARDBase


class CARD(DeepDetector):
"""
Community-Guided Contrastive Learning with Anomaly-Aware Reconstruction for
Anomaly Detection on Attributed Networks.
CARD is a contrastive learning based method and utilizes mask reconstruction and community
information to make anomalies more distinct. This model is train with contrastive loss and
local and global attribute reconstruction loss. Random neighbor sampling instead of random walk
sampling is used to sample the subgraph corresponding to each node. Since random neighbor sampling
cannot accurately control the number of neighbors for each sampling, it may run slower compared to
the method implementation in the original paper.
See:cite:`Wang2024Card` for details.
Parameters
----------
hid_dim : int, optional
Hidden dimension of model. Default: ``64``.
num_layers : int, optional
Total number of layers in model. Default: ``2``.
dropout : float, optional
Dropout rate. Default: ``0.``.
weight_decay : float, optional
Weight decay (L2 penalty). Default: ``0.``.
act : callable activation function or None, optional
Activation function if not None.
Default: ``torch.nn.functional.relu``.
backbone : torch.nn.Module
The backbone of the deep detector implemented in PyG.
Default: ``torch_geometric.nn.GCN``.
contamination : float, optional
The amount of contamination of the dataset in (0., 0.5], i.e.,
the proportion of outliers in the dataset. Used when fitting to
define the threshold on the decision function. Default: ``0.1``.
lr : float, optional
Learning rate. Default: ``0.004``.
epoch : int, optional
Maximum number of training epoch. Default: ``100``.
gpu : int
GPU Index, -1 for using CPU. Default: ``-1``.
batch_size : int, optional
Minibatch size, 0 for full batch training. Default: ``0``.
num_neigh : int, optional
Number of neighbors in sampling, -1 for all neighbors.
Default: ``-1``.
subgraph_num_neigh: int, optional
Number of neighbors in subgraph sampling for each node, Values not exceeding 4 are recommended for efficiency.
Default: ``4``.
fp: float, optional
The balance parameter between the mask autoencoder module and contrastive learning.
Default: ``0.6``
gama: float, optional
The proportion of the local reconstruction in contrastive learning module.
Default: ``0.5``
alpha: float, optional
The proprotion of the community embedding in the conbine_encoder.
Default: ``0.1``
verbose : int, optional
Verbosity mode. Range in [0, 3]. Larger value for printing out
more log information. Default: ``0``.
save_emb : bool, optional
Whether to save the embedding. Default: ``False``.
compile_model : bool, optional
Whether to compile the model with ``torch_geometric.compile``.
Default: ``False``.
**kwargs
Other parameters for the backbone.
Attributes
----------
decision_score_ : torch.Tensor
The outlier scores of the training data. Outliers tend to have
higher scores. This value is available once the detector is
fitted.
threshold_ : float
The threshold is based on ``contamination``. It is the
:math:`N \\times` ``contamination`` most abnormal samples in
``decision_score_``. The threshold is calculated for generating
binary outlier labels.
label_ : torch.Tensor
The binary labels of the training data. 0 stands for inliers
and 1 for outliers. It is generated by applying
``threshold_`` on ``decision_score_``.
emb : torch.Tensor or tuple of torch.Tensor or None
The learned node hidden embeddings of shape
:math:`N \\times` ``hid_dim``. Only available when ``save_emb``
is ``True``. When the detector has not been fitted, ``emb`` is
``None``. When the detector has multiple embeddings,
``emb`` is a tuple of torch.Tensor.
"""

def __init__(self,
hid_dim=64,
num_layers=2,
dropout=0.,
weight_decay=0.,
act=torch.nn.functional.relu,
backbone=GCN,
contamination=0.1,
lr=4e-3,
epoch=100,
gpu=-1,
batch_size=0,
num_neigh=-1,
subgraph_num_neigh=4,
fp=0.6,
gama=0.5,
alpha=0.1,
verbose=0,
save_emb=False,
compile_model=False,
**kwargs):
super(CARD, self).__init__(hid_dim=hid_dim,
num_layers=num_layers,
dropout=dropout,
weight_decay=weight_decay,
act=act,
backbone=backbone,
contamination=contamination,
lr=lr,
epoch=epoch,
gpu=gpu,
batch_size=batch_size,
num_neigh=num_neigh,
verbose=verbose,
save_emb=save_emb,
compile_model=compile_model,
**kwargs)
self.subgraph_num_neigh = subgraph_num_neigh
self.fp = fp
self.gama = gama
self.alpha = alpha

def process_graph(self, data):
community_adj, self.diff_data = CARDBase.process_graph(data)
data.community_adj = community_adj.to(self.device)
self.diff_data = self.diff_data.to(self.device)
self.diff_data.community_adj = community_adj.to(self.device)

def init_model(self, **kwargs):
if self.save_emb:
self.emb = torch.zeros(self.num_nodes,
self.hid_dim)

return CARDBase(in_dim=self.in_dim,
subgraph_num_neigh=self.subgraph_num_neigh,
fp=self.fp,
gama=self.gama,
alpha=self.alpha,
hid_dim=self.hid_dim,
num_layers=self.num_layers,
dropout=self.dropout,
act=self.act,
backbone=self.backbone,
**kwargs).to(self.device)

def forward_model(self, data):
batch_size = data.batch_size

data.x = data.x.to(self.device)
data.edge_index = data.edge_index.to(self.device)

pos_logits, neg_logits, x_, local_x_ = self.model(data)
diff_pos_logits, diff_neg_logits, _, _ = self.model(
self.diff_data)

logits = torch.cat([pos_logits[:batch_size],
neg_logits[:batch_size]])
diff_logits = torch.cat([diff_pos_logits[:batch_size],
diff_neg_logits[:batch_size]])

con_label = torch.cat([torch.ones(batch_size),
torch.zeros(batch_size)]).to(self.device)

loss, score = self.model.loss_func(
logits, diff_logits, x_, local_x_, data.x, con_label)

return loss, score.detach().cpu()
3 changes: 2 additions & 1 deletion pygod/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .adone import AdONEBase
from .anomalydae import AnomalyDAEBase
from .card import CARDBase
from .cola import CoLABase
from .dmgd import DMGDBase
from .dominant import DOMINANTBase
Expand All @@ -15,6 +16,6 @@
from . import functional

__all__ = [
"AdONEBase", "AnomalyDAEBase", "CoLABase", "DMGDBase", "DOMINANTBase",
"AdONEBase", "AnomalyDAEBase", "CARDBase", "CoLABase", "DMGDBase", "DOMINANTBase",
"DONEBase", "GAANBase", "GADNRBase", "GAEBase", "GUIDEBase", "OCGNNBase"
]
Loading

0 comments on commit c84dcca

Please sign in to comment.