Skip to content

Commit

Permalink
DCO Remediation Commit for Zifu Wang <[email protected]>
Browse files Browse the repository at this point in the history
I, Zifu Wang <[email protected]>, hereby add my Signed-off-by to this commit: 3f74183
I, Zifu Wang <[email protected]>, hereby add my Signed-off-by to this commit: a778e58
I, Zifu Wang <[email protected]>, hereby add my Signed-off-by to this commit: aeef0af
I, Zifu Wang <[email protected]>, hereby add my Signed-off-by to this commit: 58c5396

Signed-off-by: Zifu Wang <[email protected]>
  • Loading branch information
zifuwanggg committed Dec 2, 2024
1 parent f07925e commit 185d2e1
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions monai/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ def compute_tp_fp_fn(
decoupled: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
reduce_axis: the axis to be reduced.
ord: the order of the vector norm.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
decoupled: whether the input and the target should be decoupled when computing fp and fn.
Only for the original implementation when soft_label is False.
Adapted from:
https://github.com/zifuwanggg/JDTLosses
"""
Expand All @@ -39,6 +49,8 @@ def compute_tp_fp_fn(
else:
fp = torch.sum(input * (1 - target), dim=reduce_axis)
fn = torch.sum((1 - input) * target, dim=reduce_axis)
# the new implementation that is correct with soft labels
# and it is identical to the original implementation with hard labels
else:
pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis)
ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis)
Expand Down

0 comments on commit 185d2e1

Please sign in to comment.