From 9a84df543e730199a93dbb3538318d0ae7592463 Mon Sep 17 00:00:00 2001 From: Joshua Talks Date: Thu, 25 Apr 2024 14:53:30 +0200 Subject: [PATCH] fix overwriting of prediction for multiple thresholds --- pytorch3dunet/unet3d/metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch3dunet/unet3d/metrics.py b/pytorch3dunet/unet3d/metrics.py index 2b60b4b7..6764eec9 100644 --- a/pytorch3dunet/unet3d/metrics.py +++ b/pytorch3dunet/unet3d/metrics.py @@ -205,16 +205,16 @@ def input_to_segm(self, input): for predictions in input: for th in self.thresholds: # threshold probability maps - predictions = predictions > th + predictions_th = predictions > th if self.invert_pmaps: # for connected component analysis we need to treat boundary signal as background # assign 0-label to boundary mask - predictions = np.logical_not(predictions) + predictions_th = np.logical_not(predictions_th) - predictions = predictions.astype(np.uint8) + predictions_th = predictions_th.astype(np.uint8) # run connected components on the predicted mask; consider only 1-connectivity - seg = measure.label(predictions, background=0, connectivity=1) + seg = measure.label(predictions_th, background=0, connectivity=1) segs.append(seg) return np.stack(segs)