Skip to content

Commit

Permalink
add class label option to write metric report to improve readability … (
Browse files Browse the repository at this point in the history
#7249)

add class label option to write metric report to improve readability,
without that option in case of many classes the resulting report is very
hard to interpret.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).

---------

Signed-off-by: elitap <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
elitap and pre-commit-ci[bot] authored Nov 22, 2023
1 parent c300b36 commit 8e134b8
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def write_metrics_reports(
summary_ops: str | Sequence[str] | None,
deli: str = ",",
output_type: str = "csv",
class_labels: list[str] | None = None,
) -> None:
"""
Utility function to write the metrics into files, contains 3 parts:
Expand Down Expand Up @@ -94,6 +95,8 @@ class mean median max 5percentile 95percentile notnans
deli: the delimiter character in the saved file, default to "," as the default output type is `csv`.
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
output_type: expected output file type, supported types: ["csv"], default to "csv".
class_labels: list of class names used to name the classes in the output report, if None,
"class0", ..., "classn" are used, default to None.
"""
if output_type.lower() != "csv":
Expand All @@ -118,7 +121,12 @@ class mean median max 5percentile 95percentile notnans
v = v.reshape((-1, 1))

# add the average value of all classes to v
class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"]
if class_labels is None:
class_labels = ["class" + str(i) for i in range(v.shape[1])]
else:
class_labels = [str(i) for i in class_labels] # ensure to have a list of str

class_labels += ["mean"]
v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1)

with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f:
Expand Down

0 comments on commit 8e134b8

Please sign in to comment.