Skip to content

Commit

Permalink
Merge branch 'mims-harvard:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
abearab authored May 6, 2024
2 parents 078a656 + 5fbd717 commit 16d23e6
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
49 changes: 49 additions & 0 deletions tdc/benchmark_group/scdti_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Author: TDC Team
# License: MIT
import os

from .base_group import BenchmarkGroup


class SCDTIGroup(BenchmarkGroup):
"""Create SCDTI Group Class object. This is for single-cell drug-target identification task benchmark.
Args:
path (str, optional): the path to store/retrieve the SCDTI group datasets.
"""

def __init__(self, path="./data", file_format="csv"):
"""Create an SCDTI benchmark group class."""
# super().__init__(name="SCDTI_Group", path=path)
self.name = "SCDTI_Group"
self.path = os.path.join(path, self.name)
# self.datasets = ["pinnacle_dti"]
self.dataset_names = ["pinnacle_dti"]
self.file_format = file_format
self.split = None

def get_train_valid_split(self):
"""parameters included for compatibility. this benchmark has a fixed train/test split."""
from ..resource.dataloader import DataLoader
if self.split is None:
dl = DataLoader(name="pinnacle_dti")
self.split = dl.get_split()
return self.split["train"], self.split["dev"]

def get_test(self):
from ..resource.dataloader import DataLoader
if self.split is None:
dl = DataLoader(name="pinnacle_dti")
self.split = dl.get_split()
return self.split["test"]

def evaluate(self, y_pred):
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
y_true = self.get_test()["Y"]
# Calculate metrics
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
return [precision, recall, accuracy, f1]
21 changes: 20 additions & 1 deletion tdc/test/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from tdc.benchmark_group import admet_group
from tdc.benchmark_group import admet_group, scdti_group


def is_classification(values):
Expand Down Expand Up @@ -62,6 +62,25 @@ def test_ADME_evaluate_many(self):
for my_group in self.group:
self.assertTrue(my_group["name"] in results)

def test_SCDTI_benchmark(self):
from tdc.resource.dataloader import DataLoader

data = DataLoader(name="pinnacle_dti")
group = scdti_group.SCDTIGroup()
train, val = group.get_train_valid_split()
assert len(val) == 0 # this benchmark has no validation set
# test simple preds
y_true = group.get_test()["Y"]
results = group.evaluate(y_true)
assert results[-1] == 1.0 # should be perfect F1 score
# assert it matches the PINNACLE official test scores
tst = data.get_split()["test"]["Y"]
results = group.evaluate(tst)
assert results[-1] == 1.0
zero_pred = [0] * len(y_true)
results = group.evaluate(zero_pred)
assert results[-1] != 1.0 # should not be perfect F1 score


if __name__ == "__main__":
unittest.main()

0 comments on commit 16d23e6

Please sign in to comment.