From 8e134b8cb92e3c624b23d4d10c5d4596bb5b9d9b Mon Sep 17 00:00:00 2001 From: elitap Date: Wed, 22 Nov 2023 06:06:17 +0100 Subject: [PATCH] =?UTF-8?q?add=20class=20label=20option=20to=20write=20met?= =?UTF-8?q?ric=20report=20to=20improve=20readability=20=E2=80=A6=20(#7249)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add class label option to write metric report to improve readability, without that option in case of many classes the resulting report is very hard to interpret. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: elitap Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/handlers/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 58a3fd36f3..0cd31b89c2 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -61,6 +61,7 @@ def write_metrics_reports( summary_ops: str | Sequence[str] | None, deli: str = ",", output_type: str = "csv", + class_labels: list[str] | None = None, ) -> None: """ Utility function to write the metrics into files, contains 3 parts: @@ -94,6 +95,8 @@ class mean median max 5percentile 95percentile notnans deli: the delimiter character in the saved file, default to "," as the default output type is `csv`. to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter. output_type: expected output file type, supported types: ["csv"], default to "csv". + class_labels: list of class names used to name the classes in the output report, if None, + "class0", ..., "classn" are used, default to None. """ if output_type.lower() != "csv": @@ -118,7 +121,12 @@ class mean median max 5percentile 95percentile notnans v = v.reshape((-1, 1)) # add the average value of all classes to v - class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"] + if class_labels is None: + class_labels = ["class" + str(i) for i in range(v.shape[1])] + else: + class_labels = [str(i) for i in class_labels] # ensure to have a list of str + + class_labels += ["mean"] v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1) with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f: