Skip to content

Commit

Permalink
Merge pull request #910 from MouseLand/otest
Browse files Browse the repository at this point in the history
branch with updates for neurips challenge
  • Loading branch information
carsen-stringer authored Apr 7, 2024
2 parents 61c1a94 + 518f92b commit 1b534ce
Show file tree
Hide file tree
Showing 19 changed files with 876 additions and 276 deletions.
84 changes: 62 additions & 22 deletions cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def main():
else:
pretrained_model = args.pretrained_model


restore_type = args.restore_type
if restore_type is not None:
try:
Expand All @@ -98,6 +99,15 @@ def main():
raise ValueError("restore_type invalid")
if args.train or args.train_size:
raise ValueError("restore_type cannot be used with training on CLI yet")

if args.transformer and (restore_type is None):
default_model = "transformer_cp3"
backbone = "transformer"
elif args.transformer and restore_type is not None:
raise ValueError("no transformer based restoration")
else:
default_model = "cyto3"
backbone = "default"

model_type = None
if pretrained_model and not os.path.exists(pretrained_model):
Expand All @@ -106,13 +116,14 @@ def main():
all_models = models.MODEL_NAMES.copy()
all_models.extend(model_strings)
if ~np.any([model_type == s for s in all_models]):
model_type = "cyto"
logger.warning("pretrained model has incorrect path")
model_type = default_model
logger.warning(f"pretrained model has incorrect path, using {default_model}")
if model_type == "nuclei":
szmean = 17.
else:
szmean = 30.
builtin_size = model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei" or model_type == "cyto3"
builtin_size = (model_type == "cyto" or model_type == "cyto2" or
model_type == "nuclei" or model_type == "cyto3")

if len(args.image_path) > 0 and (args.train or args.train_size):
raise ValueError("ERROR: cannot train model with single image input")
Expand All @@ -138,7 +149,8 @@ def main():

# handle built-in model exceptions
if builtin_size and restore_type is None:
model = models.Cellpose(gpu=gpu, device=device, model_type=model_type)
model = models.Cellpose(gpu=gpu, device=device,
model_type=model_type, backbone=backbone)
else:
builtin_size = False
if args.all_channels:
Expand All @@ -147,7 +159,8 @@ def main():
if restore_type is None:
model = models.CellposeModel(gpu=gpu, device=device,
pretrained_model=pretrained_model,
model_type=model_type)
model_type=model_type,
backbone=backbone)
else:
model = denoise.CellposeDenoiseModel(gpu=gpu, device=device,
pretrained_model=pretrained_model,
Expand Down Expand Up @@ -221,21 +234,37 @@ def main():
else:

test_dir = None if len(args.test_dir) == 0 else args.test_dir
output = io.load_train_test_data(args.dir, test_dir, imf, args.mask_filter,
args.look_one_level_down)
images, labels, image_names, test_images, test_labels, image_names_test = output
images, labels, image_names, train_probs = None, None, None, None
test_images, test_labels, image_names_test, test_probs = None, None, None, None
compute_flows = False
if len(args.file_list) > 0:
if os.path.exists(args.file_list):
dat = np.load(args.file_list, allow_pickle=True).item()
image_names = dat["train_files"]
image_names_test = dat.get("test_files", None)
train_probs = dat.get("train_probs", None)
test_probs = dat.get("test_probs", None)
compute_flows = dat.get("compute_flows", False)
load_files = False
else:
logger.critical(f"ERROR: {args.file_list} does not exist")
else:
output = io.load_train_test_data(args.dir, test_dir, imf, args.mask_filter,
args.look_one_level_down)
images, labels, image_names, test_images, test_labels, image_names_test = output
load_files = True

# training with all channels
if args.all_channels:
img = images[0]
img = images[0] if images is not None else io.imread(image_names[0])
if img.ndim == 3:
nchan = min(img.shape)
elif img.ndim == 2:
nchan = 1
channels = None
else:
nchan = 2

# model path
szmean = args.diam_mean
if not os.path.exists(pretrained_model) and model_type is None:
Expand All @@ -252,37 +281,48 @@ def main():

# initialize model
model = models.CellposeModel(
device=device,
device=device, model_type=model_type, diam_mean=szmean, nchan=nchan,
pretrained_model=pretrained_model if model_type is None else None,
model_type=model_type, diam_mean=szmean, nchan=nchan)
backbone=backbone)

# train segmentation model
if args.train:
cpmodel_path = train.train_seg(
model.net, images, labels, train_files=image_names,
test_data=test_images, test_labels=test_labels,
test_files=image_names_test, learning_rate=args.learning_rate,
weight_decay=args.weight_decay, channels=channels,
channel_axis=args.channel_axis,
save_path=os.path.realpath(args.dir), save_every=args.save_every,
test_files=image_names_test,
train_probs=train_probs, test_probs=test_probs,
compute_flows=compute_flows, load_files=load_files,
normalize=(not args.no_norm), channels=channels,
channel_axis=args.channel_axis, rgb=(nchan==3),
learning_rate=args.learning_rate, weight_decay=args.weight_decay,
SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.batch_size,
min_train_masks=args.min_train_masks,
min_train_masks=args.min_train_masks,
nimg_per_epoch=args.nimg_per_epoch,
nimg_test_per_epoch=args.nimg_test_per_epoch,
save_path=os.path.realpath(args.dir), save_every=args.save_every,
model_name=args.model_name_out)
model.pretrained_model = cpmodel_path
logger.info(">>>> model trained and saved to %s" % cpmodel_path)

# train size model
if args.train_size:
sz_model = models.SizeModel(cp_model=model, device=device)
masks = [lbl[0] for lbl in labels]
test_masks = [lbl[0] for lbl in test_labels
] if test_labels is not None else test_labels
# data has already been normalized and reshaped
sz_model.params = train.train_size(model.net, model.pretrained_model,
images, masks, test_images,
test_masks, channels=channels,
images, labels, train_files=image_names,
test_data=test_images, test_labels=test_labels,
test_files=image_names_test,
train_probs=train_probs, test_probs=test_probs,
load_files=load_files, channels=channels,
min_train_masks=args.min_train_masks,
channel_axis=args.channel_axis, rgb=(nchan==3),
nimg_per_epoch=args.nimg_per_epoch, normalize=(not args.no_norm),
nimg_test_per_epoch=args.nimg_test_per_epoch,
batch_size=args.batch_size)
if test_images is not None:
test_masks = [lbl[0] for lbl in test_labels
] if test_labels is not None else test_labels
predicted_diams, diams_style = sz_model.eval(
test_images, channels=channels)
ccs = np.corrcoef(
Expand Down
14 changes: 12 additions & 2 deletions cellpose/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_arg_parser():
input_img_args.add_argument(
"--all_channels", action="store_true", help=
"use all channels in image if using own model and images with special channels")

# model settings
model_args = parser.add_argument_group("Model Arguments")
model_args.add_argument("--pretrained_model", required=False, default="cyto",
Expand All @@ -77,7 +77,9 @@ def get_arg_parser():
model_args.add_argument(
"--add_model", required=False, default=None, type=str,
help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")

model_args.add_argument("--transformer", action="store_true",
help="use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)")

# algorithm settings
algorithm_args = parser.add_argument_group("Algorithm Arguments")
algorithm_args.add_argument(
Expand Down Expand Up @@ -171,6 +173,10 @@ def get_arg_parser():
help="train size network at end of training")
training_args.add_argument("--test_dir", default=[], type=str,
help="folder containing test data (optional)")
training_args.add_argument(
"--file_list", default=[], type=str, help=
"path to list of files for training and testing and probabilities for each image (optional)"
)
training_args.add_argument(
"--mask_filter", default="_masks", type=str, help=
"end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s"
Expand All @@ -187,6 +193,10 @@ def get_arg_parser():
help="number of epochs. Default: %(default)s")
training_args.add_argument("--batch_size", default=8, type=int,
help="batch size. Default: %(default)s")
training_args.add_argument("--nimg_per_epoch", default=None, type=int,
help="number of train images per epoch. Default is to use all train images.")
training_args.add_argument("--nimg_test_per_epoch", default=None, type=int,
help="number of test images per epoch. Default is to use all test images.")
training_args.add_argument(
"--min_train_masks", default=5, type=int, help=
"minimum number of masks a training image must have to be used. Default: %(default)s"
Expand Down
2 changes: 1 addition & 1 deletion cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1,
(faster if augment is False)
Args:
net (class): cellpose network (model.net)
imgs (np.ndarray): The input image or stack of images of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan].
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
Expand Down Expand Up @@ -240,7 +241,6 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1,

return y, style


def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0.1):
"""
Run network on tiles of size [bsize x bsize]
Expand Down
69 changes: 34 additions & 35 deletions cellpose/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
device = torch.device("cuda")

T = torch.zeros(shape, dtype=torch.double, device=device)

for i in range(n_iter):
T[meds[:, 0], meds[:, 1]] += 1
T[tuple(meds.T)] += 1
Tneigh = T[tuple(neighbors)]
Tneigh *= isneighbor
T[tuple(neighbors[:, 0])] = Tneigh.mean(axis=0)
Expand All @@ -90,11 +89,11 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
del grads
mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
else:
grads = T[:, pt[1:, :, 0], pt[1:, :, 1], pt[1:, :, 2]]
del pt
dz = grads[:, 0] - grads[:, 1]
dy = grads[:, 2] - grads[:, 3]
dx = grads[:, 4] - grads[:, 5]
grads = T[tuple(neighbors[:,1:])]
del neighbors
dz = grads[0] - grads[1]
dy = grads[2] - grads[3]
dx = grads[4] - grads[5]
del grads
mu_torch = np.stack(
(dz.cpu().squeeze(0), dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
Expand Down Expand Up @@ -161,7 +160,7 @@ def masks_to_flows_gpu(masks, device=None, niter=None):
neighborsY = torch.stack((y, y - 1, y + 1, y, y, y - 1, y - 1, y + 1, y + 1), dim=0)
neighborsX = torch.stack((x, x, x, x - 1, x + 1, x - 1, x + 1, x - 1, x + 1), dim=0)
neighbors = torch.stack((neighborsY, neighborsX), dim=0)
neighbor_masks = masks_padded[neighbors[0], neighbors[1]]
neighbor_masks = masks_padded[tuple(neighbors)]
isneighbor = neighbor_masks == neighbor_masks[0]

### get center-of-mass within cell
Expand Down Expand Up @@ -210,17 +209,17 @@ def masks_to_flows_gpu_3d(masks, device=None):
Lz0, Ly0, Lx0 = masks.shape
Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2

masks_padded = np.zeros((Lz, Ly, Lx), np.int64)
masks_padded[1:-1, 1:-1, 1:-1] = masks

masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1))
# get mask pixel neighbors
z, y, x = np.nonzero(masks_padded)
neighborsZ = np.stack((z, z + 1, z - 1, z, z, z, z))
neighborsY = np.stack((y, y, y, y + 1, y - 1, y, y), axis=0)
neighborsX = np.stack((x, x, x, x, x, x + 1, x - 1), axis=0)

neighbors = np.stack((neighborsZ, neighborsY, neighborsX), axis=-1)
z, y, x = torch.nonzero(masks_padded).T
neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z))
neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0)
neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0)

neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0)

# get mask centers
slices = find_objects(masks)

Expand All @@ -245,8 +244,7 @@ def masks_to_flows_gpu_3d(masks, device=None):
centers[i, 2] = xmed + sx.start

# get neighbor validator (not all neighbors are in same mask)
neighbor_masks = masks_padded[neighbors[:, :, 0], neighbors[:, :, 1],
neighbors[:, :, 2]]
neighbor_masks = masks_padded[tuple(neighbors)]
isneighbor = neighbor_masks == neighbor_masks[0]
ext = np.array(
[[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1]
Expand All @@ -262,7 +260,7 @@ def masks_to_flows_gpu_3d(masks, device=None):

# put into original image
mu0 = np.zeros((3, Lz0, Ly0, Lx0))
mu0[:, z - 1, y - 1, x - 1] = mu
mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
mu_c = np.zeros_like(mu0)
return mu0, mu_c

Expand Down Expand Up @@ -362,7 +360,8 @@ def masks_to_flows(masks, device=None, niter=None):
raise ValueError("masks_to_flows only takes 2D or 3D arrays")


def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None):
def labels_to_flows(labels, files=None, device=None,
redo_flows=False, niter=None, return_flows=True):
"""Converts labels (list of masks or flows) to flows for training model.
Args:
Expand All @@ -384,6 +383,7 @@ def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=Non
if labels[0].ndim < 3:
labels = [labels[n][np.newaxis, :, :] for n in range(nimg)]

flows = []
# flows need to be recomputed
if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows:
dynamics_logger.info("computing flows for labels")
Expand All @@ -392,23 +392,22 @@ def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=Non
# make sure labels are unique!
labels = [fastremap.renumber(label, in_place=True)[0] for label in labels]
iterator = trange if nimg > 1 else range
veci = [
masks_to_flows(labels[n][0].astype(int), device=device, niter=niter)
for n in iterator(nimg)
]

# concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
flows = [
np.concatenate((labels[n], labels[n] > 0.5, veci[n]),
axis=0).astype(np.float32) for n in range(nimg)
]
if files is not None:
for flow, file in zip(flows, files):
file_name = os.path.splitext(file)[0]
for n in iterator(nimg):
labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0]
vecn = masks_to_flows(labels[n][0].astype(int), device=device, niter=niter)

# concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
flow = np.concatenate((labels[n], labels[n] > 0.5, vecn),
axis=0).astype(np.float32)
if files is not None:
file_name = os.path.splitext(files[n])[0]
tifffile.imwrite(file_name + "_flows.tif", flow)
if return_flows:
flows.append(flow)
else:
dynamics_logger.info("flows precomputed")
flows = [labels[n].astype(np.float32) for n in range(nimg)]
if return_flows:
flows = [labels[n].astype(np.float32) for n in range(nimg)]
return flows


Expand Down
8 changes: 5 additions & 3 deletions cellpose/gui/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,11 @@ def _initialize_images(parent, image, load_3D=False):
c = np.array(image.shape).argmin()
image = image.transpose(((c + 1) % 3, (c + 2) % 3, c))
elif load_3D:
# assume smallest dimension is Z and put first
z = np.array(image.shape).argmin()
image = image.transpose((z, (z + 1) % 3, (z + 2) % 3))
# assume smallest dimension is Z and put first if <3x max dim
shape = np.array(image.shape)
z = shape.argmin()
if shape[z] < shape.max()/3:
image = image.transpose((z, (z + 1) % 3, (z + 2) % 3))
image = image[..., np.newaxis]
elif image.ndim == 2:
if not load_3D:
Expand Down
6 changes: 0 additions & 6 deletions cellpose/gui/menus.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ def mainmenu(parent):
file_menu.addAction(parent.saveFlows)
parent.saveFlows.setEnabled(False)

parent.saveServer = QAction("Send manually labelled data to server", parent)
parent.saveServer.triggered.connect(lambda: save_server(parent))
file_menu.addAction(parent.saveServer)
parent.saveServer.setEnabled(False)


def editmenu(parent):
main_menu = parent.menuBar()
edit_menu = main_menu.addMenu("&Edit")
Expand Down
Loading

0 comments on commit 1b534ce

Please sign in to comment.