diff --git a/tdc/benchmark_group/counterfactual_group.py b/tdc/benchmark_group/counterfactual_group.py index 1fdcea6e..35ac9d51 100644 --- a/tdc/benchmark_group/counterfactual_group.py +++ b/tdc/benchmark_group/counterfactual_group.py @@ -3,6 +3,7 @@ # License: MIT import numpy as np import os +import pandas as pd from .base_group import BenchmarkGroup from ..dataset_configs.config_map import scperturb_datasets, scperturb_gene_datasets @@ -34,20 +35,28 @@ def __init__(self, path="./data", file_format="csv", is_drug=True): self.file_format = file_format self.split = None - def get_train_valid_split(self, dataset=None, only_seen=False): + def get_train_valid_split(self, + dataset=None, + split_to_unseen=False, + remove_unseen=True): """parameters included for compatibility. this benchmark has a fixed train/test split.""" from ..multi_pred.perturboutcome import PerturbOutcome dataset = dataset or "scperturb_drug_AissaBenevolenskaya2021" assert dataset in self.dataset_names, "{} dataset not in {}".format( dataset, self.dataset_names) data = PerturbOutcome(dataset) - self.split = data.get_split() + self.split = data.get_split(unseen=split_to_unseen, + remove_unseen=remove_unseen) cell_lines = list(self.split.keys()) + self.split["adj"] = 0 for line in cell_lines: print("processing benchmark line", line) for split, df in self.split[line].items(): if split not in self.split: self.split[split] = {} + elif split == "adj": + self.split["adj"] += df + continue self.split[split][line] = df print("done with line", line) return self.split["train"], self.split["dev"] @@ -74,7 +83,20 @@ def evaluate(self, y_pred): categorical_cols = y_pred[cell_line].select_dtypes( include=['object', 'category']).columns y_pred[cell_line] = y_pred[cell_line].drop(columns=categorical_cols) - r2 = r2_score(df.mean(), y_pred[cell_line].mean()) + mdf = df.mean() + mpred = y_pred[cell_line].mean() + if len(mdf) != len(mpred): + raise Exception( + "lengths between true and test mean vectors defers in cell line {} with {} vs {}" + .format(cell_line, len(mdf), len(mpred))) + elif pd.isna(mdf.values).any(): + raise Exception( + "ground truth mean contains {} nan values".format( + mdf.isna().sum())) + elif pd.isna(mpred.values).any(): + raise Exception("prediction mean contains {} nan values".format( + mpred.isna().sum())) + r2 = r2_score(mdf, mpred) r2vec.append(r2) return np.mean(r2vec) diff --git a/tdc/multi_pred/perturboutcome.py b/tdc/multi_pred/perturboutcome.py index 024831f4..f636c8a1 100644 --- a/tdc/multi_pred/perturboutcome.py +++ b/tdc/multi_pred/perturboutcome.py @@ -62,7 +62,8 @@ def get_dropout_genes(self): def get_cellline_split(self, ratios=[0.8, 0.1, 0.1], random_state=42, - split_to_unseen=False): + split_to_unseen=False, + remove_unseen=True): df = self.get_data() print("got data grouping by cell line") cell_line_groups = df.groupby("cell_line") @@ -82,11 +83,22 @@ def get_cellline_split(self, test_size=ratios[2] / (ratios[1] + ratios[2]), random_state=random_state) + filter_test = test["perturbation"].isin(train["perturbation"]) + filter_dev = dev["perturbation"].isin(train["perturbation"]) + adj = 0 + if remove_unseen: + lbef = len(test), len(dev) + test = test[~filter_test] + dev = dev[~filter_dev] + laft = len(test), len(dev) + adj = sum(lbef) - sum(laft) + # TODO: filters might dilute test/dev siginificantly ... cell_line_splits[cell_line] = { "control": control, "train": train, "test": test, - "dev": dev + "dev": dev, + "adj": adj, } else: perturbs = cell_line_group["perturbation"].unique() @@ -257,12 +269,12 @@ def get_split(self, use_random=False, random_state=42, train_val_gene_set_size=0.75, - combo_seen2_train_frac=0.75): + combo_seen2_train_frac=0.75, + remove_unseen=True): """obtain train/dev/test splits for each cell_line counterfactual prediction model is trained on a single cell line and then evaluated on same cell line and against new cell lines - TODO: also allow for splitting by unseen perturbations - TODO: allow for evaluating within the same cell line""" + """ if self.is_gene: ## use gene perturbation data split @@ -317,7 +329,8 @@ def map_name(x): if not use_random: return self.get_cellline_split(split_to_unseen=unseen, ratios=ratios, - random_state=random_state) + random_state=random_state, + remove_unseen=remove_unseen) df = self.get_data() # just do a random split, otherwise you'll split by cell line... control = df[df["perturbation"] == "control"] diff --git a/tdc/test/test_benchmark.py b/tdc/test/test_benchmark.py index a8c2540a..7285dca0 100644 --- a/tdc/test/test_benchmark.py +++ b/tdc/test/test_benchmark.py @@ -103,19 +103,20 @@ def test_counterfactual(self): print("getting test data via dataloader") ct = len(test_data.get_data()) print("getting splits") - train, val = group.get_train_valid_split() + train, val = group.get_train_valid_split(remove_unseen=False) test = group.get_test() print("got splits; checking counts") trainct = sum(len(x) for _, x in train.items()) valct = sum(len(x) for _, x in val.items()) testct = sum(len(x) for _, x in test.items()) controlct = sum(len(x) for _, x in group.split["control"].items()) - totalct = trainct + valct + testct + controlct + adjct = group.split["adj"] + totalct = trainct + valct + testct + controlct + adjct assert ct == totalct, "counts between original data and the 3 splits should match: original {} vs splits {}".format( ct, totalct) # basic test on perfect score print("benchmark - generating identical test set") - tst = test_data.get_split() + tst = test_data.get_split(remove_unseen=False) tstdict = {} for line, splits in tst.items(): tstdict[line] = splits["test"]