Skip to content

Commit

Permalink
filter out unseen
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed May 21, 2024
1 parent eb397e3 commit 666e5be
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 12 deletions.
28 changes: 25 additions & 3 deletions tdc/benchmark_group/counterfactual_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# License: MIT
import numpy as np
import os
import pandas as pd

from .base_group import BenchmarkGroup
from ..dataset_configs.config_map import scperturb_datasets, scperturb_gene_datasets
Expand Down Expand Up @@ -34,20 +35,28 @@ def __init__(self, path="./data", file_format="csv", is_drug=True):
self.file_format = file_format
self.split = None

def get_train_valid_split(self, dataset=None, only_seen=False):
def get_train_valid_split(self,
dataset=None,
split_to_unseen=False,
remove_unseen=True):
"""parameters included for compatibility. this benchmark has a fixed train/test split."""
from ..multi_pred.perturboutcome import PerturbOutcome
dataset = dataset or "scperturb_drug_AissaBenevolenskaya2021"
assert dataset in self.dataset_names, "{} dataset not in {}".format(
dataset, self.dataset_names)
data = PerturbOutcome(dataset)
self.split = data.get_split()
self.split = data.get_split(unseen=split_to_unseen,
remove_unseen=remove_unseen)
cell_lines = list(self.split.keys())
self.split["adj"] = 0
for line in cell_lines:
print("processing benchmark line", line)
for split, df in self.split[line].items():
if split not in self.split:
self.split[split] = {}
elif split == "adj":
self.split["adj"] += df
continue
self.split[split][line] = df
print("done with line", line)
return self.split["train"], self.split["dev"]
Expand All @@ -74,7 +83,20 @@ def evaluate(self, y_pred):
categorical_cols = y_pred[cell_line].select_dtypes(
include=['object', 'category']).columns
y_pred[cell_line] = y_pred[cell_line].drop(columns=categorical_cols)
r2 = r2_score(df.mean(), y_pred[cell_line].mean())
mdf = df.mean()
mpred = y_pred[cell_line].mean()
if len(mdf) != len(mpred):
raise Exception(
"lengths between true and test mean vectors defers in cell line {} with {} vs {}"
.format(cell_line, len(mdf), len(mpred)))
elif pd.isna(mdf.values).any():
raise Exception(
"ground truth mean contains {} nan values".format(
mdf.isna().sum()))
elif pd.isna(mpred.values).any():
raise Exception("prediction mean contains {} nan values".format(
mpred.isna().sum()))
r2 = r2_score(mdf, mpred)
r2vec.append(r2)
return np.mean(r2vec)

Expand Down
25 changes: 19 additions & 6 deletions tdc/multi_pred/perturboutcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def get_dropout_genes(self):
def get_cellline_split(self,
ratios=[0.8, 0.1, 0.1],
random_state=42,
split_to_unseen=False):
split_to_unseen=False,
remove_unseen=True):
df = self.get_data()
print("got data grouping by cell line")
cell_line_groups = df.groupby("cell_line")
Expand All @@ -82,11 +83,22 @@ def get_cellline_split(self,
test_size=ratios[2] /
(ratios[1] + ratios[2]),
random_state=random_state)
filter_test = test["perturbation"].isin(train["perturbation"])
filter_dev = dev["perturbation"].isin(train["perturbation"])
adj = 0
if remove_unseen:
lbef = len(test), len(dev)
test = test[~filter_test]
dev = dev[~filter_dev]
laft = len(test), len(dev)
adj = sum(lbef) - sum(laft)
# TODO: filters might dilute test/dev siginificantly ...
cell_line_splits[cell_line] = {
"control": control,
"train": train,
"test": test,
"dev": dev
"dev": dev,
"adj": adj,
}
else:
perturbs = cell_line_group["perturbation"].unique()
Expand Down Expand Up @@ -257,12 +269,12 @@ def get_split(self,
use_random=False,
random_state=42,
train_val_gene_set_size=0.75,
combo_seen2_train_frac=0.75):
combo_seen2_train_frac=0.75,
remove_unseen=True):
"""obtain train/dev/test splits for each cell_line
counterfactual prediction model is trained on a single cell line and then evaluated on same cell line
and against new cell lines
TODO: also allow for splitting by unseen perturbations
TODO: allow for evaluating within the same cell line"""
"""

if self.is_gene:
## use gene perturbation data split
Expand Down Expand Up @@ -317,7 +329,8 @@ def map_name(x):
if not use_random:
return self.get_cellline_split(split_to_unseen=unseen,
ratios=ratios,
random_state=random_state)
random_state=random_state,
remove_unseen=remove_unseen)
df = self.get_data()
# just do a random split, otherwise you'll split by cell line...
control = df[df["perturbation"] == "control"]
Expand Down
7 changes: 4 additions & 3 deletions tdc/test/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,20 @@ def test_counterfactual(self):
print("getting test data via dataloader")
ct = len(test_data.get_data())
print("getting splits")
train, val = group.get_train_valid_split()
train, val = group.get_train_valid_split(remove_unseen=False)
test = group.get_test()
print("got splits; checking counts")
trainct = sum(len(x) for _, x in train.items())
valct = sum(len(x) for _, x in val.items())
testct = sum(len(x) for _, x in test.items())
controlct = sum(len(x) for _, x in group.split["control"].items())
totalct = trainct + valct + testct + controlct
adjct = group.split["adj"]
totalct = trainct + valct + testct + controlct + adjct
assert ct == totalct, "counts between original data and the 3 splits should match: original {} vs splits {}".format(
ct, totalct)
# basic test on perfect score
print("benchmark - generating identical test set")
tst = test_data.get_split()
tst = test_data.get_split(remove_unseen=False)
tstdict = {}
for line, splits in tst.items():
tstdict[line] = splits["test"]
Expand Down

0 comments on commit 666e5be

Please sign in to comment.