diff --git a/tdc/resource/dataloader.py b/tdc/resource/dataloader.py index 66ff73c1..0b966cd6 100644 --- a/tdc/resource/dataloader.py +++ b/tdc/resource/dataloader.py @@ -51,14 +51,21 @@ def __init__(self, name, path="./data", dataset_names=None): for split_type, labeled_splits in self[self.split_key].items(): for label, keyed_splits in labeled_splits.items(): for key, split_name in keyed_splits.items(): - splitdf = load.resource_dataset_load( - split_name, path, dataset_names) - # apply config - splitdf["Y"] = splitdf[self.config["Y"]] - # store dataframe - self[self.split_key][split_type][label][key] = splitdf + if type(split_name) == str: + splitdf = load.resource_dataset_load( + split_name, path, dataset_names) + # apply config + splitdf["Y"] = splitdf[self.config["Y"]] + # store dataframe + self[self. + split_key][split_type][label][key] = splitdf + else: + assert type(split_name) == pd.DataFrame, type( + split_name) + # apply label - self[self.df_key]["Y"] = self[self.df_key][self.config["Y"]] + if "Y" not in self[self.df_key].columns: # for repeat loader runs + self[self.df_key]["Y"] = self[self.df_key][self.config["Y"]] return # no further data processing needed. TODO: can support dsl configs as well here. # raise Exception("keys", self.keys()) diff --git a/tdc/test/test_benchmark.py b/tdc/test/test_benchmark.py index 7285dca0..444bed95 100644 --- a/tdc/test/test_benchmark.py +++ b/tdc/test/test_benchmark.py @@ -166,6 +166,15 @@ def test_proteinpeptide(self): res = group.evaluate(y_test) assert res[-1] == 1 and res[-2] == 1, res + def test_tcrepitope(self): + from tdc.benchmark_group.tcrepitope_group import TCREpitopeGroup + from tdc.resource.dataloader import DataLoader + data = DataLoader("tchard") + tst = data.get_split()["test"] + group = TCREpitopeGroup() + res = group.evaluate(tst) + assert res == 1 + if __name__ == "__main__": unittest.main()