From 3b63be8269290f0739981aec4f4e47ecc9752926 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 12 Sep 2024 10:47:52 +0300 Subject: [PATCH] adding max_size_fraction option (#796) --- cellpose/dynamics.py | 19 +++++++++++++------ cellpose/models.py | 21 +++++++++++++-------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py index d38e07ba..812d11d9 100644 --- a/cellpose/dynamics.py +++ b/cellpose/dynamics.py @@ -663,7 +663,7 @@ def mem_info(): return masks -def get_masks(p, iscell=None, rpad=20): +def get_masks(p, iscell=None, rpad=20, max_size_fraction=0.4): """Create masks using pixel convergence after running dynamics. Makes a histogram of final pixel locations p, initializes masks @@ -677,6 +677,8 @@ def get_masks(p, iscell=None, rpad=20): iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are iscell False to stay in their original location. rpad (int, optional): Histogram edge padding. Default is 20. + max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. Returns: M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed, @@ -750,7 +752,7 @@ def get_masks(p, iscell=None, rpad=20): # remove big masks uniq, counts = fastremap.unique(M0, return_counts=True) - big = np.prod(shape0) * 0.4 + big = np.prod(shape0) * max_size_fraction bigc = uniq[counts > big] if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0): M0 = fastremap.mask(M0, bigc) @@ -761,7 +763,7 @@ def get_masks(p, iscell=None, rpad=20): def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, flow_threshold=0.4, interp=True, do_3D=False, min_size=15, - resize=None, device=None): + max_size_fraction=0.4, resize=None, device=None): """Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None. Args: @@ -774,6 +776,8 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True. do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False. min_size (int, optional): The minimum size of the masks. Defaults to 15. + max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. resize (tuple, optional): The desired size for resizing the masks. Defaults to None. device (str, optional): The torch device to use for computation. Defaults to None. @@ -783,7 +787,8 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold mask, p = compute_masks(dP, cellprob, p=p, niter=niter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, interp=interp, do_3D=do_3D, - min_size=min_size, device=device) + min_size=min_size, max_size_fraction=max_size_fraction, + device=device) if resize is not None: mask = transforms.resize_image(mask, resize[0], resize[1], @@ -798,7 +803,7 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, flow_threshold=0.4, interp=True, do_3D=False, min_size=15, - device=None): + max_size_fraction=0.4, device=None): """Compute masks using dynamics from dP and cellprob. Args: @@ -811,6 +816,8 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True. do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False. min_size (int, optional): The minimum size of the masks. Defaults to 15. + max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. device (str, optional): The torch device to use for computation. Defaults to None. Returns: @@ -831,7 +838,7 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, return mask, p #calculate masks - mask = get_masks(p, iscell=cp_mask) + mask = get_masks(p, iscell=cp_mask, max_size_fraction=max_size_fraction) # flow thresholding factored out of get_masks if not do_3D: diff --git a/cellpose/models.py b/cellpose/models.py index 8f6b8ff7..32ba223a 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -140,7 +140,7 @@ def __init__(self, gpu=False, model_type="cyto3", nchan=2, device=None, self.sz.model_type = model_type def eval(self, x, batch_size=8, channels=[0, 0], channel_axis=None, invert=False, - normalize=True, diameter=30., do_3D=False, find_masks=True, **kwargs): + normalize=True, diameter=30., do_3D=False, **kwargs): """Run cellpose size model and mask model and get masks. Args: @@ -353,9 +353,9 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, - stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile=True, - tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, - progress=None): + stitch_threshold=0.0, min_size=15, max_size_fraction=0.4, niter=None, + augment=False, tile=True, tile_overlap=0.1, bsize=224, + interp=True, compute_masks=True, progress=None): """ segment list of images x, or 4D array - Z x nchan x Y x X Args: @@ -394,6 +394,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None. stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0. min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15. + max_size_fraction (float, optional): max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None. augment (bool, optional): tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False. tile (bool, optional): tiles image to ensure GPU/CPU memory usage limited (recommended). Defaults to True. @@ -435,7 +437,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, tile_overlap=tile_overlap, bsize=bsize, resample=resample, interp=interp, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, compute_masks=compute_masks, - min_size=min_size, stitch_threshold=stitch_threshold, + min_size=min_size, max_size_fraction=max_size_fraction, + stitch_threshold=stitch_threshold, progress=progress, niter=niter) masks.append(maski) flows.append(flowi) @@ -464,7 +467,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, rescale=rescale, resample=resample, augment=augment, tile=tile, tile_overlap=tile_overlap, bsize=bsize, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, interp=interp, min_size=min_size, - do_3D=do_3D, anisotropy=anisotropy, niter=niter, + max_size_fraction=max_size_fraction, do_3D=do_3D, anisotropy=anisotropy, niter=niter, stitch_threshold=stitch_threshold) flows = [plot.dx_to_circ(dP), dP, cellprob, p] @@ -473,7 +476,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=None, rescale=1.0, resample=True, augment=False, tile=True, tile_overlap=0.1, cellprob_threshold=0.0, bsize=224, flow_threshold=0.4, min_size=15, - interp=True, anisotropy=1.0, do_3D=False, stitch_threshold=0.0): + max_size_fraction=0.4, interp=True, anisotropy=1.0, do_3D=False, + stitch_threshold=0.0): if isinstance(normalize, dict): normalize_params = {**normalize_default, **normalize} @@ -538,7 +542,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non masks, p = dynamics.resize_and_compute_masks( dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, interp=interp, do_3D=do_3D, - min_size=min_size, resize=None, + min_size=min_size, max_size_fraction=max_size_fraction, resize=None, device=self.device if self.gpu else None) else: masks, p = [], [] @@ -557,6 +561,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non resize=resize, min_size=min_size if stitch_threshold == 0 or nimg == 1 else -1, # turn off for 3D stitching + max_size_fraction=max_size_fraction, device=self.device if self.gpu else None) masks.append(outputs[0]) p.append(outputs[1])