Skip to content

Commit

Permalink
Merge pull request #115 from Josh-Talks/fix-multi-threshold
Browse files Browse the repository at this point in the history
fix overwriting of prediction for multiple thresholds
  • Loading branch information
wolny authored May 27, 2024
2 parents 2eaf45a + 9a84df5 commit 203225f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pytorch3dunet/unet3d/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,16 @@ def input_to_segm(self, input):
for predictions in input:
for th in self.thresholds:
# threshold probability maps
predictions = predictions > th
predictions_th = predictions > th

if self.invert_pmaps:
# for connected component analysis we need to treat boundary signal as background
# assign 0-label to boundary mask
predictions = np.logical_not(predictions)
predictions_th = np.logical_not(predictions_th)

predictions = predictions.astype(np.uint8)
predictions_th = predictions_th.astype(np.uint8)
# run connected components on the predicted mask; consider only 1-connectivity
seg = measure.label(predictions, background=0, connectivity=1)
seg = measure.label(predictions_th, background=0, connectivity=1)
segs.append(seg)

return np.stack(segs)
Expand Down

0 comments on commit 203225f

Please sign in to comment.