Hossein Jafarinia, Alireza Alipanah, Danial Hamdi, Saeed Razavi, Nahal Mirzaie, Mohammad Hossein Rohban
[arXiv
] [Project Page
] [Demo
] [BibTex
]
PyTorch implementation for the Multiple Instance Learning framework described in the paper Snuffy: Efficient Whole Slide Image Classifier (ECCV 2024, accepted).
Snuffy is a novel MIL-pooling method based on sparse transformers, designed to address the computational challenges in Whole Slide Image (WSI) classification for digital pathology. Our approach mitigates performance loss with limited pre-training and enables continual few-shot pre-training as a competitive option.
Key features:
- Tailored sparsity pattern for pathology
- Theoretically proven universal approximator with tight probabilistic sharp bounds
- Superior WSI and patch-level accuracies on CAMELYON16 and TCGA Lung cancer datasets
This repository provides a complete, runnable implementation of the Snuffy framework, including code for the FROC metric, which is unique among WSI classification frameworks to the best of our knowledge.
- Slide Patching: WSIs are divided into manageable patches.
- Self-Supervised Learning: An SSL method is trained on the patches to create an embedder.
- Feature Extraction: The embedder computes features (embeddings) for each slide.
- MIL Training: The Snuffy MIL framework is applied to the computed features.
Each step in this pipeline can be executed independently, with intermediate results available for download to facilitate continued processing.
Table of Contents
- Operating System: Ubuntu 20.04 LTS (or compatible Linux distribution)
- Python Version: 3.8 or later
- GPU: Recommended for faster processing (CUDA-compatible)
- Disk Space: Ensure you have sufficient disk space for dataset downloads and processing, especially if you intend to work with raw slides rather than pre-computed embeddings. Raw slide data can be very large.
- Hardware: The MIL training code can run on both GPU and CPU. For optimal performance, a GPU is strongly recommended.
- Amazon CLI: To download the CAMELYON16 dataset's raw whole-slide images, you'll need the AWS CLI. Install it by:
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
unzip awscliv2.zip
./aws/install
-
GDC Client (For downloading the TCGA dataset): This is automatically downloaded and installed when you use the
download_tcga_lung.sh
script. -
OpenSlide is necessary if you intend to patch the slides yourself using the
deepzoom_tiler_camelyon16.py
ordeepzoom_tiler_tcga_lung_cancer.py
scripts. Install OpenSlide with:
# Update package list and install OpenSlide
apt-get update
apt-get install openslide-tools
- The ASAP package is required for calculating the FROC
metric.
Install ASAP and its
multiresolutionimageinterface
Python package as follows:
# Download and install ASAP
wget https://github.com/computationalpathologygroup/ASAP/releases/download/ASAP-2.1/ASAP-2.1-py38-Ubuntu2004.deb
apt-get install -f "./ASAP-2.1-py38-Ubuntu2004.deb"
- Required Python packages can be installed with:
# Install Python packages from requirements.txt
pip install -r requirements.txt
Note: The requirements.txt
file includes specific package versions used and verified in our experiments. However,
newer versions available in your environment may also be compatible.
-
MAE with Adapter: Refer to the MAE repository for installation instructions.
Important: If using PyTorch versions 1.8+ , follow the instructions in the MAE repository to fix compatibility issue with the
timm
module. Alternatively, run the following script to fix the issue.chmod +x requirements_timm_patch.sh ./requirements_timm_patch.sh
Note that we've also included a modified version of timm, to support adapter functionality.
-
List and Download Dataset: Run the following commands to list and download the CAMELYON16 dataset:
aws s3 ls --no-sign-request s3://camelyon-dataset/CAMELYON16/ --recursive aws s3 cp --no-sign-request s3://camelyon-dataset/CAMELYON16/ raw_data/camelyon16 --recursive
-
Directory Structure: After downloading, your
raw_data/camelyon16
directory should look like this:-- camelyon16 |-- README.md |-- annotations |-- background_tissue |-- checksums.md5 |-- evaluation |-- images |-- license.txt |-- masks `-- pathology-tissue-background-segmentation.json
-
Organize Files:
Use the provided script to copy the necessary files into thedatasets/camelyon16
directory. If space is limited, modify the script to move files instead of copying them.python move_camelyon16_tifs.py
-
Final Directory Structure:
datasets/camelyon16 |-- annotations | |-- test_001.xml | |-- tumor_001.xml | |-- ... |-- masks | |-- normal_001_mask.tif | |-- test_001_mask.tif | |-- tumor_001_mask.tif | |-- ... |-- 0_normal | |-- normal_004.tif | |-- test_018.tif | |-- ... |-- 1_tumor | |-- test_046.tif | |-- tumor_075.tif | |-- ... |-- reference.csv |-- n_shot_dataset_maker.py |-- train_validation_test_reverse_camelyon.py `-- train_validation_test_splitter_camelyon.py
To download the TCGA Lung Cancer dataset, run the following script. This will download the slides listed in
the LUAD manifest
and LUSC manifest to the datasets/tcga/{luad, lusc}
directory. Each slide will be stored in its own directory, named according to its ID in the manifest.
chmod +x download_dataset.sh
./download_tcga_lung.sh
Download the MIL datasets (sourced from the DSMIL project) and unzip them into the datasets/ directory.
wget https://uwmadison.box.com/shared/static/arvv7f1k8c2m8e2hugqltxgt9zbbpbh2.zip
unzip mil-dataset.zip -d datasets/
This script processes TIFF slides located in datasets/camelyon16/{0_normal, 1_tumor}/
. For each slide, it creates a
directory at datasets/camelyon16/single/{0_normal, 1_tumor}/{slide_name}
, saving the extracted patches as JPEG images.
python deepzoom_tiler_camelyon16.py
This script processes SVS slides in datasets/tcga/{lusc, luad}/
and saves the extracted patches in
datasets/tcga/single/{lusc, luad}/{slide_name}
as JPEG images.
python deepzoom_tiler_tcga_lung_cancer.py
For both scripts, please refer to their arguments for detailed information on the script's arguments and their functionalities.
To split the CAMELYON16 dataset:
cd datasets/camelyon16
python train_validation_test_splitter_camelyon.py
This script reorganizes the directory structure from:
datasets/camelyon16/single/{0_normal, 1_tumor}
to:
datasets/camelyon16/single/fold1/{train, validation, test}/{0_normal, 1_tumor}
The official CAMELYON16 test set is used for testing, while the remaining data is randomly split into training and validation sets with an 80/20 ratio. You can adjust the fold number directly in the script.
To reverse the CAMELYON16 split:
cd datasets/camelyon16
python train_validation_test_reverse_camelyon.py
The processed and shuffled datasets are saved with filenames that reflect the dataset name, fold count, and split ratio.
The fold_generator.py
script creates K-Fold cross-validation splits for the TCGA data, ensuring that a single
patient's slides are not divided across multiple splits. It uses the patients.csv
reference file and stores the fold
information in datasets/tcga/folds/fold_{i}.csv
.
To run the K-Fold split:
cd datasets/tcga
python fold_generator.py
After generating folds, use the train_validation_test_splitter_tcga.py
script to organize the directories according to
a selected fold:
python train_validation_test_splitter_tcga.py
This script reorganizes the directory structure from:
datasets/tcga/single/{0_luad, 1_lusc}
to:
datasets/tcga/single/fold{i}/{train, validation, test}/{0_luad, 1_lusc}
To reverse the TCGA split and restore the original directory structure:
cd datasets/tcga
python train_validation_test_reverse_tcga.py
The mil_cross_validation.py script loads and processes MIL datasets downloaded in the previous step (Musk1, Musk2, Elephant) into a format compatible with Snuffy. It then performs cross-validation, ensuring each fold contains both negative and positive bags.
cd datasets/mil_dataset
# python mil_cross_validation.py --dataset [Musk1, Musk2, Elephant] --num_folds [10] --train_valid_ratio [0.2]
python mil_cross_validation.py --dataset Musk1
To create a 50-Shot patch dataset (a dataset containing at most n patches of each WSI):
cd datasets/camelyon16
python n_shot_dataset_maker.py --shots=50
This will create a new folder named single/fold1_50shot
based on the dataset in single/fold1
. In this new folder,
each
slide will have at most 50 patches (or all patches if the original number is less than 50).
cd datasets/tcga
python n_shot_dataset_maker_tcga.py --shots 5
Method | Instructions | Embedder Weights | Embeddings |
---|---|---|---|
SimCLR (From Scratch) | Refer to DSMIL | Weights | Embeddings |
DINO (From Scratch) | Refer to DINO (And use a ViT-S/16) | Weights | Embeddings |
DINO (with Adapter) | Refer to DINO with Adapter Section | Weights | Embeddings |
MAE (with Adapter) | Refer to MAE with Adapter Section | Weights | Embeddings |
Download DINO ImageNet-1K Pretrained ViT-S8 full wights:
wget https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_full_checkpoint.pth
Continue pretraining with DINO Adapter:
python dino_adapter/main_dino_adapter.py \
--adapter_ffn_scalar=10 \
--arch=vit_small \
--batch_size_per_gpu=16 \
--clip_grad=3 \
--data_path_train=datasets/camelyon16/single/fold1_50shot/train \
--data_path_valid=datasets/camelyon16/single/fold1_50shot/validation \
--epochs=100 \
--ffn_num=32 \
--freeze_last_layer=0 \
--full_checkpoint=dino_deitsmall8_pretrain_full_checkpoint.pth \
--lr__warmup_epochs__minlr="[0.0005, 10, 1e-06]" \
--momentum_teacher=0.9995 \
--norm_last_layer=False \
--output_dir=out \
--patch_size=8 \
--random_head=1 \
--teacher_temp__warmup_teacher_temp_epochs="[0.04, 0]" \
--warmup_teacher_temp=0.04 \
--weight_decay__weight_decay_end="[0.04, 0.4]"
Download MAE ImageNet-1K Pretrained ViT-S8 full wights:
wget https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base_full.pth
Continue pretraining with MAE Adapter:
torchrun main_pretrain_adapter.py \
--accum_iter=1 \
--adapter_ffn_scalar=1 \
--blr__min_lr__warmup_epochs="[0.001, 0, 40]" \
--data_path=datasets/camelyon16/single/fold1_200shot \
--epochs=400 \
--full_checkpoint=mae_pretrain_vit_base_full.pth \
--norm_pix_loss=0 \
--train_linears__linears_from_scratch="[1, 1]"
The compute_feats.py
script extracts features (embeddings) from a dataset using a specified embedder model. It
processes the dataset and
saves the cleaned embedder weights, feature vectors, and corresponding labels.
The dataset is expected to follow this directory structure:
datasets/
└── {dataset_name}/
├── single/
│ └── {fold}/
│ ├── train/
│ ├── validation/
│ └── test/
└── tile_label.csv
{dataset_name}
: The name of your dataset.{fold}
: The specific fold of data (e.g., fold1, fold2, ...).train/
,validation/
,test/
: Directories containing the patches for training, validation, and testing, respectively.tile_label.csv
: CSV file containing the labels for the patches, if available, created bydeepzoom_tiler
.
The script saves the outputs in the following directory structure:
embeddings/
└── {embedder}_{version_name}/
└── {dataset_name}/
├── embedder.pth
├── {train, test, validation}/
│ └── {0_normal, 1_tumor}.csv
│ ├── {0_normal, 1_tumor}/
│ │ └── {slide_name}.csv
└── {dataset_name}.csv
{embedder}
: The name of the embedder model used (e.g., SimCLR).{version_name}
: The version name of the embedder model.{dataset_name}
: The name of the dataset.embedder.pth
: The cleaned embedder weights.{slide_name}.csv
: CSV file containing features[feature_0, ..., feature_511, position, label]
for each slide. Each row corresponds to a patch from the slide.{split}/{class_name}.csv
: CSV file containing[bag_path, bag_label]
for each class in each split ( train/validation/test).{dataset_name}.csv
: CSV file containing[bag_path, bag_label]
for the whole dataset.
python compute_feats.py \
--backbone=resnet18 \
--norm_layer=instance \
--weights=embedders/dsmil_simclr.pth \
--embedder=SimCLR \
--version_name=dsmil_simclr
python compute_feats.py \
--embedder=DINO \
--num_classes=2048 \
--backbone=vit_small \
--weights=embedders/dino_scratch.pth \
--version_name=dino_scratch
python compute_feats.py \
--embedder=DINO \
--num_classes=2048 \
--backbone=vit_small \
--patch_size=8 \
--weights=embedders/dino_adapter.pth \
--ffn_num=32 \
--adapter_ffn_scalar=10 \
--version_name=dino_adapter \
--use_adapter \
--transform 1
python compute_feats.py \
--embedder=MAE \
--num_classes=512 \
--backbone=mae_vit_base_patch16 \
--weights=embedders/mae_adapter.pth \
--ffn_num=64 \
--adapter_ffn_scalar=1 \
--version_name=mae_adapter \
--use_adapter \
--transform 1
python compute_feats.py \
--backbone=resnet18 \
--dataset=tcga \
--norm_layer=instance \
--weights=embedders/dsmil_simclr_tcga.pth \
--embedder=SimCLR \
--version_name=dsmil_simclr
python train.py \
--activation=relu \
--arch=snuffy \
--betas="[0.9, 0.999]" \
--big_lambda=900 \
--dataset=camelyon16 \
--embedding=DINO_dino_scratch \
--encoder_dropout=0.1 \
--feats_size=384 \
--l2normed_embeddings=1 \
--lr=0.02 \
--num_epochs=200 \
--num_heads=4 \
--optimizer=adamw \
--random_patch_share=0.7777777777777778 \
--scheduler=cosine \
--single_weight__lr_multiplier=1 \
--soft_average=0 \
--weight_decay=0.05 \
--weight_init__weight_init_i__weight_init_b="['trunc_normal', 'xavier_uniform', 'trunc_normal']"
python train.py \
--activation=relu \
--arch=snuffy \
--betas="[0.9, 0.999]" \
--big_lambda=500 \
--dataset=camelyon16 \
--embedding=DINO_dino_adapter \
--encoder_dropout=0.1 \
--feats_size=384 \
--l2normed_embeddings=1 \
--lr=0.02 \
--num_epochs=200 \
--num_heads=4 \
--optimizer=adamw \
--random_patch_share=0.5 \
--scheduler=cosine \
--single_weight__lr_multiplier=1 \
--soft_average=1 \
--weight_decay=0.05 \
--weight_init__weight_init_i__weight_init_b="['trunc_normal', 'xavier_uniform', 'trunc_normal']"
python train.py \
--activation=relu \
--arch=snuffy \
--betas="[0.9, 0.999]" \
--big_lambda=500 \
--dataset=camelyon16 \
--embedding=MAE_mae_adapter \
--encoder_dropout=0 \
--feats_size=768 \
--l2normed_embeddings=0 \
--lr=0.02 \
--num_epochs=200 \
--num_heads=4 \
--optimizer=adamw \
--random_patch_share=0.5 \
--scheduler=cosine \
--single_weight__lr_multiplier=1 \
--soft_average=1 \
--weight_decay=0.05 \
--weight_init__weight_init_i__weight_init_b="['trunc_normal', 'xavier_uniform', 'trunc_normal']"
--feats_size should match the size of features you got in Feature Extraction. --random_patch_share * --big_lambda shows the number of random patches and the rest are top patches.
For TCGA use --arch=snuffy_multiclass
.
python train.py \
--arch=snuffy \
--dataset=musk1 \
--num_heads=2 \
--cv_num_folds 10 \
--cv_valid_ratio 0.2 \
--cv_current_fold 1
- Feature Size is automatically set based on the dataset ('musk1' and 'musk2': 166, 'elephant': 230). No manual adjustment needed.
- MultiHeadAttention: Ensure the feature size is divisible by the number of heads.
- Cross-Validation: Use
mil_cross_validation.py
to generate a shuffle file ({dataset_file_name}_{num_folds}folds_{valid_ratio}split.pkl
, e.g.musk1_10folds_0.2split.pkl
). Matchargs.cv_num_folds
andargs.cv_valid_ratio
in this script to read the file correctly. Set the desired fold to train usingargs.cv_current_fold
.
In the figure below, the black line outlines the tumor area. The model's attention is represented by a color overlay, where red indicates the highest attention and blue indicates the lowest. As shown, the model effectively highlights the tumor regions.
To create heatmaps similar to the one shown above, run the following command:
python roi.py \
--batch_size=512 \
--num_workers=24 \
--embedder_weights=embedders/clean/camelyon16/SimCLR/embedder.pth \
--aggregator_weights=aggregators/snuffy_simclr_dsmil.pth \
--thres_tumor=0.75959325 \
--num_heads=2 \
--encoder_dropout=0.2 \
--k=900 \
--random_patch_share=0.7777777777777778 \
--activation=gelu \
--depth=5
The script requires the following inputs:
--embedder_weights
: Path to the embedder weights file--aggregator_weights
: Path to the aggregator weights file- Ground truth masks located in
datasets/camelyon16/masks/
- Raw TIFF slides located in
datasets/camelyon16/1_tumor/
- Name and label of slides located in
datasets/camelyon16/reference.csv
For each slide, the script generates the following outputs:
- Heatmaps saved in
roi_output/{slide_name}/cmaps/
, where:jet_slide.png
is the raw slide.jet.png
is the slide with the attention map overlay and the ground truth tumor region outlined in black.
By default, the script processes 3 slides from the CAMELYON16 test set, but you can customize the slides to process by modifying the script. Additionally, reducing the DPI setting can speed up processing.
You can download the aggregator used for creating the figure above from here.
This codebase is built upon the work of DSMIL, DINO, and MAE. We extend our gratitude to the authors for their valuable contributions.
If you find our work helpful for your research, please consider giving a star to this repository and citing the following BibTeX entry.
@misc{jafarinia2024snuffyefficientslideimage,
title={Snuffy: Efficient Whole Slide Image Classifier},
author={Hossein Jafarinia and Alireza Alipanah and Danial Hamdi and Saeed Razavi and Nahal Mirzaie and Mohammad Hossein Rohban},
year={2024},
eprint={2408.08258},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2408.08258},
}