From c28164258589e5156bd8dbb7816be659517714a7 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Mon, 13 May 2024 20:32:42 -0400 Subject: [PATCH] complete counterfactual benchmark --- tdc/benchmark_group/counterfactual_group.py | 128 ++++++++++++++++++++ tdc/multi_pred/perturboutcome.py | 89 ++++++++++++++ tdc/test/test_benchmark.py | 33 ++++- 3 files changed, 249 insertions(+), 1 deletion(-) create mode 100644 tdc/benchmark_group/counterfactual_group.py diff --git a/tdc/benchmark_group/counterfactual_group.py b/tdc/benchmark_group/counterfactual_group.py new file mode 100644 index 00000000..49efbc2e --- /dev/null +++ b/tdc/benchmark_group/counterfactual_group.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# Author: TDC Team +# License: MIT +import os + +from .base_group import BenchmarkGroup +from ..dataset_configs.config_map import scperturb_datasets, scperturb_gene_datasets + + +class CounterfactualGroup(BenchmarkGroup): + """Create Counterfactual Group Class object. This is for single-cell counterfactual prediction tasks (drug, gene) benchmark. + + Args: + path (str, optional): the path to store/retrieve the Counterfactual group datasets. + """ + _DRUG_COLS = [ + "ncounts", 'celltype', 'cell_line', 'cancer', 'disease', 'tissue_type', + 'perturbation', 'perturbation_type', 'ngenes' + ] + + _GENE_COLS = [ + 'UMI_count', 'cancer', 'cell_line', 'disease', 'guide_id', 'ncounts', + 'ngenes', 'nperts', 'organism', 'percent_mito', 'percent_ribo', + 'perturbation', 'perturbation_type', 'tissue_type' + ] + + def __init__(self, path="./data", file_format="csv", is_drug=True): + """Create a Counterfactual prediction benchmark group class.""" + self.name = "Coutnerfactual_Group" + self.path = os.path.join(path, self.name) + self.is_drug = is_drug + self.dataset_names = scperturb_gene_datasets if not self.is_drug else scperturb_datasets + self.file_format = file_format + self.split = None + + def get_train_valid_split(self, dataset=None, only_seen=False): + """parameters included for compatibility. this benchmark has a fixed train/test split.""" + from ..multi_pred.perturboutcome import PerturbOutcome + if only_seen: + raise ValueError( + "Counterfactual does not currently support the 'only seen' split" + ) + dataset = dataset or "scperturb_drug_AissaBenevolenskaya2021" + assert dataset in self.dataset_names, "{} dataset not in {}".format( + dataset, self.dataset_names) + data = PerturbOutcome(dataset) + self.split = data.get_split() + return self.split["train"], self.split["dev"] + + def get_test(self): + if self.split is None: + self.get_train_valid_split() + return self.split["test"] + + def evaluate(self, y_pred): + from sklearn.metrics import r2_score + y_true = self.get_test() + cols_to_drop = self._DRUG_COLS if self.is_drug else self._GENE_COLS + y_true = y_true.drop(cols_to_drop, axis=1) + y_pred = y_pred.drop(cols_to_drop, axis=1) + return r2_score(y_true, y_pred) + + def evaluate_dev( + self, y_pred + ): # TODO: under development; benchmark using cell line splits.. will benchmark on random split for now + from sklearn.metrics import r2_score + from numpy import average, std + assert type( + y_pred + ) == dict, "evaluate() expects a dictionary with control and perturbation dataframes" + cols_to_drop = self._DRUG_COLS if self.is_drug else self._GENE_COLS + y_true = self.get_test() + # validate input predictions have the same cell lines and perturbations as ground truth + assert len(y_pred["control"]) == len( + y_true["control"] + ), "input pred and ground truth defer in control row ct; {} vs {}".format( + len(y_pred["control"]), len(y_true["control"])) + assert y_pred["control"].columns == y_true["control"].columns, \ + "Predictions do not match ground truth columns; lengths are:\n{}\n{}".\ + format(len(y_pred["control"].columns), len(y_true["control"].columns)) + assert len(y_pred["perturbations"]) == len(y_true["perturbations"]), \ + "Perturbation lists do not match length; lengths are:\n{},\n{}".\ + format(len(y_pred["perturbation"]), len(y_true["perturbation"])) + assert y_pred["perturbations"].columns == y_true["perturbations"].columns, \ + "Perturbation columns do not match; lengths are:\n{},\n{}".\ + format(len(y_pred["perturbations"].columns), len(y_true["perturbations"].columns)) + cell_lines = y_pred["control"]["cell_line"].unique() + assert set(cell_lines) == set(y_true["control"]["cell_line"].unique()), \ + "Control lines do not match; lengths are:\n{}\n{}".\ + format(len(cell_lines),len(y_true["control"]["cell_line"].unique())) + cell_lines_perturb = y_pred["perturbations"]["cell_line"].unique() + assert set(cell_lines_perturb) == set(y_true["perturbations"]["cell_line"].unique()), \ + "Cell lines with perturbations do not match; lengths are:\n{}\n{}".\ + format(len(cell_lines_perturb),len(y_true["perturbations"]["cell_line"].unique())) + assert set(cell_lines) == set(cell_lines_perturb), \ + "Cell lines do not match; lengths are:\n{}\n{}".\ + format(len(cell_lines),len(cell_lines_perturb)) + r2vec = [] + for line in cell_lines: + perturbations = y_pred["perturbations"][ + y_pred["perturbations"]["cell_line"] == + line]["perturbation"].unique() + for p in perturbations: + perturbs_pred = y_pred["perturbations"][ + y_pred["perturbations"]["cell_line"] == line and + y_pred["perturbations"]["perturbation"] == p] + perturbs_true = y_true["perturbations"][ + y_true["perturbations"]["cell_line"] == line and + y_true["perturbations"]["perturbation"] == p] + perturbs_pred.drop(cols_to_drop, axis=1, inplace=True) + perturbs_true.drop(cols_to_drop, axis=1, inplace=True) + pred_mean = perturbs_pred.mean() + true_mean = perturbs_true.mean() + r2vec.append(r2_score(true_mean, pred_mean)) + return {"mean_r2": average(r2vec), "std_r2": std(r2vec)} + + def evaluate_many(self, preds): + from numpy import mean, std + if len(preds) < 5: + raise Exception( + "Run your model on at least 5 seeds to compare results and provide your outputs in preds." + ) + out = dict() + preds = [self.evaluate(p) for p in preds] + out["mean_R^2"] = mean([x["mean_r2"] for x in preds]) + out["std_R^2"] = mean([x["std_r2"] for x in preds]) + out["seedstd_R^2"] = std([x["mean_r2"] for x in preds]) + return out diff --git a/tdc/multi_pred/perturboutcome.py b/tdc/multi_pred/perturboutcome.py index 94efc29b..4097dd19 100644 --- a/tdc/multi_pred/perturboutcome.py +++ b/tdc/multi_pred/perturboutcome.py @@ -5,6 +5,7 @@ import warnings warnings.filterwarnings("ignore") +import numpy as np import sys from ..utils import print_sys @@ -24,3 +25,91 @@ def get_DE_genes(self): def get_dropout_genes(self): raise ValueError("TODO") + + def get_split(self, + ratios=[0.8, 0.1, 0.1], + unseen=False, + use_random=True, + random_state=42): + """obtain train/dev/test splits for each cell_line + counterfactual prediction model is trained on a single cell line and then evaluated on same cell line + and against new cell lines + TODO: also allow for splitting by unseen perturbations + TODO: allow for evaluating within the same cell line""" + # For now, we will ensure there are no unseen perturbations + if unseen: + raise ValueError( + "Unseen perturbation splits are not yet implemented!") + df = self.get_data() + if use_random: + # just do a random split, otherwise you'll split by cell line... + from sklearn.model_selection import train_test_split + control = df[df["perturbation"] == "control"] + perturbs = df[df["perturbation"] != "control"] + train, tmp = train_test_split(perturbs, + test_size=ratios[1] + ratios[2], + random_state=random_state) + test, dev = train_test_split(tmp, + test_size=ratios[2] / + (ratios[1] + ratios[2]), + random_state=random_state) + return { + "control": control, + "train": train, + "dev": dev, + "test": test + } + cell_lines = df["cell_line"].unique() + perturbations = df["perturbation"].unique() + shuffled_cell_line_idx = np.random.permutation(len(cell_lines)) + assert len(shuffled_cell_line_idx) == len(cell_lines) + assert len(shuffled_cell_line_idx) > 3 + + # Split indices into three parts + train_end = int(ratios[0] * len(cell_lines)) # 60% for training + dev_end = train_end + int( + ratios[1] * len(cell_lines)) # 20% for development + + train_cell_line = shuffled_cell_line_idx[:train_end] + dev_cell_line = shuffled_cell_line_idx[train_end:dev_end] + test_cell_line = shuffled_cell_line_idx[dev_end:] + + assert len(train_cell_line) > 0 + assert len(dev_cell_line) > 0 + assert len(test_cell_line) > 0 + assert len(test_cell_line) > len(dev_cell_line) + + train_control = df[(df["cell_line"].isin(train_cell_line)) & + (df["perturbation"] == "control")] + train_perturbations = df[(df["cell_line"].isin(train_cell_line)) & + (df["perturbation"] != "control")] + + assert len(train_control) > 0 + assert len(train_perturbations) > 0 + assert len(train_control) <= len(train_perturbations) + + dev_control = df[(df["cell_line"].isin(dev_cell_line)) & + (df["perturbation"] == "control")] + dev_perturbations = df[(df["cell_line"].isin(dev_cell_line)) & + (df["perturbation"] != "control")] + + test_control = df[(df["cell_line"].isin(test_cell_line)) & + (df["perturbation"] == "control")] + test_perturbations = df[(df["cell_line"].isin(test_cell_line)) & + (df["perturbation"] != "control")] + + out = {} + out["train"] = { + "control": train_control, + "perturbations": train_perturbations + } + out["dev"] = { + "control": dev_control, + "perturbations": dev_perturbations + } + out["test"] = { + "control": test_control, + "perturbations": test_perturbations + } + # TODO: currently, there will be no inter-cell-line evaluation + return out diff --git a/tdc/test/test_benchmark.py b/tdc/test/test_benchmark.py index 2b448493..fe981c22 100644 --- a/tdc/test/test_benchmark.py +++ b/tdc/test/test_benchmark.py @@ -7,7 +7,7 @@ sys.path.append( os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -from tdc.benchmark_group import admet_group, scdti_group +from tdc.benchmark_group import admet_group, scdti_group, counterfactual_group def is_classification(values): @@ -85,6 +85,37 @@ def test_SCDTI_benchmark(self): assert len(many_results["f1"] ) == 2 # should include mean and standard deviation + @unittest.skip( + "counterfactual test is taking up too much memory" + ) #FIXME: please run if making changes to counterfactual benchmark or core code. + def test_counterfactual(self): + from tdc.multi_pred.perturboutcome import PerturbOutcome + from tdc.dataset_configs.config_map import scperturb_datasets, scperturb_gene_datasets + + test_data = PerturbOutcome("scperturb_drug_AissaBenevolenskaya2021") + group = counterfactual_group.CounterfactualGroup() # is drug + assert group.is_drug + assert set(group.dataset_names) == set( + scperturb_datasets + ), "loaded datasets should be scperturb drug, but were {} vs correct: {}".format( + group.dataset_names, scperturb_datasets) + ct = len(test_data.get_data()) + train, val = group.get_train_valid_split() + test = group.get_test() + control = group.split["control"] + testct = len(train) + len(val) + len(test) + len(control) + assert ct == testct, "counts between original data and the 3 splits should match: original {} vs splits {}".format( + ct, testct) + # basic test on perfect score + tst = test_data.get_split()["test"] + r2 = group.evaluate(tst) + assert r2 == 1, "comparing test to itself should have perfect R^2 score, was {}".format( + r2) + # now just check we can load sc perturb gene correctly + group_gene = counterfactual_group.CounterfactualGroup(is_drug=False) + assert not group_gene.is_drug + assert set(group_gene.dataset_names) == set(scperturb_gene_datasets) + if __name__ == "__main__": unittest.main()