diff --git a/tdc/metadata.py b/tdc/metadata.py index da34214e..7cdf164c 100644 --- a/tdc/metadata.py +++ b/tdc/metadata.py @@ -199,6 +199,47 @@ "opentargets_ibd_drug_evidence", "opentargets_ra_data_splits_idx", "opentargets_ibd_data_splits_idx", + "tchard_full", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-0", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-1", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-2", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-3", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-4", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-0", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-1", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-2", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-3", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-4", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-0", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-1", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-2", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-3", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-4", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-0", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-1", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-2", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-3", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-4", + "tchard_pep_cdr3b_only_neg_assays_test-0", + "tchard_pep_cdr3b_only_neg_assays_test-1", + "tchard_pep_cdr3b_only_neg_assays_test-2", + "tchard_pep_cdr3b_only_neg_assays_test-3", + "tchard_pep_cdr3b_only_neg_assays_test-4", + "tchard_pep_cdr3b_only_neg_assays_train-0", + "tchard_pep_cdr3b_only_neg_assays_train-1", + "tchard_pep_cdr3b_only_neg_assays_train-2", + "tchard_pep_cdr3b_only_neg_assays_train-3", + "tchard_pep_cdr3b_only_neg_assays_train-4", + "tchard_pep_cdr3b_only_sampled_negs_test-0", + "tchard_pep_cdr3b_only_sampled_negs_test-1", + "tchard_pep_cdr3b_only_sampled_negs_test-2", + "tchard_pep_cdr3b_only_sampled_negs_test-3", + "tchard_pep_cdr3b_only_sampled_negs_test-4", + "tchard_pep_cdr3b_only_sampled_negs_train-0", + "tchard_pep_cdr3b_only_sampled_negs_train-1", + "tchard_pep_cdr3b_only_sampled_negs_train-2", + "tchard_pep_cdr3b_only_sampled_negs_train-3", + "tchard_pep_cdr3b_only_sampled_negs_train-4", ] resources = { @@ -217,7 +258,76 @@ "opentargets_ra_drug_evidence", "opentargets_ibd_drug_evidence", ], - } + }, + "tchard": { + "splits_raw": { + "train": { + "tchard_pep_cdr3b_only_neg_assays": { + 0: "tchard_pep_cdr3b_only_neg_assays_train-0", + 1: "tchard_pep_cdr3b_only_neg_assays_train-1", + 2: "tchard_pep_cdr3b_only_neg_assays_train-2", + 3: "tchard_pep_cdr3b_only_neg_assays_train-3", + 4: "tchard_pep_cdr3b_only_neg_assays_train-4", + }, + "tchard_pep_cdr3b_only_sampled_negs_train": { + 0: "tchard_pep_cdr3b_only_sampled_negs_train-0", + 1: "tchard_pep_cdr3b_only_sampled_negs_train-1", + 2: "tchard_pep_cdr3b_only_sampled_negs_train-2", + 3: "tchard_pep_cdr3b_only_sampled_negs_train-3", + 4: "tchard_pep_cdr3b_only_sampled_negs_train-4", + }, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train": { + 0: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-0", + 1: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-1", + 2: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-2", + 3: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-3", + 4: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-4", + }, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train": { + 0: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-0", + 1: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-1", + 2: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-2", + 3: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-3", + 4: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-4", + } + }, + "test": { + "tchard_pep_cdr3b_only_neg_assays": { + 0: "tchard_pep_cdr3b_only_neg_assays_test-0", + 1: "tchard_pep_cdr3b_only_neg_assays_test-1", + 2: "tchard_pep_cdr3b_only_neg_assays_test-2", + 3: "tchard_pep_cdr3b_only_neg_assays_test-3", + 4: "tchard_pep_cdr3b_only_neg_assays_test-4", + }, + "tchard_pep_cdr3b_only_sampled_negs_train": { + 0: "tchard_pep_cdr3b_only_sampled_negs_test-0", + 1: "tchard_pep_cdr3b_only_sampled_negs_test-1", + 2: "tchard_pep_cdr3b_only_sampled_negs_test-2", + 3: "tchard_pep_cdr3b_only_sampled_negs_test-3", + 4: "tchard_pep_cdr3b_only_sampled_negs_test-4", + }, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train": { + 0: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-0", + 1: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-1", + 2: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-2", + 3: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-3", + 4: "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-4", + }, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train": { + 0: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-0", + 1: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-1", + 2: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-2", + 3: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-3", + 4: "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-4", + } + }, + "dev": {} # no dev set on tchard + }, + "all": ["tchard_full",], + "config": { + "Y": "label", + } + }, } #################################### @@ -785,6 +895,47 @@ def get_task2category(): "opentargets_ibd_data_splits_idx": "json", "opentargets_ra_drug_evidence": "tab", "opentargets_ibd_drug_evidence": "tab", + "tchard_full": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-0": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-1": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-2": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-3": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-4": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-0": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-1": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-2": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-3": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-4": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-0": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-1": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-2": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-3": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-4": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-0": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-1": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-2": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-3": "tab", + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-4": "tab", + "tchard_pep_cdr3b_only_neg_assays_test-0": "tab", + "tchard_pep_cdr3b_only_neg_assays_test-1": "tab", + "tchard_pep_cdr3b_only_neg_assays_test-2": "tab", + "tchard_pep_cdr3b_only_neg_assays_test-3": "tab", + "tchard_pep_cdr3b_only_neg_assays_test-4": "tab", + "tchard_pep_cdr3b_only_neg_assays_train-0": "tab", + "tchard_pep_cdr3b_only_neg_assays_train-1": "tab", + "tchard_pep_cdr3b_only_neg_assays_train-2": "tab", + "tchard_pep_cdr3b_only_neg_assays_train-3": "tab", + "tchard_pep_cdr3b_only_neg_assays_train-4": "tab", + "tchard_pep_cdr3b_only_sampled_negs_test-0": "tab", + "tchard_pep_cdr3b_only_sampled_negs_test-1": "tab", + "tchard_pep_cdr3b_only_sampled_negs_test-2": "tab", + "tchard_pep_cdr3b_only_sampled_negs_test-3": "tab", + "tchard_pep_cdr3b_only_sampled_negs_test-4": "tab", + "tchard_pep_cdr3b_only_sampled_negs_train-0": "tab", + "tchard_pep_cdr3b_only_sampled_negs_train-1": "tab", + "tchard_pep_cdr3b_only_sampled_negs_train-2": "tab", + "tchard_pep_cdr3b_only_sampled_negs_train-3": "tab", + "tchard_pep_cdr3b_only_sampled_negs_train-4": "tab", } name2id = { @@ -914,6 +1065,47 @@ def get_task2category(): "opentargets_ibd_data_splits_idx": 10143573, "opentargets_ra_drug_evidence": 10141153, "opentargets_ibd_drug_evidence": 10141154, + "tchard_full": 10228321, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-0": 10228304, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-1": 10228296, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-2": 10228328, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-3": 10228299, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_test-4": 10228330, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-0": 10228331, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-1": 10228334, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-2": 10228324, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-3": 10228325, + "tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train-4": 10228327, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-0": 10228320, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-1": 10228295, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-2": 10228297, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-3": 10228294, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_test-4": 10228309, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-0": 10228301, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-1": 10228310, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-2": 10228315, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-3": 10228311, + "tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train-4": 10228335, + "tchard_pep_cdr3b_only_neg_assays_test-0": 10228300, + "tchard_pep_cdr3b_only_neg_assays_test-1": 10228302, + "tchard_pep_cdr3b_only_neg_assays_test-2": 10228305, + "tchard_pep_cdr3b_only_neg_assays_test-3": 10228298, + "tchard_pep_cdr3b_only_neg_assays_test-4": 10228319, + "tchard_pep_cdr3b_only_neg_assays_train-0": 10228312, + "tchard_pep_cdr3b_only_neg_assays_train-1": 10228317, + "tchard_pep_cdr3b_only_neg_assays_train-2": 10228333, + "tchard_pep_cdr3b_only_neg_assays_train-3": 10228318, + "tchard_pep_cdr3b_only_neg_assays_train-4": 10228314, + "tchard_pep_cdr3b_only_sampled_negs_test-0": 10228329, + "tchard_pep_cdr3b_only_sampled_negs_test-1": 10228332, + "tchard_pep_cdr3b_only_sampled_negs_test-2": 10228303, + "tchard_pep_cdr3b_only_sampled_negs_test-3": 10228306, + "tchard_pep_cdr3b_only_sampled_negs_test-4": 10228308, + "tchard_pep_cdr3b_only_sampled_negs_train-0": 10228323, + "tchard_pep_cdr3b_only_sampled_negs_train-1": 10228313, + "tchard_pep_cdr3b_only_sampled_negs_train-2": 10228322, + "tchard_pep_cdr3b_only_sampled_negs_train-3": 10228316, + "tchard_pep_cdr3b_only_sampled_negs_train-4": 10228326, } oracle2type = { diff --git a/tdc/resource/dataloader.py b/tdc/resource/dataloader.py index 3084a706..66ff73c1 100644 --- a/tdc/resource/dataloader.py +++ b/tdc/resource/dataloader.py @@ -1,5 +1,8 @@ """ Dataloader class for resource datasets + +TODO: we should create a way for task-specific loaders to leverage this in the backend via the dsl config. +TODO: for example, the tchard dataset uses this loader, but it is a multi_pred task (tct-epitope) """ import pandas as pd @@ -33,9 +36,34 @@ def __init__(self, name, path="./data", dataset_names=None): self[n] = load.resource_dataset_load(n, path, dataset_names) for nsplit in metadata.resources[name].get("splits", []): self[nsplit] = {"splits": self[nsplit]} + if "splits_raw" in metadata.resources[name]: + self.df_key = "all_data" + self[self.df_key] = pd.concat([ + load.resource_dataset_load(dfname, path, dataset_names) + for dfname in metadata.resources[name].get("all", []) + ]) + self.split_key = "splits" + self[self.split_key] = metadata.resources[name].get( + "splits_raw", []) + # for now, we have hard-coded config embedded in metadata; should implement dsl config support for this case as well. + self.config = metadata.resources[name].get("config") + # now just need to load the actual data into the splits + for split_type, labeled_splits in self[self.split_key].items(): + for label, keyed_splits in labeled_splits.items(): + for key, split_name in keyed_splits.items(): + splitdf = load.resource_dataset_load( + split_name, path, dataset_names) + # apply config + splitdf["Y"] = splitdf[self.config["Y"]] + # store dataframe + self[self.split_key][split_type][label][key] = splitdf + # apply label + self[self.df_key]["Y"] = self[self.df_key][self.config["Y"]] + return # no further data processing needed. TODO: can support dsl configs as well here. + # raise Exception("keys", self.keys()) self.config = config_map.ConfigMap().get(name) - self.config = self.config() + self.config = self.config() if self.config is not None else None assert self.config is not None, "resource.DataLoader requires a corresponding config" assert isinstance( self.config, config.ResourceConfig @@ -52,14 +80,16 @@ def __init__(self, name, path="./data", dataset_names=None): def get_data(self, df_key=None, **kwargs): # TODO: can call parent's get_data(**kwargs) function if dataset not pre-loaded - df_key = df_key or self.config.df_key + df_key = df_key or self.config.df_key if type( + self.config) != dict else self.df_key assert df_key in self, "{} key hasn't been set in the loader, please set it by using the resource.DataLoader".format( df_key) return self[df_key] def get_split(self, split_key=None, **kwargs): # TODO: can call parent's get_split(**kwargs) function if splits not pre-loaded - split_key = split_key or self.config.split_key + split_key = split_key or self.config.split_key if type( + self.config) != dict else self.split_key assert split_key in self, "{} key hasn't been set in the loader, please set it by using the resource.DataLoader".format( split_key) return self[split_key] diff --git a/tdc/test/test_dataloaders.py b/tdc/test/test_dataloaders.py index e8c3906e..b43e1f5e 100644 --- a/tdc/test/test_dataloaders.py +++ b/tdc/test/test_dataloaders.py @@ -105,6 +105,25 @@ def test_resource_dataverse_dataloader(self): assert len(split["test"]) > 0 assert isinstance(split["train"], pd.DataFrame) + def test_resource_dataverse_dataloader_raw_splits(self): + import pandas as pd + from tdc.resource.dataloader import DataLoader + data = DataLoader(name="tchard") + df = data.get_data() + assert isinstance(df, pd.DataFrame) + assert "Y" in df.columns + assert "splits" in data + splits = data.get_split() + assert "train" in splits + assert "dev" in splits + assert "test" in splits + assert isinstance( + splits["train"]["tchard_pep_cdr3b_only_neg_assays"][0], + pd.DataFrame) + assert isinstance(splits["test"]["tchard_pep_cdr3b_only_neg_assays"][2], + pd.DataFrame) + assert not splits["dev"] + def tearDown(self): try: print(os.getcwd())