Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neurips benchmarks -- add tchard dataset along with hard splits #272

Merged
merged 2 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 193 additions & 1 deletion tdc/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,47 @@
"opentargets_ibd_drug_evidence",
"opentargets_ra_data_splits_idx",
"opentargets_ibd_data_splits_idx",
"tchard_full",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-0",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-1",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-2",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-3",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-4",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-0",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-1",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-2",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-3",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-4",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-0",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-1",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-2",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-3",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-4",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-0",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-1",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-2",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-3",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-4",
"tchard_pep_cdr3b_only_neg_assays_test-0",
"tchard_pep_cdr3b_only_neg_assays_test-1",
"tchard_pep_cdr3b_only_neg_assays_test-2",
"tchard_pep_cdr3b_only_neg_assays_test-3",
"tchard_pep_cdr3b_only_neg_assays_test-4",
"tchard_pep_cdr3b_only_neg_assays_train-0",
"tchard_pep_cdr3b_only_neg_assays_train-1",
"tchard_pep_cdr3b_only_neg_assays_train-2",
"tchard_pep_cdr3b_only_neg_assays_train-3",
"tchard_pep_cdr3b_only_neg_assays_train-4",
"tchard_pep_cdr3b_only_sampled_negs_test-0",
"tchard_pep_cdr3b_only_sampled_negs_test-1",
"tchard_pep_cdr3b_only_sampled_negs_test-2",
"tchard_pep_cdr3b_only_sampled_negs_test-3",
"tchard_pep_cdr3b_only_sampled_negs_test-4",
"tchard_pep_cdr3b_only_sampled_negs_train-0",
"tchard_pep_cdr3b_only_sampled_negs_train-1",
"tchard_pep_cdr3b_only_sampled_negs_train-2",
"tchard_pep_cdr3b_only_sampled_negs_train-3",
"tchard_pep_cdr3b_only_sampled_negs_train-4",
]

resources = {
Expand All @@ -217,7 +258,76 @@
"opentargets_ra_drug_evidence",
"opentargets_ibd_drug_evidence",
],
}
},
"tchard": {
"splits_raw": {
"train": {
"tchard_pep_cdr3b_only_neg_assays": {
0: "tchard_pep_cdr3b_only_neg_assays_train-0",
1: "tchard_pep_cdr3b_only_neg_assays_train-1",
2: "tchard_pep_cdr3b_only_neg_assays_train-2",
3: "tchard_pep_cdr3b_only_neg_assays_train-3",
4: "tchard_pep_cdr3b_only_neg_assays_train-4",
},
"tchard_pep_cdr3b_only_sampled_negs_train": {
0: "tchard_pep_cdr3b_only_sampled_negs_train-0",
1: "tchard_pep_cdr3b_only_sampled_negs_train-1",
2: "tchard_pep_cdr3b_only_sampled_negs_train-2",
3: "tchard_pep_cdr3b_only_sampled_negs_train-3",
4: "tchard_pep_cdr3b_only_sampled_negs_train-4",
},
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train": {
0: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-0",
1: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-1",
2: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-2",
3: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-3",
4: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-4",
},
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train": {
0: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-0",
1: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-1",
2: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-2",
3: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-3",
4: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-4",
}
},
"test": {
"tchard_pep_cdr3b_only_neg_assays": {
0: "tchard_pep_cdr3b_only_neg_assays_test-0",
1: "tchard_pep_cdr3b_only_neg_assays_test-1",
2: "tchard_pep_cdr3b_only_neg_assays_test-2",
3: "tchard_pep_cdr3b_only_neg_assays_test-3",
4: "tchard_pep_cdr3b_only_neg_assays_test-4",
},
"tchard_pep_cdr3b_only_sampled_negs_train": {
0: "tchard_pep_cdr3b_only_sampled_negs_test-0",
1: "tchard_pep_cdr3b_only_sampled_negs_test-1",
2: "tchard_pep_cdr3b_only_sampled_negs_test-2",
3: "tchard_pep_cdr3b_only_sampled_negs_test-3",
4: "tchard_pep_cdr3b_only_sampled_negs_test-4",
},
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train": {
0: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-0",
1: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-1",
2: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-2",
3: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-3",
4: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-4",
},
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train": {
0: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-0",
1: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-1",
2: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-2",
3: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-3",
4: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-4",
}
},
"dev": {} # no dev set on tchard
},
"all": ["tchard_full",],
"config": {
"Y": "label",
}
},
}

####################################
Expand Down Expand Up @@ -785,6 +895,47 @@ def get_task2category():
"opentargets_ibd_data_splits_idx": "json",
"opentargets_ra_drug_evidence": "tab",
"opentargets_ibd_drug_evidence": "tab",
"tchard_full": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-0": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-1": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-2": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-3": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-4": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-0": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-1": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-2": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-3": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-4": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-0": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-1": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-2": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-3": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-4": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-0": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-1": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-2": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-3": "tab",
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-4": "tab",
"tchard_pep_cdr3b_only_neg_assays_test-0": "tab",
"tchard_pep_cdr3b_only_neg_assays_test-1": "tab",
"tchard_pep_cdr3b_only_neg_assays_test-2": "tab",
"tchard_pep_cdr3b_only_neg_assays_test-3": "tab",
"tchard_pep_cdr3b_only_neg_assays_test-4": "tab",
"tchard_pep_cdr3b_only_neg_assays_train-0": "tab",
"tchard_pep_cdr3b_only_neg_assays_train-1": "tab",
"tchard_pep_cdr3b_only_neg_assays_train-2": "tab",
"tchard_pep_cdr3b_only_neg_assays_train-3": "tab",
"tchard_pep_cdr3b_only_neg_assays_train-4": "tab",
"tchard_pep_cdr3b_only_sampled_negs_test-0": "tab",
"tchard_pep_cdr3b_only_sampled_negs_test-1": "tab",
"tchard_pep_cdr3b_only_sampled_negs_test-2": "tab",
"tchard_pep_cdr3b_only_sampled_negs_test-3": "tab",
"tchard_pep_cdr3b_only_sampled_negs_test-4": "tab",
"tchard_pep_cdr3b_only_sampled_negs_train-0": "tab",
"tchard_pep_cdr3b_only_sampled_negs_train-1": "tab",
"tchard_pep_cdr3b_only_sampled_negs_train-2": "tab",
"tchard_pep_cdr3b_only_sampled_negs_train-3": "tab",
"tchard_pep_cdr3b_only_sampled_negs_train-4": "tab",
}

name2id = {
Expand Down Expand Up @@ -914,6 +1065,47 @@ def get_task2category():
"opentargets_ibd_data_splits_idx": 10143573,
"opentargets_ra_drug_evidence": 10141153,
"opentargets_ibd_drug_evidence": 10141154,
"tchard_full": 10228321,
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-0": 10228304,
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-1": 10228296,
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-2": 10228328,
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-3": 10228299,
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-4": 10228330,
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-0": 10228331,
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-1": 10228334,
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-2": 10228324,
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-3": 10228325,
"tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-4": 10228327,
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-0": 10228320,
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-1": 10228295,
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-2": 10228297,
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-3": 10228294,
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-4": 10228309,
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-0": 10228301,
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-1": 10228310,
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-2": 10228315,
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-3": 10228311,
"tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-4": 10228335,
"tchard_pep_cdr3b_only_neg_assays_test-0": 10228300,
"tchard_pep_cdr3b_only_neg_assays_test-1": 10228302,
"tchard_pep_cdr3b_only_neg_assays_test-2": 10228305,
"tchard_pep_cdr3b_only_neg_assays_test-3": 10228298,
"tchard_pep_cdr3b_only_neg_assays_test-4": 10228319,
"tchard_pep_cdr3b_only_neg_assays_train-0": 10228312,
"tchard_pep_cdr3b_only_neg_assays_train-1": 10228317,
"tchard_pep_cdr3b_only_neg_assays_train-2": 10228333,
"tchard_pep_cdr3b_only_neg_assays_train-3": 10228318,
"tchard_pep_cdr3b_only_neg_assays_train-4": 10228314,
"tchard_pep_cdr3b_only_sampled_negs_test-0": 10228329,
"tchard_pep_cdr3b_only_sampled_negs_test-1": 10228332,
"tchard_pep_cdr3b_only_sampled_negs_test-2": 10228303,
"tchard_pep_cdr3b_only_sampled_negs_test-3": 10228306,
"tchard_pep_cdr3b_only_sampled_negs_test-4": 10228308,
"tchard_pep_cdr3b_only_sampled_negs_train-0": 10228323,
"tchard_pep_cdr3b_only_sampled_negs_train-1": 10228313,
"tchard_pep_cdr3b_only_sampled_negs_train-2": 10228322,
"tchard_pep_cdr3b_only_sampled_negs_train-3": 10228316,
"tchard_pep_cdr3b_only_sampled_negs_train-4": 10228326,
}

oracle2type = {
Expand Down
36 changes: 33 additions & 3 deletions tdc/resource/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""
Dataloader class for resource datasets

TODO: we should create a way for task-specific loaders to leverage this in the backend via the dsl config.
TODO: for example, the tchard dataset uses this loader, but it is a multi_pred task (tct-epitope)
"""

import pandas as pd
Expand Down Expand Up @@ -33,9 +36,34 @@ def __init__(self, name, path="./data", dataset_names=None):
self[n] = load.resource_dataset_load(n, path, dataset_names)
for nsplit in metadata.resources[name].get("splits", []):
self[nsplit] = {"splits": self[nsplit]}
if "splits_raw" in metadata.resources[name]:
self.df_key = "all_data"
self[self.df_key] = pd.concat([
load.resource_dataset_load(dfname, path, dataset_names)
for dfname in metadata.resources[name].get("all", [])
])
self.split_key = "splits"
self[self.split_key] = metadata.resources[name].get(
"splits_raw", [])
# for now, we have hard-coded config embedded in metadata; should implement dsl config support for this case as well.
self.config = metadata.resources[name].get("config")
# now just need to load the actual data into the splits
for split_type, labeled_splits in self[self.split_key].items():
for label, keyed_splits in labeled_splits.items():
for key, split_name in keyed_splits.items():
splitdf = load.resource_dataset_load(
split_name, path, dataset_names)
# apply config
splitdf["Y"] = splitdf[self.config["Y"]]
# store dataframe
self[self.split_key][split_type][label][key] = splitdf
# apply label
self[self.df_key]["Y"] = self[self.df_key][self.config["Y"]]
return # no further data processing needed. TODO: can support dsl configs as well here.

# raise Exception("keys", self.keys())
self.config = config_map.ConfigMap().get(name)
self.config = self.config()
self.config = self.config() if self.config is not None else None
assert self.config is not None, "resource.DataLoader requires a corresponding config"
assert isinstance(
self.config, config.ResourceConfig
Expand All @@ -52,14 +80,16 @@ def __init__(self, name, path="./data", dataset_names=None):

def get_data(self, df_key=None, **kwargs):
# TODO: can call parent's get_data(**kwargs) function if dataset not pre-loaded
df_key = df_key or self.config.df_key
df_key = df_key or self.config.df_key if type(
self.config) != dict else self.df_key
assert df_key in self, "{} key hasn't been set in the loader, please set it by using the resource.DataLoader".format(
df_key)
return self[df_key]

def get_split(self, split_key=None, **kwargs):
# TODO: can call parent's get_split(**kwargs) function if splits not pre-loaded
split_key = split_key or self.config.split_key
split_key = split_key or self.config.split_key if type(
self.config) != dict else self.split_key
assert split_key in self, "{} key hasn't been set in the loader, please set it by using the resource.DataLoader".format(
split_key)
return self[split_key]
19 changes: 19 additions & 0 deletions tdc/test/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,25 @@ def test_resource_dataverse_dataloader(self):
assert len(split["test"]) > 0
assert isinstance(split["train"], pd.DataFrame)

def test_resource_dataverse_dataloader_raw_splits(self):
import pandas as pd
from tdc.resource.dataloader import DataLoader
data = DataLoader(name="tchard")
df = data.get_data()
assert isinstance(df, pd.DataFrame)
assert "Y" in df.columns
assert "splits" in data
splits = data.get_split()
assert "train" in splits
assert "dev" in splits
assert "test" in splits
assert isinstance(
splits["train"]["tchard_pep_cdr3b_only_neg_assays"][0],
pd.DataFrame)
assert isinstance(splits["test"]["tchard_pep_cdr3b_only_neg_assays"][2],
pd.DataFrame)
assert not splits["dev"]

def tearDown(self):
try:
print(os.getcwd())
Expand Down
Loading