diff --git a/cellpose/core.py b/cellpose/core.py index b48d83cc..678cbf5c 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -277,12 +277,15 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0 """ nout = net.nout if imgi.ndim == 4: - Lz, nchan = imgi.shape[:2] - IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(imgi[0], bsize=bsize, - augment=augment, - tile_overlap=tile_overlap) - ny, nx, nchan, ly, lx = IMG.shape - batch_size *= max(4, (bsize**2 // (ly * lx))**0.5) + Lz, nchan, Ly, Lx = imgi.shape + if augment: + ny = max(2, int(np.ceil(2. * Ly / bsize))) + nx = max(2, int(np.ceil(2. * Lx / bsize))) + ly, lx = bsize, bsize + else: + ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize)) + nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize)) + ly, lx = min(bsize, Ly), min(bsize, Lx) yf = np.zeros((Lz, nout, imgi.shape[-2], imgi.shape[-1]), np.float32) styles = [] if ny * nx > batch_size: @@ -290,34 +293,35 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0 if Lz > 1 else range(Lz)) for i in ziterator: yfi, stylei = _run_tiled(net, imgi[i], augment=augment, bsize=bsize, - tile_overlap=tile_overlap) + batch_size=batch_size, tile_overlap=tile_overlap) yf[i] = yfi styles.append(stylei) else: # run multiple slices at the same time ntiles = ny * nx - nimgs = max(2, int(np.round(batch_size / ntiles))) + nimgs = batch_size // ntiles # number of z-slices to run at the same time niter = int(np.ceil(Lz / nimgs)) ziterator = (trange(niter, file=tqdm_out, mininterval=30) if Lz > 1 else range(niter)) for k in ziterator: - IMGa = np.zeros((ntiles * nimgs, nchan, ly, lx), np.float32) - for i in range(min(Lz - k * nimgs, nimgs)): + inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs)) + IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") + for i, b in enumerate(inds): IMG, ysub, xsub, Ly, Lx = transforms.make_tiles( - imgi[k * nimgs + i], bsize=bsize, augment=augment, + imgi[b], bsize=bsize, augment=augment, tile_overlap=tile_overlap) - IMGa[i * ntiles:(i + 1) * ntiles] = np.reshape( - IMG, (ny * nx, nchan, ly, lx)) + IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG, + (ny * nx, nchan, ly, lx)) ya, stylea = _forward(net, IMGa) - for i in range(min(Lz - k * nimgs, nimgs)): - y = ya[i * ntiles:(i + 1) * ntiles] + for i, b in enumerate(inds): + y = ya[i * ntiles : (i + 1) * ntiles] if augment: y = np.reshape(y, (ny, nx, 3, ly, lx)) y = transforms.unaugment_tiles(y) y = np.reshape(y, (-1, 3, ly, lx)) yfi = transforms.average_tiles(y, ysub, xsub, Ly, Lx) yfi = yfi[:, :imgi.shape[2], :imgi.shape[3]] - yf[k * nimgs + i] = yfi + yf[b] = yfi stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0) stylei /= (stylei**2).sum()**0.5 styles.append(stylei) diff --git a/cellpose/models.py b/cellpose/models.py index 7ca20747..967abd97 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -454,8 +454,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, nchan=self.nchan) if x.ndim < 4: x = x[np.newaxis, ...] - self.batch_size = batch_size - + if diameter is not None and diameter > 0: rescale = self.diam_mean / diameter elif rescale is None: @@ -465,7 +464,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, masks, styles, dP, cellprob, p = self._run_cp( x, compute_masks=compute_masks, normalize=normalize, invert=invert, rescale=rescale, resample=resample, augment=augment, tile=tile, - tile_overlap=tile_overlap, bsize=bsize, flow_threshold=flow_threshold, + batch_size=batch_size, tile_overlap=tile_overlap, bsize=bsize, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, interp=interp, min_size=min_size, max_size_fraction=max_size_fraction, do_3D=do_3D, anisotropy=anisotropy, niter=niter, stitch_threshold=stitch_threshold) @@ -474,7 +473,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, return masks, flows, styles def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=None, - rescale=1.0, resample=True, augment=False, tile=True, tile_overlap=0.1, + rescale=1.0, resample=True, augment=False, tile=True, + batch_size=8, tile_overlap=0.1, cellprob_threshold=0.0, bsize=224, flow_threshold=0.4, min_size=15, max_size_fraction=0.4, interp=True, anisotropy=1.0, do_3D=False, stitch_threshold=0.0): @@ -507,7 +507,8 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non if do_3D: img = np.asarray(x) yf, styles = run_3D(self.net, img, rsz=rescale, anisotropy=anisotropy, - augment=augment, tile=tile, tile_overlap=tile_overlap) + batch_size=batch_size, augment=augment, tile=tile, + tile_overlap=tile_overlap) cellprob = yf[0][-1] + yf[1][-1] + yf[2][-1] dP = np.stack( (yf[1][0] + yf[2][0], yf[0][0] + yf[2][1], yf[0][1] + yf[1][1]), @@ -521,7 +522,8 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non if rescale != 1.0: img = transforms.resize_image(img, rsz=rescale) yf, style = run_net(self.net, img, bsize=bsize, augment=augment, - tile=tile, tile_overlap=tile_overlap) + batch_size=batch_size, tile=tile, + tile_overlap=tile_overlap) if resample: yf = transforms.resize_image(yf, shape[1], shape[2]) dP = np.moveaxis(yf[..., :2], source=-1, destination=0).copy() diff --git a/cellpose/transforms.py b/cellpose/transforms.py index cbb18b05..d72a9f3b 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -114,6 +114,7 @@ def make_tiles(imgi, bsize=224, augment=False, tile_overlap=0.1): if Lx < bsize: imgi = np.concatenate((imgi, np.zeros((nchan, Ly, bsize - Lx))), axis=2) Ly, Lx = imgi.shape[-2:] + # tiles overlap by half of tile size ny = max(2, int(np.ceil(2. * Ly / bsize))) nx = max(2, int(np.ceil(2. * Lx / bsize)))