Official implementation of "Knowledge Enhanced Conditional Imputation for Healthcare Time-series"
CSAI (Conditional Self-Attention Imputation) is a novel recurrent neural network architecture designed to handle complex missing data patterns in multivariate time series from electronic health records (EHRs). Key features:
- Domain-informed temporal decay mechanism adapted to clinical data recording patterns
- Attention-based hidden state initialization for capturing long and short-range dependencies
- Non-uniform masking strategy that reflects real-world missingness patterns
# Clone the repository
git clone https://github.com/LinglongQian/CSAI.git
cd CSAI
# Create conda environment
conda env create -f csai.yml
# Activate environment
conda activate csai
Or
# Clone the repository
git clone https://github.com/LinglongQian/CSAI.git
cd CSAI
python -m venv csai
source csai/bin/activate
# Install dependencies
pip install -r requirements.txt
The implementation supports three healthcare benchmark datasets:
-
PhysioNet Challenge 2012
-
MIMIC-III
- Requires credentialed access
-
eICU
- Available after registration
Place the downloaded data in the data/
directory following this structure:
data/
├── physionet_raw/
├── mimic_59f_raw/
└── eicu_raw/
python data_process.py \
--data_dir ./data/physionet_raw \
--output_dir ./data/physionet \
--n_splits 5 \
--seed 3407
python main.py \
--dataset physionet \
--model_name CSAI \
--task I \ # I for imputation, C for classification
--gpu_id 0 \
--epoch 300 \
--lr 0.0005 \
--batchsize 64
Hardware:
--gpu_id GPU device ID
--seed Random seed for reproducibility
Model:
--model_name Model architecture [CSAI, Brits, GRUD, BVRIN, MRNN]
--hiddens Hidden layer size
--channels Number of channels
--step_channels Step channels for transformer
Training:
--task Task type [I: Imputation, C: Classification]
--epoch Number of training epochs
--lr Learning rate
--batchsize Batch size
--weight_decay Weight decay factor
Loss Weights:
--imputation_weight Weight for imputation loss
--classification_weight Weight for classification loss (task C)
--consistency_weight Weight for consistency loss
Process and analyze training results:
python result_process.py \
--log_dir ./log \
--dataset physionet \
--key_pattern bets_valid \
--results_dir results
--log_dir Root directory containing experiment logs
--dataset Dataset name [physionet, mimic_59f, eicu]
--key_pattern Pattern to match in result keys
--results_dir Directory to save processed results
@article{qian2023knowledge,
title={Knowledge Enhanced Conditional Imputation for Healthcare Time-series},
author={Qian, Linglong and Raj, Joseph Arul and Ellis, Hugh Logan and Zhang, Ao and Zhang, Yuezhou and Wang, Tao and Dobson, Richard JB and Ibrahim, Zina},
journal={arXiv preprint arXiv:2312.16713},
year={2023}
}
This project is licensed under the MIT License - see the LICENSE file for details.
This research was supported by [funding details from paper]. The codebase builds upon several excellent repositories: