Skip to content

Latest commit

 

History

History
89 lines (76 loc) · 2.74 KB

README.md

File metadata and controls

89 lines (76 loc) · 2.74 KB

DCBM README

This is an implementation of the IEEE TPAMI paper The Decoupling Concept Bottleneck Model (DCBM). The intervention/rectification and vision-language-model (VLM) parts are being refined and will be available soon. Please follow the previous version if you require the intervention/rectification part.

DCBM Pipeline. (A) DCBM for prediction and interpretation. (B) DCBM for human-machine interaction, including forward intervention and backward rectification.

Citation

If you find our paper/code useful in your research, welcome to cite our work

@article{zhang2024decoupling,
  author={Zhang, Rui and Du, Xingbo and Yan, Junchi and Zhang, Shihua},
  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 
  title={The Decoupling Concept Bottleneck Model}, 
  year={2024},
  pages={1-16},
  doi={10.1109/TPAMI.2024.3489597}}

To-do list

  • concept/label learning
  • the introduction of data preprocessing
  • forward intervention and backward rectification
  • VLM-based DCBM

Prerequisites

  • Please run pip install -r requirements.txt to achieve the environment.
  • This repo is executed under torch=2.4.0+cu118 and pytorch-lightning=2.3.3. Please find the suitable versions of torch and pytorch-lightning.

Datasets

We use CUB, Derm7pt, and CelebA in this repository. Please refer to:

CUB

Original Dataset, Processed Dataset.

Derm7pt

Original Dataset.

CelebA

Original Dataset.

Usage

We provide several tools in this repository. Take CUB as an example, we have:

Concept and Label Prediction

Train

python main.py -d CUB -seed 0

Inference

python main.py -d CUB_test -seed 0

To be continued...

Directory

|-- README.md
|-- main.py
|-- requirements.txt
|-- configs
    |-- CUB.yaml
    |-- CUB_test.yaml
    |-- Derm7pt.yaml
    |-- Derm7pt_test.yaml
    |-- celeba.yaml
    |-- celeba_test.yaml
|-- data
    |-- __init__.py
    |-- CUB.py
    |-- Derm7pt.py
    |-- celeba.py
    |-- data_interface.py
    |-- data_utils.py
|-- images
    |-- DCBM_pipeline.png
|-- models
    |-- __init__.py
    |-- dcbm.py
    |-- model_interface.py
    |-- template_model.py
|-- saves
    |-- celeba_imbalance.pth
|-- utils
    |-- __init__.py
    |-- analysis.py
    |-- base_utils.py
    |-- config.py