Skip to content

Commit

Permalink
Merge pull request #261 from mims-harvard/neurips_benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 authored May 14, 2024
2 parents e115138 + c281642 commit 4afd866
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 1 deletion.
128 changes: 128 additions & 0 deletions tdc/benchmark_group/counterfactual_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT
import os

from .base_group import BenchmarkGroup
from ..dataset_configs.config_map import scperturb_datasets, scperturb_gene_datasets


class CounterfactualGroup(BenchmarkGroup):
"""Create Counterfactual Group Class object. This is for single-cell counterfactual prediction tasks (drug, gene) benchmark.
Args:
path (str, optional): the path to store/retrieve the Counterfactual group datasets.
"""
_DRUG_COLS = [
"ncounts", 'celltype', 'cell_line', 'cancer', 'disease', 'tissue_type',
'perturbation', 'perturbation_type', 'ngenes'
]

_GENE_COLS = [
'UMI_count', 'cancer', 'cell_line', 'disease', 'guide_id', 'ncounts',
'ngenes', 'nperts', 'organism', 'percent_mito', 'percent_ribo',
'perturbation', 'perturbation_type', 'tissue_type'
]

def __init__(self, path="./data", file_format="csv", is_drug=True):
"""Create a Counterfactual prediction benchmark group class."""
self.name = "Coutnerfactual_Group"
self.path = os.path.join(path, self.name)
self.is_drug = is_drug
self.dataset_names = scperturb_gene_datasets if not self.is_drug else scperturb_datasets
self.file_format = file_format
self.split = None

def get_train_valid_split(self, dataset=None, only_seen=False):
"""parameters included for compatibility. this benchmark has a fixed train/test split."""
from ..multi_pred.perturboutcome import PerturbOutcome
if only_seen:
raise ValueError(
"Counterfactual does not currently support the 'only seen' split"
)
dataset = dataset or "scperturb_drug_AissaBenevolenskaya2021"
assert dataset in self.dataset_names, "{} dataset not in {}".format(
dataset, self.dataset_names)
data = PerturbOutcome(dataset)
self.split = data.get_split()
return self.split["train"], self.split["dev"]

def get_test(self):
if self.split is None:
self.get_train_valid_split()
return self.split["test"]

def evaluate(self, y_pred):
from sklearn.metrics import r2_score
y_true = self.get_test()
cols_to_drop = self._DRUG_COLS if self.is_drug else self._GENE_COLS
y_true = y_true.drop(cols_to_drop, axis=1)
y_pred = y_pred.drop(cols_to_drop, axis=1)
return r2_score(y_true, y_pred)

def evaluate_dev(
self, y_pred
): # TODO: under development; benchmark using cell line splits.. will benchmark on random split for now
from sklearn.metrics import r2_score
from numpy import average, std
assert type(
y_pred
) == dict, "evaluate() expects a dictionary with control and perturbation dataframes"
cols_to_drop = self._DRUG_COLS if self.is_drug else self._GENE_COLS
y_true = self.get_test()
# validate input predictions have the same cell lines and perturbations as ground truth
assert len(y_pred["control"]) == len(
y_true["control"]
), "input pred and ground truth defer in control row ct; {} vs {}".format(
len(y_pred["control"]), len(y_true["control"]))
assert y_pred["control"].columns == y_true["control"].columns, \
"Predictions do not match ground truth columns; lengths are:\n{}\n{}".\
format(len(y_pred["control"].columns), len(y_true["control"].columns))
assert len(y_pred["perturbations"]) == len(y_true["perturbations"]), \
"Perturbation lists do not match length; lengths are:\n{},\n{}".\
format(len(y_pred["perturbation"]), len(y_true["perturbation"]))
assert y_pred["perturbations"].columns == y_true["perturbations"].columns, \
"Perturbation columns do not match; lengths are:\n{},\n{}".\
format(len(y_pred["perturbations"].columns), len(y_true["perturbations"].columns))
cell_lines = y_pred["control"]["cell_line"].unique()
assert set(cell_lines) == set(y_true["control"]["cell_line"].unique()), \
"Control lines do not match; lengths are:\n{}\n{}".\
format(len(cell_lines),len(y_true["control"]["cell_line"].unique()))
cell_lines_perturb = y_pred["perturbations"]["cell_line"].unique()
assert set(cell_lines_perturb) == set(y_true["perturbations"]["cell_line"].unique()), \
"Cell lines with perturbations do not match; lengths are:\n{}\n{}".\
format(len(cell_lines_perturb),len(y_true["perturbations"]["cell_line"].unique()))
assert set(cell_lines) == set(cell_lines_perturb), \
"Cell lines do not match; lengths are:\n{}\n{}".\
format(len(cell_lines),len(cell_lines_perturb))
r2vec = []
for line in cell_lines:
perturbations = y_pred["perturbations"][
y_pred["perturbations"]["cell_line"] ==
line]["perturbation"].unique()
for p in perturbations:
perturbs_pred = y_pred["perturbations"][
y_pred["perturbations"]["cell_line"] == line and
y_pred["perturbations"]["perturbation"] == p]
perturbs_true = y_true["perturbations"][
y_true["perturbations"]["cell_line"] == line and
y_true["perturbations"]["perturbation"] == p]
perturbs_pred.drop(cols_to_drop, axis=1, inplace=True)
perturbs_true.drop(cols_to_drop, axis=1, inplace=True)
pred_mean = perturbs_pred.mean()
true_mean = perturbs_true.mean()
r2vec.append(r2_score(true_mean, pred_mean))
return {"mean_r2": average(r2vec), "std_r2": std(r2vec)}

def evaluate_many(self, preds):
from numpy import mean, std
if len(preds) < 5:
raise Exception(
"Run your model on at least 5 seeds to compare results and provide your outputs in preds."
)
out = dict()
preds = [self.evaluate(p) for p in preds]
out["mean_R^2"] = mean([x["mean_r2"] for x in preds])
out["std_R^2"] = mean([x["std_r2"] for x in preds])
out["seedstd_R^2"] = std([x["mean_r2"] for x in preds])
return out
89 changes: 89 additions & 0 deletions tdc/multi_pred/perturboutcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings

warnings.filterwarnings("ignore")
import numpy as np
import sys

from ..utils import print_sys
Expand All @@ -24,3 +25,91 @@ def get_DE_genes(self):

def get_dropout_genes(self):
raise ValueError("TODO")

def get_split(self,
ratios=[0.8, 0.1, 0.1],
unseen=False,
use_random=True,
random_state=42):
"""obtain train/dev/test splits for each cell_line
counterfactual prediction model is trained on a single cell line and then evaluated on same cell line
and against new cell lines
TODO: also allow for splitting by unseen perturbations
TODO: allow for evaluating within the same cell line"""
# For now, we will ensure there are no unseen perturbations
if unseen:
raise ValueError(
"Unseen perturbation splits are not yet implemented!")
df = self.get_data()
if use_random:
# just do a random split, otherwise you'll split by cell line...
from sklearn.model_selection import train_test_split
control = df[df["perturbation"] == "control"]
perturbs = df[df["perturbation"] != "control"]
train, tmp = train_test_split(perturbs,
test_size=ratios[1] + ratios[2],
random_state=random_state)
test, dev = train_test_split(tmp,
test_size=ratios[2] /
(ratios[1] + ratios[2]),
random_state=random_state)
return {
"control": control,
"train": train,
"dev": dev,
"test": test
}
cell_lines = df["cell_line"].unique()
perturbations = df["perturbation"].unique()
shuffled_cell_line_idx = np.random.permutation(len(cell_lines))
assert len(shuffled_cell_line_idx) == len(cell_lines)
assert len(shuffled_cell_line_idx) > 3

# Split indices into three parts
train_end = int(ratios[0] * len(cell_lines)) # 60% for training
dev_end = train_end + int(
ratios[1] * len(cell_lines)) # 20% for development

train_cell_line = shuffled_cell_line_idx[:train_end]
dev_cell_line = shuffled_cell_line_idx[train_end:dev_end]
test_cell_line = shuffled_cell_line_idx[dev_end:]

assert len(train_cell_line) > 0
assert len(dev_cell_line) > 0
assert len(test_cell_line) > 0
assert len(test_cell_line) > len(dev_cell_line)

train_control = df[(df["cell_line"].isin(train_cell_line)) &
(df["perturbation"] == "control")]
train_perturbations = df[(df["cell_line"].isin(train_cell_line)) &
(df["perturbation"] != "control")]

assert len(train_control) > 0
assert len(train_perturbations) > 0
assert len(train_control) <= len(train_perturbations)

dev_control = df[(df["cell_line"].isin(dev_cell_line)) &
(df["perturbation"] == "control")]
dev_perturbations = df[(df["cell_line"].isin(dev_cell_line)) &
(df["perturbation"] != "control")]

test_control = df[(df["cell_line"].isin(test_cell_line)) &
(df["perturbation"] == "control")]
test_perturbations = df[(df["cell_line"].isin(test_cell_line)) &
(df["perturbation"] != "control")]

out = {}
out["train"] = {
"control": train_control,
"perturbations": train_perturbations
}
out["dev"] = {
"control": dev_control,
"perturbations": dev_perturbations
}
out["test"] = {
"control": test_control,
"perturbations": test_perturbations
}
# TODO: currently, there will be no inter-cell-line evaluation
return out
33 changes: 32 additions & 1 deletion tdc/test/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from tdc.benchmark_group import admet_group, scdti_group
from tdc.benchmark_group import admet_group, scdti_group, counterfactual_group


def is_classification(values):
Expand Down Expand Up @@ -85,6 +85,37 @@ def test_SCDTI_benchmark(self):
assert len(many_results["f1"]
) == 2 # should include mean and standard deviation

@unittest.skip(
"counterfactual test is taking up too much memory"
) #FIXME: please run if making changes to counterfactual benchmark or core code.
def test_counterfactual(self):
from tdc.multi_pred.perturboutcome import PerturbOutcome
from tdc.dataset_configs.config_map import scperturb_datasets, scperturb_gene_datasets

test_data = PerturbOutcome("scperturb_drug_AissaBenevolenskaya2021")
group = counterfactual_group.CounterfactualGroup() # is drug
assert group.is_drug
assert set(group.dataset_names) == set(
scperturb_datasets
), "loaded datasets should be scperturb drug, but were {} vs correct: {}".format(
group.dataset_names, scperturb_datasets)
ct = len(test_data.get_data())
train, val = group.get_train_valid_split()
test = group.get_test()
control = group.split["control"]
testct = len(train) + len(val) + len(test) + len(control)
assert ct == testct, "counts between original data and the 3 splits should match: original {} vs splits {}".format(
ct, testct)
# basic test on perfect score
tst = test_data.get_split()["test"]
r2 = group.evaluate(tst)
assert r2 == 1, "comparing test to itself should have perfect R^2 score, was {}".format(
r2)
# now just check we can load sc perturb gene correctly
group_gene = counterfactual_group.CounterfactualGroup(is_drug=False)
assert not group_gene.is_drug
assert set(group_gene.dataset_names) == set(scperturb_gene_datasets)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4afd866

Please sign in to comment.