Skip to content

Commit

Permalink
fixing bug where batch_size is not passed to core functions (#964)
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Sep 12, 2024
1 parent 209a232 commit 8bc3f62
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 22 deletions.
36 changes: 20 additions & 16 deletions cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,47 +277,51 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0
"""
nout = net.nout
if imgi.ndim == 4:
Lz, nchan = imgi.shape[:2]
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(imgi[0], bsize=bsize,
augment=augment,
tile_overlap=tile_overlap)
ny, nx, nchan, ly, lx = IMG.shape
batch_size *= max(4, (bsize**2 // (ly * lx))**0.5)
Lz, nchan, Ly, Lx = imgi.shape
if augment:
ny = max(2, int(np.ceil(2. * Ly / bsize)))
nx = max(2, int(np.ceil(2. * Lx / bsize)))
ly, lx = bsize, bsize
else:
ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize))
nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize))
ly, lx = min(bsize, Ly), min(bsize, Lx)
yf = np.zeros((Lz, nout, imgi.shape[-2], imgi.shape[-1]), np.float32)
styles = []
if ny * nx > batch_size:
ziterator = (trange(Lz, file=tqdm_out, mininterval=30)
if Lz > 1 else range(Lz))
for i in ziterator:
yfi, stylei = _run_tiled(net, imgi[i], augment=augment, bsize=bsize,
tile_overlap=tile_overlap)
batch_size=batch_size, tile_overlap=tile_overlap)
yf[i] = yfi
styles.append(stylei)
else:
# run multiple slices at the same time
ntiles = ny * nx
nimgs = max(2, int(np.round(batch_size / ntiles)))
nimgs = batch_size // ntiles # number of z-slices to run at the same time
niter = int(np.ceil(Lz / nimgs))
ziterator = (trange(niter, file=tqdm_out, mininterval=30)
if Lz > 1 else range(niter))
for k in ziterator:
IMGa = np.zeros((ntiles * nimgs, nchan, ly, lx), np.float32)
for i in range(min(Lz - k * nimgs, nimgs)):
inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs))
IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32")
for i, b in enumerate(inds):
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(
imgi[k * nimgs + i], bsize=bsize, augment=augment,
imgi[b], bsize=bsize, augment=augment,
tile_overlap=tile_overlap)
IMGa[i * ntiles:(i + 1) * ntiles] = np.reshape(
IMG, (ny * nx, nchan, ly, lx))
IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG,
(ny * nx, nchan, ly, lx))
ya, stylea = _forward(net, IMGa)
for i in range(min(Lz - k * nimgs, nimgs)):
y = ya[i * ntiles:(i + 1) * ntiles]
for i, b in enumerate(inds):
y = ya[i * ntiles : (i + 1) * ntiles]
if augment:
y = np.reshape(y, (ny, nx, 3, ly, lx))
y = transforms.unaugment_tiles(y)
y = np.reshape(y, (-1, 3, ly, lx))
yfi = transforms.average_tiles(y, ysub, xsub, Ly, Lx)
yfi = yfi[:, :imgi.shape[2], :imgi.shape[3]]
yf[k * nimgs + i] = yfi
yf[b] = yfi
stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0)
stylei /= (stylei**2).sum()**0.5
styles.append(stylei)
Expand Down
14 changes: 8 additions & 6 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
nchan=self.nchan)
if x.ndim < 4:
x = x[np.newaxis, ...]
self.batch_size = batch_size


if diameter is not None and diameter > 0:
rescale = self.diam_mean / diameter
elif rescale is None:
Expand All @@ -465,7 +464,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
masks, styles, dP, cellprob, p = self._run_cp(
x, compute_masks=compute_masks, normalize=normalize, invert=invert,
rescale=rescale, resample=resample, augment=augment, tile=tile,
tile_overlap=tile_overlap, bsize=bsize, flow_threshold=flow_threshold,
batch_size=batch_size, tile_overlap=tile_overlap, bsize=bsize, flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold, interp=interp, min_size=min_size,
max_size_fraction=max_size_fraction, do_3D=do_3D, anisotropy=anisotropy, niter=niter,
stitch_threshold=stitch_threshold)
Expand All @@ -474,7 +473,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
return masks, flows, styles

def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=None,
rescale=1.0, resample=True, augment=False, tile=True, tile_overlap=0.1,
rescale=1.0, resample=True, augment=False, tile=True,
batch_size=8, tile_overlap=0.1,
cellprob_threshold=0.0, bsize=224, flow_threshold=0.4, min_size=15,
max_size_fraction=0.4, interp=True, anisotropy=1.0, do_3D=False,
stitch_threshold=0.0):
Expand Down Expand Up @@ -507,7 +507,8 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
if do_3D:
img = np.asarray(x)
yf, styles = run_3D(self.net, img, rsz=rescale, anisotropy=anisotropy,
augment=augment, tile=tile, tile_overlap=tile_overlap)
batch_size=batch_size, augment=augment, tile=tile,
tile_overlap=tile_overlap)
cellprob = yf[0][-1] + yf[1][-1] + yf[2][-1]
dP = np.stack(
(yf[1][0] + yf[2][0], yf[0][0] + yf[2][1], yf[0][1] + yf[1][1]),
Expand All @@ -521,7 +522,8 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
if rescale != 1.0:
img = transforms.resize_image(img, rsz=rescale)
yf, style = run_net(self.net, img, bsize=bsize, augment=augment,
tile=tile, tile_overlap=tile_overlap)
batch_size=batch_size, tile=tile,
tile_overlap=tile_overlap)
if resample:
yf = transforms.resize_image(yf, shape[1], shape[2])
dP = np.moveaxis(yf[..., :2], source=-1, destination=0).copy()
Expand Down
1 change: 1 addition & 0 deletions cellpose/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def make_tiles(imgi, bsize=224, augment=False, tile_overlap=0.1):
if Lx < bsize:
imgi = np.concatenate((imgi, np.zeros((nchan, Ly, bsize - Lx))), axis=2)
Ly, Lx = imgi.shape[-2:]

# tiles overlap by half of tile size
ny = max(2, int(np.ceil(2. * Ly / bsize)))
nx = max(2, int(np.ceil(2. * Lx / bsize)))
Expand Down

0 comments on commit 8bc3f62

Please sign in to comment.