diff --git a/tdc/benchmark_group/geneperturb_group.py b/tdc/benchmark_group/geneperturb_group.py index c19dc034..2b756adf 100644 --- a/tdc/benchmark_group/geneperturb_group.py +++ b/tdc/benchmark_group/geneperturb_group.py @@ -4,11 +4,11 @@ import numpy as np import os -from .base_group import BenchmarkGroup +from .counterfactual_group import CounterfactualGroup from ..dataset_configs.config_map import scperturb_datasets, scperturb_gene_datasets -class GenePerturbGroup(BenchmarkGroup): +class GenePerturbGroup(CounterfactualGroup): """Create GenePerturbGroup Group Class object. This is for single-cell gene perturbation prediction tasks benchmark. Args: @@ -25,9 +25,11 @@ def __init__(self, path="./data", file_format="csv"): """Create a GenePerturbGroup prediction benchmark group class.""" self.name = "GenePerturbGroup" self.path = os.path.join(path, self.name) - self.dataset_names = ["scperturb_gene_NormanWeissman2019", - "scperturb_gene_ReplogleWeissman2022_rpe1", - "scperturb_gene_ReplogleWeissman2022_k562_essential"] + self.dataset_names = [ + "scperturb_gene_NormanWeissman2019", + "scperturb_gene_ReplogleWeissman2022_rpe1", + "scperturb_gene_ReplogleWeissman2022_k562_essential" + ] self.file_format = file_format self.split = None @@ -39,43 +41,10 @@ def get_train_valid_split(self, dataset=None): dataset, self.dataset_names) data = PerturbOutcome(dataset) self.split = data.get_split() - + return self.split[0]["train"], self.split[0]["dev"] def get_test(self): if self.split is None: self.get_train_valid_split() return self.split[0]["test"] - - def evaluate(self, y_pred): - from sklearn.metrics import r2_score - y_true = self.get_test() - r2vec = [] - for cell_line, df in y_true.items(): - check = self._DRUG_COLS[0] if self.is_drug else self._GENE_COLS[0] - cols = self._DRUG_COLS if self.is_drug else self._GENE_COLS - if check in df.columns: - df.drop(cols, axis=1) - if check in y_pred[cell_line].columns: - y_pred[cell_line].drop(cols, axis=1) - categorical_cols = df.select_dtypes( - include=['object', 'category']).columns - df = df.drop(columns=categorical_cols) - categorical_cols = y_pred[cell_line].select_dtypes( - include=['object', 'category']).columns - y_pred[cell_line] = y_pred[cell_line].drop(columns=categorical_cols) - r2 = r2_score(df.mean(), y_pred[cell_line].mean()) - r2vec.append(r2) - return np.mean(r2vec) - - def evaluate_many(self, preds): - from numpy import mean, std - if len(preds) < 5: - raise Exception( - "Run your model on at least 5 seeds to compare results and provide your outputs in preds." - ) - out = dict() - preds = [self.evaluate(p) for p in preds] - out["mean_R^2"] = mean(preds) - out["std_R^2"] = std(preds) - return out diff --git a/tdc/multi_pred/perturboutcome.py b/tdc/multi_pred/perturboutcome.py index 2c246332..024831f4 100644 --- a/tdc/multi_pred/perturboutcome.py +++ b/tdc/multi_pred/perturboutcome.py @@ -13,6 +13,7 @@ from .single_cell import CellXGeneTemplate from ..dataset_configs.config_map import scperturb_gene_datasets + def parse_single_pert(i): a = i.split('+')[0] b = i.split('+')[1] @@ -22,9 +23,11 @@ def parse_single_pert(i): pert = a return pert + def parse_combo_pert(i): return i.split('+')[0], i.split('+')[1] - + + def parse_any_pert(p): if ('ctrl' in p) and (p != 'ctrl'): return [parse_single_pert(p)] @@ -32,6 +35,7 @@ def parse_any_pert(p): out = parse_combo_pert(p) return [out[0], out[1]] + class PerturbOutcome(CellXGeneTemplate): def __init__(self, name, path="./data", print_stats=False): @@ -112,30 +116,30 @@ def get_cellline_split(self, print("done with cell line", cell_line) return cell_line_splits - + def get_perts_from_genes(self, genes, pert_list, type_='both'): - """ + """ Returns all single/combo/both perturbations that include a gene """ - single_perts = [p for p in pert_list if ('ctrl' in p) and (p != 'ctrl')] - combo_perts = [p for p in pert_list if 'ctrl' not in p] - - perts = [] - - if type_ == 'single': - pert_candidate_list = single_perts - elif type_ == 'combo': - pert_candidate_list = combo_perts - elif type_ == 'both': - pert_candidate_list = pert_list - - for p in pert_candidate_list: - for g in genes: - if g in parse_any_pert(p): - perts.append(p) - break - return perts + single_perts = [p for p in pert_list if ('ctrl' in p) and (p != 'ctrl')] + combo_perts = [p for p in pert_list if 'ctrl' not in p] + + perts = [] + + if type_ == 'single': + pert_candidate_list = single_perts + elif type_ == 'combo': + pert_candidate_list = combo_perts + elif type_ == 'both': + pert_candidate_list = pert_list + + for p in pert_candidate_list: + for g in genes: + if g in parse_any_pert(p): + perts.append(p) + break + return perts def get_genes_from_perts(self, perts): """ @@ -149,86 +153,111 @@ def get_genes_from_perts(self, perts): gene_list = [g for g in gene_list if g != 'ctrl'] return np.unique(gene_list) - def get_simulation_split_single(self, pert_list, train_gene_set_size = 0.85, seed = 1): + def get_simulation_split_single(self, + pert_list, + train_gene_set_size=0.85, + seed=1): unique_pert_genes = self.get_genes_from_perts(pert_list) - + pert_train = [] pert_test = [] np.random.seed(seed=seed) - + ## a pre-specified list of genes - train_gene_candidates = np.random.choice(unique_pert_genes, - int(len(unique_pert_genes) * train_gene_set_size), replace = False) - + train_gene_candidates = np.random.choice( + unique_pert_genes, + int(len(unique_pert_genes) * train_gene_set_size), + replace=False) + ## ood genes - ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates) + ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates) + + pert_single_train = self.get_perts_from_genes(train_gene_candidates, + pert_list, 'single') + unseen_single = self.get_perts_from_genes(ood_genes, pert_list, + 'single') - pert_single_train = self.get_perts_from_genes(train_gene_candidates, pert_list,'single') - unseen_single = self.get_perts_from_genes(ood_genes, pert_list, 'single') - #print(len(pert_single_train), len(unseen_single), len(pert_list)) #assert len(unseen_single) + len(pert_single_train) == len(pert_list) - - return pert_single_train, unseen_single, {'unseen_single': unseen_single} - def get_simulation_split(self, pert_list, train_gene_set_size = 0.85, - combo_seen2_train_frac = 0.85, seed = 1): - + return pert_single_train, unseen_single, { + 'unseen_single': unseen_single + } + + def get_simulation_split(self, + pert_list, + train_gene_set_size=0.85, + combo_seen2_train_frac=0.85, + seed=1): + unique_pert_genes = self.get_genes_from_perts(pert_list) - + pert_train = [] pert_test = [] np.random.seed(seed=seed) ## a pre-specified list of genes - train_gene_candidates = np.random.choice(unique_pert_genes, - int(len(unique_pert_genes) * train_gene_set_size), replace = False) - + train_gene_candidates = np.random.choice( + unique_pert_genes, + int(len(unique_pert_genes) * train_gene_set_size), + replace=False) + ## ood genes - ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates) + ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates) - pert_single_train = self.get_perts_from_genes(train_gene_candidates, pert_list,'single') - pert_combo = self.get_perts_from_genes(train_gene_candidates, pert_list,'combo') + pert_single_train = self.get_perts_from_genes(train_gene_candidates, + pert_list, 'single') + pert_combo = self.get_perts_from_genes(train_gene_candidates, pert_list, + 'combo') pert_train.extend(pert_single_train) - + ## the combo set with one of them in OOD - combo_seen1 = [x for x in pert_combo if len([t for t in x.split('+') if - t in train_gene_candidates]) == 1] + combo_seen1 = [ + x for x in pert_combo + if len([t for t in x.split('+') if t in train_gene_candidates]) == 1 + ] pert_test.extend(combo_seen1) - + pert_combo = np.setdiff1d(pert_combo, combo_seen1) ## randomly sample the combo seen 2 as a test set, the rest in training set np.random.seed(seed=seed) - pert_combo_train = np.random.choice(pert_combo, int(len(pert_combo) * combo_seen2_train_frac), replace = False) - + pert_combo_train = np.random.choice( + pert_combo, + int(len(pert_combo) * combo_seen2_train_frac), + replace=False) + combo_seen2 = np.setdiff1d(pert_combo, pert_combo_train).tolist() pert_test.extend(combo_seen2) pert_train.extend(pert_combo_train) - + ## unseen single - unseen_single = self.get_perts_from_genes(ood_genes, pert_list, 'single') + unseen_single = self.get_perts_from_genes(ood_genes, pert_list, + 'single') combo_ood = self.get_perts_from_genes(ood_genes, pert_list, 'combo') pert_test.extend(unseen_single) - + ## here only keeps the seen 0, since seen 1 is tackled above - combo_seen0 = [x for x in combo_ood if len([t for t in x.split('+') if - t in train_gene_candidates]) == 0] + combo_seen0 = [ + x for x in combo_ood + if len([t for t in x.split('+') if t in train_gene_candidates]) == 0 + ] pert_test.extend(combo_seen0) #assert len(combo_seen1) + len(combo_seen0) + len(unseen_single) + len(pert_train) + len(combo_seen2) == len(pert_list) - return pert_train, pert_test, {'combo_seen0': combo_seen0, - 'combo_seen1': combo_seen1, - 'combo_seen2': combo_seen2, - 'unseen_single': unseen_single} - + return pert_train, pert_test, { + 'combo_seen0': combo_seen0, + 'combo_seen1': combo_seen1, + 'combo_seen2': combo_seen2, + 'unseen_single': unseen_single + } def get_split(self, - ratios=[0.8, 0.1, 0.1], - unseen=False, - use_random=False, - random_state=42, - train_val_gene_set_size = 0.75, - combo_seen2_train_frac = 0.75): + ratios=[0.8, 0.1, 0.1], + unseen=False, + use_random=False, + random_state=42, + train_val_gene_set_size=0.75, + combo_seen2_train_frac=0.75): """obtain train/dev/test splits for each cell_line counterfactual prediction model is trained on a single cell line and then evaluated on same cell line and against new cell lines @@ -240,61 +269,64 @@ def get_split(self, # check if this data has single or combo perturbations train_gene_set_size = train_val_gene_set_size combo_seen2_train_frac = combo_seen2_train_frac - + if self.is_combo: + def map_name(x): if x == 'control': return 'ctrl' else: - return '+'.join(x.split('_')) if '_' in x else x + '+ctrl' - self.adata.obs['condition'] = self.adata.obs.perturbation.apply(lambda x: map_name(x)) + return '+'.join( + x.split('_')) if '_' in x else x + '+ctrl' + + self.adata.obs['condition'] = self.adata.obs.perturbation.apply( + lambda x: map_name(x)) unique_perts = self.adata.obs.condition.unique() - train, test, test_subgroup = self.get_simulation_split(unique_perts, - train_gene_set_size, - combo_seen2_train_frac, - random_state) - train, val, val_subgroup = self.get_simulation_split(train, - 0.9, - 0.9, - random_state) + train, test, test_subgroup = self.get_simulation_split( + unique_perts, train_gene_set_size, combo_seen2_train_frac, + random_state) + train, val, val_subgroup = self.get_simulation_split( + train, 0.9, 0.9, random_state) else: - self.adata.obs['condition'] = self.adata.obs.perturbation.apply(lambda x: x + '+ctrl' if x!='control' else 'ctrl') + self.adata.obs['condition'] = self.adata.obs.perturbation.apply( + lambda x: x + '+ctrl' if x != 'control' else 'ctrl') unique_perts = self.adata.obs.condition.unique() - train, test, test_subgroup = self.get_simulation_split_single(unique_perts, - train_gene_set_size, - random_state) - train, val, val_subgroup = self.get_simulation_split_single(train, - 0.9, - random_state) - + train, test, test_subgroup = self.get_simulation_split_single( + unique_perts, train_gene_set_size, random_state) + train, val, val_subgroup = self.get_simulation_split_single( + train, 0.9, random_state) + map_dict = {x: 'train' for x in train} map_dict.update({x: 'val' for x in val}) map_dict.update({x: 'test' for x in test}) map_dict.update({'ctrl': 'train'}) self.adata.obs['split'] = self.adata.obs['condition'].map(map_dict) - adata_out = {"train": self.adata[self.adata.obs.split == 'train'], - "dev": self.adata[self.adata.obs.split == 'val'], - "test": self.adata[self.adata.obs.split == 'test']} - subgroup = {'test_subgroup': test_subgroup, - 'dev_subgroup': val_subgroup} + adata_out = { + "train": self.adata[self.adata.obs.split == 'train'], + "dev": self.adata[self.adata.obs.split == 'val'], + "test": self.adata[self.adata.obs.split == 'test'] + } + subgroup = { + 'test_subgroup': test_subgroup, + 'dev_subgroup': val_subgroup + } return adata_out, subgroup - if not use_random: return self.get_cellline_split(split_to_unseen=unseen, - ratios=ratios, - random_state=random_state) + ratios=ratios, + random_state=random_state) df = self.get_data() # just do a random split, otherwise you'll split by cell line... control = df[df["perturbation"] == "control"] perturbs = df[df["perturbation"] != "control"] train, tmp = train_test_split(perturbs, - test_size=ratios[1] + ratios[2], - random_state=random_state) + test_size=ratios[1] + ratios[2], + random_state=random_state) test, dev = train_test_split(tmp, - test_size=ratios[2] / - (ratios[1] + ratios[2]), - random_state=random_state) + test_size=ratios[2] / + (ratios[1] + ratios[2]), + random_state=random_state) return {"control": control, "train": train, "dev": dev, "test": test} diff --git a/tdc/test/test_benchmark.py b/tdc/test/test_benchmark.py index eb6a6f33..a8c2540a 100644 --- a/tdc/test/test_benchmark.py +++ b/tdc/test/test_benchmark.py @@ -7,7 +7,7 @@ sys.path.append( os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -from tdc.benchmark_group import admet_group, scdti_group, counterfactual_group +from tdc.benchmark_group import admet_group, scdti_group, counterfactual_group, geneperturb_group def is_classification(values): @@ -130,6 +130,14 @@ def test_counterfactual(self): assert not group_gene.is_drug assert set(group_gene.dataset_names) == set(scperturb_gene_datasets) + @unittest.skip( + "counterfactual test is taking up too much memory" + ) #FIXME: please run if making changes to counterfactual benchmark or core code. + def test_gene_perturb(self): + group = geneperturb_group.GenePerturbGroup() + group.get_train_valid_split() + group.get_test() + def test_proteinpeptide(self): from tdc.benchmark_group.protein_peptide_group import ProteinPeptideGroup from tdc.multi_pred.proteinpeptide import ProteinPeptide