From 1f549c287c68df0bd19e4b22f256ecdcd8e56f6e Mon Sep 17 00:00:00 2001 From: Alejandro Velez Arce Date: Tue, 7 May 2024 12:13:18 -0400 Subject: [PATCH 1/2] Update run_tests.py Testing --- run_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/run_tests.py b/run_tests.py index 56b529b5..77d2b41d 100644 --- a/run_tests.py +++ b/run_tests.py @@ -1,6 +1,7 @@ import unittest import sys + if __name__ == '__main__': loader = unittest.TestLoader() start_dir = 'tdc/test' From f436bdc4695fa1b96c53b57964f120af3de08739 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 8 May 2024 10:52:09 -0400 Subject: [PATCH 2/2] implement evaluate_many on scdti benchmark --- tdc/benchmark_group/scdti_group.py | 17 +++++++++++++++++ tdc/test/test_benchmark.py | 4 ++++ 2 files changed, 21 insertions(+) diff --git a/tdc/benchmark_group/scdti_group.py b/tdc/benchmark_group/scdti_group.py index 685beb09..eda47efb 100644 --- a/tdc/benchmark_group/scdti_group.py +++ b/tdc/benchmark_group/scdti_group.py @@ -47,3 +47,20 @@ def evaluate(self, y_pred): accuracy = accuracy_score(y_true, y_pred) f1 = f1_score(y_true, y_pred) return [precision, recall, accuracy, f1] + + def evaluate_many(self, preds): + from numpy import mean, std + if len(preds) < 5: + raise Exception( + "Run your model on at least 5 seeds to compare results and provide your outputs in preds." + ) + out = dict() + preds = [self.evaluate(p) for p in preds] + out["precision"] = (mean([x[0] for x in preds]), + std([x[0] for x in preds])) + out["recall"] = (mean([x[1] for x in preds]), std([x[1] for x in preds + ])) + out["accuracy"] = (mean([x[2] for x in preds]), + std([x[2] for x in preds])) + out["f1"] = (mean([x[3] for x in preds]), std([x[3] for x in preds])) + return out diff --git a/tdc/test/test_benchmark.py b/tdc/test/test_benchmark.py index 56c2b0a1..6a89c299 100644 --- a/tdc/test/test_benchmark.py +++ b/tdc/test/test_benchmark.py @@ -80,6 +80,10 @@ def test_SCDTI_benchmark(self): zero_pred = [0] * len(y_true) results = group.evaluate(zero_pred) assert results[-1] != 1.0 # should not be perfect F1 score + many_results = group.evaluate_many([y_true] * 5) + assert "f1" in many_results + assert len(many_results["f1"] + ) == 2 # should include mean and standard deviation if __name__ == "__main__":