diff --git a/livecellx/core/utils.py b/livecellx/core/utils.py index eeb67fd..71c256f 100644 --- a/livecellx/core/utils.py +++ b/livecellx/core/utils.py @@ -89,7 +89,7 @@ def get_cv2_bbox(label_mask: np.array): return bboxes_cv2 -def label_mask_to_edt_mask(label_mask, bg_val=0): +def label_mask_to_edt_mask(label_mask, bg_val=0, dtype=np.uint8): labels = np.unique(label_mask) # remvoe bg_val labels = labels[labels != bg_val] @@ -101,7 +101,14 @@ def label_mask_to_edt_mask(label_mask, bg_val=0): normalized_mask = normalize_img_to_uint8(tmp_mask) tmp_mask[tmp_mask != bg_val] = normalized_mask[tmp_mask != bg_val] edt_mask += tmp_mask - return edt_mask.astype(np.uint8) + + # TODO: remove the guard below because it is unlikely that we will have a label mask with values > 255, but we should handle this case + # The reason for "unlikely" is that the label mask is usually generated from a binary mask + # And thus the normalization process in the loop above will ensure that the values are in the range [0, 255] + if edt_mask.max() > 255: + edt_mask = normalize_img_to_uint8(edt_mask) + + return edt_mask.astype(dtype) def clip_polygon(polygon, h, w):