From 8c8e9d8a93bb204b0f9787037d1109b87600a8c7 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sun, 7 Apr 2024 09:59:59 -0400 Subject: [PATCH] adding train_size for file sampling --- cellpose/__main__.py | 25 +-- cellpose/core.py | 2 +- cellpose/gui/menus.py | 6 - cellpose/io.py | 41 +--- cellpose/key/cellpose-data-writer.json | 13 -- cellpose/models.py | 103 +++++---- cellpose/train.py | 4 +- cellpose/transforms.py | 1 - paper/3.0/fig_utils.py | 12 -- paper/neurips/analysis.py | 160 ++++++++++++++ paper/neurips/fig_utils.py | 37 ++++ paper/neurips/figures.py | 288 +++++++++++++++++++++++++ setup.py | 1 - tests/test_import.py | 4 +- 14 files changed, 564 insertions(+), 133 deletions(-) delete mode 100644 cellpose/key/cellpose-data-writer.json create mode 100644 paper/neurips/analysis.py create mode 100644 paper/neurips/fig_utils.py create mode 100644 paper/neurips/figures.py diff --git a/cellpose/__main__.py b/cellpose/__main__.py index c48c47c6..5e301513 100644 --- a/cellpose/__main__.py +++ b/cellpose/__main__.py @@ -236,14 +236,15 @@ def main(): test_dir = None if len(args.test_dir) == 0 else args.test_dir images, labels, image_names, train_probs = None, None, None, None test_images, test_labels, image_names_test, test_probs = None, None, None, None + compute_flows = False if len(args.file_list) > 0: if os.path.exists(args.file_list): dat = np.load(args.file_list, allow_pickle=True).item() image_names = dat["train_files"] - image_names_test = dat["test_files"] - if "train_probs" in dat: - train_probs = dat["train_probs"] - test_probs = dat["test_probs"] + image_names_test = dat.get("test_files", None) + train_probs = dat.get("train_probs", None) + test_probs = dat.get("test_probs", None) + compute_flows = dat.get("compute_flows", False) load_files = False else: logger.critical(f"ERROR: {args.file_list} does not exist") @@ -263,7 +264,7 @@ def main(): channels = None else: nchan = 2 - + # model path szmean = args.diam_mean if not os.path.exists(pretrained_model) and model_type is None: @@ -291,14 +292,15 @@ def main(): test_data=test_images, test_labels=test_labels, test_files=image_names_test, train_probs=train_probs, test_probs=test_probs, - load_files=load_files, learning_rate=args.learning_rate, - weight_decay=args.weight_decay, channels=channels, + compute_flows=compute_flows, load_files=load_files, + normalize=(not args.no_norm), channels=channels, channel_axis=args.channel_axis, rgb=(nchan==3), - save_path=os.path.realpath(args.dir), save_every=args.save_every, + learning_rate=args.learning_rate, weight_decay=args.weight_decay, SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.batch_size, min_train_masks=args.min_train_masks, - nimg_per_epoch=args.nimg_per_epoch, normalize=(not args.no_norm), + nimg_per_epoch=args.nimg_per_epoch, nimg_test_per_epoch=args.nimg_test_per_epoch, + save_path=os.path.realpath(args.dir), save_every=args.save_every, model_name=args.model_name_out) model.pretrained_model = cpmodel_path logger.info(">>>> model trained and saved to %s" % cpmodel_path) @@ -306,9 +308,6 @@ def main(): # train size model if args.train_size: sz_model = models.SizeModel(cp_model=model, device=device) - masks = [lbl[0] for lbl in labels] - test_masks = [lbl[0] for lbl in test_labels - ] if test_labels is not None else test_labels # data has already been normalized and reshaped sz_model.params = train.train_size(model.net, model.pretrained_model, images, labels, train_files=image_names, @@ -322,6 +321,8 @@ def main(): nimg_test_per_epoch=args.nimg_test_per_epoch, batch_size=args.batch_size) if test_images is not None: + test_masks = [lbl[0] for lbl in test_labels + ] if test_labels is not None else test_labels predicted_diams, diams_style = sz_model.eval( test_images, channels=channels) ccs = np.corrcoef( diff --git a/cellpose/core.py b/cellpose/core.py index 822b8eb3..746c1d38 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -188,6 +188,7 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1, (faster if augment is False) Args: + net (class): cellpose network (model.net) imgs (np.ndarray): The input image or stack of images of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan]. batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8. rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0. @@ -240,7 +241,6 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1, return y, style - def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0.1): """ Run network on tiles of size [bsize x bsize] diff --git a/cellpose/gui/menus.py b/cellpose/gui/menus.py index 202dbacf..36ae7f14 100644 --- a/cellpose/gui/menus.py +++ b/cellpose/gui/menus.py @@ -70,12 +70,6 @@ def mainmenu(parent): file_menu.addAction(parent.saveFlows) parent.saveFlows.setEnabled(False) - parent.saveServer = QAction("Send manually labelled data to server", parent) - parent.saveServer.triggered.connect(lambda: save_server(parent)) - file_menu.addAction(parent.saveServer) - parent.saveServer.setEnabled(False) - - def editmenu(parent): main_menu = parent.menuBar() edit_menu = main_menu.addMenu("&Edit") diff --git a/cellpose/io.py b/cellpose/io.py index 2abcd6b4..776d1cfd 100644 --- a/cellpose/io.py +++ b/cellpose/io.py @@ -744,43 +744,4 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[ imsave(os.path.join(flowdir, basename + "_flows" + suffix + ".tif"), (flows[0] * (2**16 - 1)).astype(np.uint16)) #save full flow data - imsave(os.path.join(flowdir, basename + "_dP" + suffix + ".tif"), flows[1]) - - -def save_server(parent=None, filename=None): - """ Uploads a *_seg.npy file to the bucket. - - Args: - parent (PyQt.MainWindow, optional): GUI window to grab file info from. Defaults to None. - filename (str, optional): if no GUI, send this file to server. Defaults to None. - """ - if parent is not None: - q = QMessageBox.question( - parent, "Send to server", - "Are you sure? Only send complete and fully manually segmented data.\n (do not send partially automated segmentations)", - QMessageBox.Yes | QMessageBox.No) - if q != QMessageBox.Yes: - return - else: - filename = parent.filename - - if filename is not None: - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "key/cellpose-data-writer.json") - bucket_name = "cellpose_data" - base = os.path.splitext(filename)[0] - source_file_name = base + "_seg.npy" - io_logger.info(f"sending {source_file_name} to server") - time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S.%f") - filestring = time + ".npy" - io_logger.info(f"name on server: {filestring}") - destination_blob_name = filestring - storage_client = storage.Client() - bucket = storage_client.bucket(bucket_name) - blob = bucket.blob(destination_blob_name) - - blob.upload_from_filename(source_file_name) - - io_logger.info("File {} uploaded to {}.".format(source_file_name, - destination_blob_name)) + imsave(os.path.join(flowdir, basename + "_dP" + suffix + ".tif"), flows[1]) \ No newline at end of file diff --git a/cellpose/key/cellpose-data-writer.json b/cellpose/key/cellpose-data-writer.json deleted file mode 100644 index 3feaf95b..00000000 --- a/cellpose/key/cellpose-data-writer.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "type": "service_account", - "project_id": "pachitariu-lab", - "private_key_id": "a112e329e705a5c51bfbba5277f52116e4bfc1af", - "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCsmcCIGjXqyWsl\n++lxWxzSL5QWQTPkIyP+j99gFJQ1c/Z4PYC7YrtIdfu/PsLwqBAvOHJCsrq0m+y0\nB2vqUJJ2mcvKQaWygAWKwoikFUI1AtjkH9GnSHxXmTH9WqXBw2pQv1xBfOvoMptJ\n8ArXyAMyBzSYN0xyOqEQEsrSzqEUvZw467YZmyf24ZhjFXze0Rrn1znkQprMB8GK\nPz7Vr6iiRPd8mtjzjS2WC0Fat+RIhh0o5lb93woGoFbWAhvuK8xadGOYP80Mcr/b\nsgFM1D2k/HHdjJiHRe7jwRMIRhA5478QwoBNN00W/zEgDV+8SSl9rs3yUAUJG9mf\nEHkdlzexAgMBAAECggEAUFbMyE0y9ZNFfYuxUGMxmiAtVOKKrdExiuca+VT625qb\nicJO7mn5dLP+NzmWcYA48FHc1XDt+O1vEyk1MP7J/cx+kClYYCq46aq9AWsnwxcN\nL7oj0zKpNfkHzL7p0rQMA4PfBFiKUi1kHNlPorrlyd6Su5tZyP3DRIEKyW8GiWko\n19DEcBUhe2uEW+claFSiy6fYfXFXtzYln8mWKAWjOxw6LcQBc6KRMdYh79d09/bj\nthnnNeMLK6FSiKTXuT/a84qzxNkj549H0ILwolVKn1vPe35WV7ZJJpugYvnbek5s\nZjwN2fMUa/GD9agbfS4LXbnC4hYwUnLhAn2zSEjlBQKBgQDWxwNoo0CuEB8s+VCN\n9J80vqTQfXSnSqoAS0BaDaivGoNUHtRGjXs+Qqrzlikk9032s9eRtX+q7uhR7wIo\nIilO1GhVCr+OYPWTBoJ/ARgDrhqgXPjyQ97wSDz9PHloFnGHf+HFcqzhlhNGaoS9\nQaZO3nhx3Du7YV8kaefm0iODnwKBgQDNumbAi0L9jTDguOtD/3w32MuodQfL35Ok\nRvdkQphZL2jFqiztlwVU11SOibnVc7ZAjFt4tPSS0veFQmC8JfnpdYSDk1NodwMG\nhC3UTyjrvIltEmPzSQczRMK5PCRZnvP6SBKnUSEsAk+bZO1nyFNPNNmn5J/LQ+24\n2MGAgxcCrwKBgB0pFBtm3udDJRh0GS3M4rjEkZgFEIuOJZq4nNodNKPhk6ceMHAL\n0YnYf2FnJ9rvANTYAhK0c8r/eOd27fIJAVbEnA2/0dZA79awcZNQ0LPfNZpERUCP\nWnuBM1ammU06jtt4z2yBb1uJhsBuwer4ON5Ick3zOuDsDYDiKCw8p7m9AoGBAIpp\nuAohaA/pN5JqN7eHI877KIKNQpKTOOVU7cth1thiQk6DITk021xqh7Riy0nmUR96\nj2xV6xsBn5DjyOutbUf6Tg6sR3jIYZu3wJHQNIruTVO6BM9BOfvvbkdsRFSb0jB4\n3zv9JKFUaLT3IZcqu4pV137THgOHD2DHTOEm0Yt3AoGBALMOKoajLFTBT+9bEpXZ\nB6ES4W1KNPKw+1Y0n2gGJxL+zAXFm1MJozJbyZSwZHb8TyOMNqP59+tzh080oUoc\nlMWJzS4xxuGx+JAULtDT4ko3+3Q19H3dyJNAW9SoY1lX47JMrEB1qYLNx7o78nzB\niCWxdweJjlpKijcUP9keCmIW\n-----END PRIVATE KEY-----\n", - "client_email": "cellpose-data-writer@pachitariu-lab.iam.gserviceaccount.com", - "client_id": "105326682635824364397", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/cellpose-data-writer%40pachitariu-lab.iam.gserviceaccount.com" -} - diff --git a/cellpose/models.py b/cellpose/models.py index c77a740f..3cff3f49 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -25,7 +25,8 @@ MODEL_NAMES = [ "cyto3", "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "cyto2", "cyto", - "transformer_cp3" + "transformer_cp3", "neurips_cellpose_default", "neurips_cellpose_transformer", + "neurips_grayscale_cyto2" ] MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt")) @@ -43,16 +44,13 @@ } def model_path(model_type, model_index=0): - if not os.path.exists(model_type): - torch_str = "torch" - if model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei": - basename = "%s%s_%d" % (model_type, torch_str, model_index) - else: - basename = model_type - return cache_model_path(basename) + torch_str = "torch" + if model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei": + basename = "%s%s_%d" % (model_type, torch_str, model_index) else: - return model_type - + basename = model_type + return cache_model_path(basename) + def size_model_path(model_type): if os.path.exists(model_type): return model_type + "_size.npy" @@ -139,7 +137,7 @@ def __init__(self, gpu=False, model_type="cyto3", nchan=2, self.sz.model_type = model_type def eval(self, x, batch_size=8, channels=[0,0], channel_axis=None, invert=False, - normalize=True, diameter=30., do_3D=False, **kwargs): + normalize=True, diameter=30., do_3D=False, find_masks=True, **kwargs): """Run cellpose size model and mask model and get masks. Args: @@ -197,14 +195,12 @@ def eval(self, x, batch_size=8, channels=[0,0], channel_axis=None, invert=False, else: diams = diameter - tic = time.time() models_logger.info("~~~ FINDING MASKS ~~~") masks, flows, styles = self.cp.eval(x, channels=channels, channel_axis=channel_axis, batch_size=batch_size, normalize=normalize, invert=invert, diameter=diams, do_3D=do_3D, **kwargs) - models_logger.info(">>>> TOTAL TIME %0.2f sec" % (time.time() - tic0)) return masks, flows, styles, diams @@ -250,37 +246,47 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, nchan (int, optional): Number of channels to use as input to network, default is 2 (cyto + nuclei) or (nuclei + zeros). """ self.diam_mean = diam_mean - builtin = True + + ### set model path default_model = "cyto3" if backbone=="default" else "transformer_cp3" - if model_type is not None or (pretrained_model and - not os.path.exists(pretrained_model)): - pretrained_model_string = model_type if model_type is not None else default_model - model_strings = get_user_models() - all_models = MODEL_NAMES.copy() - all_models.extend(model_strings) - if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]): - builtin = False - if (not os.path.exists(pretrained_model_string) and - ~np.any([pretrained_model_string == s for s in all_models])): - pretrained_model_string = default_model - models_logger.warning("model_type does not exist / has incorrect path") - - if (pretrained_model and not os.path.exists(pretrained_model)): - models_logger.warning("pretrained model has incorrect path") - models_logger.info(f">> {pretrained_model_string} << model set to be used") - - if pretrained_model_string == "nuclei": + builtin = False + use_default = False + model_strings = get_user_models() + all_models = MODEL_NAMES.copy() + all_models.extend(model_strings) + + # check if pretrained_model is builtin or custom user model saved in .cellpose/models + # if yes, then set to model_type + if (pretrained_model and not Path(pretrained_model).exists() and + np.any([pretrained_model == s for s in all_models])): + model_type = pretrained_model + + # check if model_type is builtin or custom user model saved in .cellpose/models + if model_type is not None and np.any([model_type == s for s in all_models]): + if np.any([model_type == s for s in MODEL_NAMES]): + builtin = True + models_logger.info(f">> {model_type} << model set to be used") + if model_type == "nuclei": self.diam_mean = 17. + pretrained_model = model_path(model_type) + # if model_type is not None and does not exist, use default model + elif model_type is not None: + if Path(model_type).exists(): + pretrained_model = model_type else: - self.diam_mean = 30. - pretrained_model = model_path(pretrained_model_string) + models_logger.warning("model_type does not exist, using default model") + use_default = True + # if model_type is None... else: - builtin = False - if pretrained_model: - pretrained_model_string = pretrained_model - models_logger.info(f">>>> loading model {pretrained_model_string}") - - # assign network device + # if pretrained_model does not exist, use default model + if pretrained_model and not Path(pretrained_model).exists(): + models_logger.warning("pretrained_model path does not exist, using default model") + use_default = True + + builtin = True if use_default else builtin + self.pretrained_model = model_path(default_model) if use_default else pretrained_model + + ### assign model device self.mkldnn = None if device is None: sdevice, gpu = assign_device(use_torch=True, gpu=gpu) @@ -291,12 +297,11 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, if not self.gpu: self.mkldnn = check_mkl(True) - # create network + ### create neural network self.nchan = nchan self.nclasses = 3 nbase = [32, 64, 128, 256] self.nbase = [nchan, *nbase] - self.pretrained_model = pretrained_model if backbone=="default": self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn, @@ -306,7 +311,9 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, self.net = Transformer(encoder_weights="imagenet" if not self.pretrained_model else None, diam_mean=diam_mean).to(self.device) + ### load model weights if self.pretrained_model: + models_logger.info(f">>>> loading model {pretrained_model}") self.net.load_model(self.pretrained_model, device=self.device) if not builtin: self.diam_mean = self.net.diam_mean.data.cpu().numpy()[0] @@ -318,6 +325,9 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, models_logger.info( f">>>> model diam_labels = {self.diam_labels: .3f} (mean diameter of training ROIs)" ) + else: + models_logger.info(f">>>> no model weights loaded") + self.diam_labels = self.diam_mean self.net_type = f"cellpose_{backbone}" @@ -382,12 +392,14 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, """ if isinstance(x, list) or x.squeeze().ndim == 5: + self.timing = [] masks, styles, flows = [], [], [] tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) nimg = len(x) iterator = trange(nimg, file=tqdm_out, mininterval=30) if nimg > 1 else range(nimg) for i in iterator: + tic = time.time() maski, flowi, stylei = self.eval( x[i], batch_size=batch_size, channels=channels[i] if channels is not None and @@ -409,6 +421,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, masks.append(maski) flows.append(flowi) styles.append(stylei) + self.timing.append(time.time() - tic) return masks, flows, styles else: @@ -499,7 +512,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non if rescale != 1.0: img = transforms.resize_image(img, rsz=rescale) yf, style = run_net(self.net, img, bsize=bsize, augment=augment, - tile=tile, tile_overlap=tile_overlap) + tile=tile, tile_overlap=tile_overlap) if resample: yf = transforms.resize_image(yf, shape[1], shape[2]) @@ -661,14 +674,15 @@ def eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False - diam (np.ndarray): Final estimated diameters from images x or styles style after running both steps. - diam_style (np.ndarray): Estimated diameters from style alone. """ - if isinstance(x, list): + self.timing = [] diams, diams_style = [], [] nimg = len(x) tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) iterator = trange(nimg, file=tqdm_out, mininterval=30) if nimg > 1 else range(nimg) for i in iterator: + tic = time.time() diam, diam_style = self.eval( x[i], channels=channels[i] if (channels is not None and len(channels) == len(x) and @@ -679,6 +693,7 @@ def eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False batch_size=batch_size, progress=progress) diams.append(diam) diams_style.append(diam_style) + self.timing.append(time.time() - tic) return diams, diams_style diff --git a/cellpose/train.py b/cellpose/train.py index aa70d656..8c86ebb8 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -192,7 +192,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files ] train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)] - if test_data is not None or test_files is not None and test_labels_files is None: + if (test_data is not None or test_files is not None) and test_labels_files is None: test_labels_files = [ os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files ] @@ -415,7 +415,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, LR = LR[:-50] for i in range(10): LR = np.append(LR, LR[-1] / 2 * np.ones(5)) - n_epochs = len(LR) + LR = LR train_logger.info(f">>> n_epochs={n_epochs}, n_train={nimg}, n_test={nimg_test}") if not SGD: diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 516b5a06..402d0eed 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -38,7 +38,6 @@ def _taper_mask(ly=224, lx=224, sig=7.5): bsize // 2 - lx // 2:bsize // 2 + lx // 2 + lx % 2] return mask - def unaugment_tiles(y): """Reverse test-time augmentations for averaging (includes flipping of flowsY and flowsX). diff --git a/paper/3.0/fig_utils.py b/paper/3.0/fig_utils.py index 474a8c84..3dc0882d 100644 --- a/paper/3.0/fig_utils.py +++ b/paper/3.0/fig_utils.py @@ -10,17 +10,6 @@ from matplotlib.colors import ListedColormap from cellpose import utils -cmap_emb = ListedColormap(plt.get_cmap("gist_ncar")(np.linspace(0.05, 0.95), 100)) - -kp_colors = np.array([ - [0.55, 0.55, 0.55], - [0., 0., 1], - [0.8, 0, 0], - [1., 0.4, 0.2], - [0, 0.6, 0.4], - [0.2, 1, 0.5], -]) - default_font = 12 rcParams["font.family"] = "Arial" rcParams["savefig.dpi"] = 300 @@ -34,7 +23,6 @@ fs_title = 16 weight_title = "normal" - def plot_label(ltr, il, ax, trans, fs_title=20): ax.text( 0.0, diff --git a/paper/neurips/analysis.py b/paper/neurips/analysis.py new file mode 100644 index 00000000..7d55c446 --- /dev/null +++ b/paper/neurips/analysis.py @@ -0,0 +1,160 @@ +import os +import numpy as np +from cellpose import io, transforms, utils, models, dynamics, metrics, resnet_torch, denoise +from natsort import natsorted +from pathlib import Path +from glob import glob + +from cellpose.io import logger_setup + +def prediction_test_hidden(root): + """ root is path to Hidden folder """ + root = Path(root) + + logger_setup() + # path to images + fall = natsorted(glob((root / "images" / "*").as_posix())) + img_files = [f for f in fall if "_masks" not in f and "_flows" not in f] + + # load images + imgs = [io.imread(f) for f in img_files] + nimg = len(imgs) + + # for 3 channel model, normalize images and convert to 3 channels if needed + imgs_norm = [] + for img in imgs: + if img.ndim==2: + img = np.tile(img[:,:,np.newaxis], (1,1,3)) + img = transforms.normalize_img(img, axis=-1) + imgs_norm.append(img.transpose(2,0,1)) + + dat = {} + for mtype in ["default", "transformer"]: + if mtype=="default": + model = models.Cellpose(gpu=True, nchan=3, model_type="neurips_cellpose_default") + channels = None + normalize = False + diams = None # Cellpose will estimate diameter + elif mtype=="transformer": + model = models.CellposeModel(gpu=True, nchan=3, model_type="neurips_cellpose_transformer", backbone="transformer") + channels = None + normalize = False + diams = dat["diams_pred"] # (use diams from Cellpose default model for transformer) + + out = model.eval(imgs_norm, diameter=diams, + channels=channels, normalize=normalize, + tile_overlap=0.6, augment=True) + # predicted masks + dat[mtype] = out[0] + + if mtype=="default": + diams = out[-1] + dat["diams_pred"] = diams + dat[f"{mtype}_size_timing"] = model.sz.timing + dat[f"{mtype}_mask_timing"] = model.cp.timing + else: + dat[f"{mtype}_mask_timing"] = model.timing + + np.savez_compressed("neurips_test_results.npz", dat) + +def prediction_tuning(root, root2=None): + """ root is path to Tuning folder, root2 is path to mediar results """ + root = Path(root) + logger_setup() + + # path to images and masks + fall = natsorted(glob((root / "images" / "*").as_posix())) + # (exclude last image) + img_files = [f for f in fall if "_masks" not in f and "_flows" not in f][:-1] + mask_files = natsorted(glob((root / "labels" / "*").as_posix()))[:-1] + + # load images and masks + imgs = [io.imread(f) for f in img_files] + masks = [io.imread(f) for f in mask_files] + nimg = len(imgs) + + # for 3 channel model, normalize images and convert to 3 channels if needed + imgs_norm = [] + for img in imgs: + if img.ndim==2: + img = np.tile(img[:,:,np.newaxis], (1,1,3)) + img = transforms.normalize_img(img, axis=-1) + imgs_norm.append(img.transpose(2,0,1)) + + dat = {} + + ### RUN MODELS + model_types = ["grayscale", "default", "transformer", "maetal", "mediar"] + for mtype in model_types[:-1]: + print(mtype) + if mtype=="grayscale" or mtype=="maetal": + if mtype=="grayscale": + model = models.CellposeModel(gpu=True, model_type="neurips_grayscale_cyto2") + else: + ### need to download cellpose model from Ma et al + # https://github.com/JunMa11/NeurIPS-CellSeg/tree/main/cellpose-omnipose-KIT-GE + pretrained_model = "/home/carsen/Downloads/model.501776_epoch_499" + if not os.path.exists(pretrained_model): + print("need to download cellpose model from Ma et al; https://github.com/JunMa11/NeurIPS-CellSeg/tree/main/cellpose-omnipose-KIT-GE") + print("skipping Ma et al model") + del model_types[-2] + break + model = models.CellposeModel(gpu=True, pretrained_model=pretrained_model) + channels = [0, 0] + normalize = True + diams = None # CellposeModel will use mean diameter from training set + elif mtype=="default": + model = models.Cellpose(gpu=True, nchan=3, model_type="neurips_cellpose_default") + channels = None + normalize = False + diams = None # Cellpose will estimate diameter + elif mtype=="transformer": + model = models.CellposeModel(gpu=True, nchan=3, model_type="neurips_cellpose_transformer", backbone="transformer") + channels = None + normalize = False + diams = dat["diams_pred"] # (use diams from Cellpose default model for transformer) + + out = model.eval(imgs if mtype=="grayscale" else imgs_norm, diameter=diams, + channels=channels, normalize=normalize, + tile_overlap=0.6, augment=True) + if mtype=="default": + diams = out[-1] + dat["diams_pred"] = diams + + dat[mtype] = out[0] + + ### load Mediar results + if root2 is not None: + root2 = Path(root2) + masks_pred_mediar = [] + for imgf in img_files: + maskf = root2 / (os.path.splitext(os.path.split(imgf)[-1])[0] + "_label.tiff") + m = io.imread(maskf) + m = np.unique(m, return_inverse=True)[1].reshape(m.shape) + masks_pred_mediar.append(m) + + dat["mediar"] = masks_pred_mediar + else: + print("no path to mediar files specified") + print("skipping mediar") + del model_types[-1] + + ### EVALUATION + thresholds = np.arange(0.5, 1.05, 0.05) + dat["thresholds"] = thresholds + masks_true = [lbl.astype("uint32") for lbl in masks] + for mtype in model_types: + print(mtype) + masks_pred = dat[mtype] + ap, tp, fp, fn = metrics.average_precision(masks_true, masks_pred, threshold=thresholds) + f1 = 2 * tp / (2 * tp + fp + fn) + print(f"{mtype}, F1 score @ 0.5 = {np.median(f1[:,0]):.3f}") + + dat[mtype+"_f1"] = f1 + dat[mtype+"_tp"] = tp + dat[mtype+"_fp"] = fp + dat[mtype+"_fn"] = fn + + np.savez_compressed("neurips_eval_results.npz", dat) + + return imgs_norm, masks, dat \ No newline at end of file diff --git a/paper/neurips/fig_utils.py b/paper/neurips/fig_utils.py new file mode 100644 index 00000000..48070f3d --- /dev/null +++ b/paper/neurips/fig_utils.py @@ -0,0 +1,37 @@ +""" +Copyright © 2024 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" +import string +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.transforms as mtransforms +import numpy as np +from matplotlib import rcParams +from matplotlib.colors import ListedColormap +from cellpose import utils + +default_font = 12 +rcParams["font.family"] = "Arial" +rcParams["savefig.dpi"] = 300 +rcParams["axes.spines.top"] = False +rcParams["axes.spines.right"] = False +rcParams["axes.titlelocation"] = "left" +rcParams["axes.titleweight"] = "normal" +rcParams["font.size"] = default_font + +ltr = string.ascii_lowercase +fs_title = 16 +weight_title = "normal" + +def plot_label(ltr, il, ax, trans, fs_title=20): + ax.text( + 0.0, + 1.0, + ltr[il], + transform=ax.transAxes + trans, + va="bottom", + fontsize=fs_title, + fontweight="bold", + ) + il += 1 + return il \ No newline at end of file diff --git a/paper/neurips/figures.py b/paper/neurips/figures.py new file mode 100644 index 00000000..3d4bca8a --- /dev/null +++ b/paper/neurips/figures.py @@ -0,0 +1,288 @@ + +from fig_utils import * +import matplotlib.patheffects as pe + +def fig1(imgs_norm, masks_true, dat, timings, save_fig=False): + fig = plt.figure(figsize=(14, 5.5)) + thresholds = dat["thresholds"] + grid = plt.GridSpec(2, 5, figure=fig, left=0.05, right=0.98, top=0.94, bottom=0.09, + wspace=0.4, hspace=0.6) + il = 0 + transl = mtransforms.ScaledTranslation(-20 / 72, 7 / 72, fig.dpi_scale_trans) + transl1 = mtransforms.ScaledTranslation(-38 / 72, 7 / 72, fig.dpi_scale_trans) + + iex = 54 + img0 = np.clip(imgs_norm[iex].copy().transpose(1,2,0), 0, 1) + xlim = [300, 660] + ylim = [350, 650] + + cols = {"grayscale": [0.5, 0.5, 1], + "maetal": "b", + "default": "g", + "mediar": "r", + "transformer": [0,1,0]} + titles = {"grayscale": "Cellpose (impaired)", + "maetal": "Cellpose (Ma et al)", + "default": "Cellpose (default)", + "mediar": "Mediar", + "transformer": "Cellpose (transformer)"} + + ax = plt.subplot(grid[0,0]) + pos = ax.get_position().bounds + ax.set_position([pos[0] - 0.03, pos[1], pos[2] + 0.035, pos[3]]) + ax.imshow(img0)#, aspect="auto") + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.axis("off") + il = plot_label(ltr, il, ax, transl, fs_title) + ax.set_title("Example validation image") + + maskk = masks_true[iex].copy() + outlines_gt = utils.outlines_list(maskk, multiprocessing=False) + + pltmasks = [(0, 1, "maetal"), + (0, 2, "default"), + (1, 0, "mediar"), + (1, 1, "transformer"), + ] + + for k,pltmask in enumerate(pltmasks): + ax = plt.subplot(grid[pltmask[0], pltmask[1]]) + pos = ax.get_position().bounds + ax.set_position([pos[0] - 0.03, pos[1], pos[2] + 0.035, pos[3]]) + il = plot_label(ltr, il, ax, transl, fs_title) + ax.imshow(img0) + maskk = dat[pltmask[2]][iex].copy() + outlines = utils.outlines_list(maskk, multiprocessing=False) + for o in outlines_gt: + ax.plot(o[:, 0], o[:, 1], color=[0.7,0.4,1], lw=1, ls="-") + for o in outlines: + ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--") + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.axis("off") + if k==0: + ax.set_title("Cellpose (Ma et al, 2024)", color=cols[pltmask[2]]) + else: + ax.set_title(titles[pltmask[2]], color=cols[pltmask[2]]) + if k==0: + ax.text(-0.1, -0.1, "ground-truth", color=[0.7, 0.4, 1], transform=ax.transAxes, + ha="left", fontweight="normal", fontsize="large", + path_effects=[pe.withStroke(linewidth=1, foreground="k")]) + ax.text(-0.1, -0.22, "model", color=[1, 1, 0.3], transform=ax.transAxes, + ha="left", fontweight="normal", fontsize="large", + path_effects=[pe.withStroke(linewidth=1, foreground="k")]) + f1 = dat[pltmask[2]+"_f1"][iex, 0] + ax.text(1, -0.1, f"F1@0.5 = {f1:.2f}", transform=ax.transAxes, ha="right") + + ax = plt.subplot(grid[1,2]) + il = plot_label(ltr, il, ax, transl1, fs_title) + mtypes = ["default", "transformer", "mediar"] + dx = 0.4 + for k, mtype in enumerate(mtypes): + tsec = timings[:,k] + vp = ax.violinplot(tsec, positions=k*np.ones(1), bw_method=0.1, + showextrema=False, showmedians=False)#, quantiles=[[0.25, 0.5, 0.75]]) + ax.plot(dx*np.arange(-1, 2, 2) + k, + np.median(tsec) * np.ones(2), + color=cols[mtype]) + vp["bodies"][0].set_facecolor(cols[mtype]) + ax.text(k+0.2 if k>0 else k-0.1, -1, titles[mtype].replace(" (", "\n("), + color=cols[mtype], rotation=0, + va="top", ha="center") + ax.set_xticklabels([]) + ax.text(-0.1, 1.05, "Test set runtimes", + fontsize="large", transform=ax.transAxes) + ax.set_ylabel("runtime per image (sec.)") + + ax = plt.subplot(grid[:2, 3]) + il = plot_label(ltr, il, ax, transl1, fs_title) + f1s = np.array([[0.8612, 0.8346, 0.7976, 0.7013, 0.4116], + [0.8484, 0.8190, 0.7761, 0.6744, 0.3907], + [0.8263, 0.7903, 0.7371, 0.6063, 0.2911], + ]) + mtypes = ["default", "transformer", "mediar"] + for k, mtype in enumerate(mtypes): + ax.plot(np.arange(0.5, 1, 0.1), f1s[k], color=cols[mtype], lw=3) + ax.text(0.1, 0.5-k*0.13 if k<2 else 0.5-k*0.13+0.05, titles[mtype].replace(" (", "\n("), + color=cols[mtype], fontsize="large", transform=ax.transAxes) + ax.set_ylim([0, 0.9]) + ax.set_xlim([0.49, 0.91]) + ax.set_xticks([0.5, 0.7, 0.9]) + ax.set_xlabel("IoU threshold") + ax.set_ylabel("F1 score") + ax.set_title("Test set results") + + mtypes = ["default", "transformer", "mediar", "maetal", "grayscale"] + dx = 0.3 + stype = "f1" + ax = plt.subplot(grid[0,4]) + for k, mtype in enumerate(mtypes): #enumerate(model_types): + score = dat[f"{mtype}_{stype}"][:,0] + vp = ax.violinplot(score, positions=k*np.ones(1), bw_method=0.1, + showextrema=False, showmedians=False)#, quantiles=[[0.25, 0.5, 0.75]]) + ax.plot(dx*np.arange(-1, 2, 2) + k, + np.median(score) * np.ones(2), + color=cols[mtype]) + vp["bodies"][0].set_facecolor(cols[mtype]) + ax.text(k+0.2, -0.06, titles[mtype].replace(" (", "\n("), + color=cols[mtype], rotation=90, + va="top", ha="center") + ax.text(-0.1, 1.05, "Validation set scores", + fontsize="large", transform=ax.transAxes) + il = plot_label(ltr, il, ax, transl1, fs_title) + ax.set_ylabel("F1 score @ 0.5 IoU") + ax.set_xticks(np.arange(len(mtypes))) + ax.set_yticks(np.arange(0, 1.1, 0.2)) + ax.set_ylim([-0.01, 1.01]) + ax.set_xticklabels([]) + + ax = plt.subplot(grid[1,4]) + for k, mtype in enumerate(mtypes): + ax.errorbar(thresholds, np.median(dat[f"{mtype}_f1"], axis=0), + dat[f"{mtype}_f1"].std(axis=0) / ((dat[f"{mtype}_f1"].shape[0]-1)**0.5), + color=cols[mtype], lw=2, #if mtype=="grayscale" else 1, + ls="--" if mtype=="transformer" else "-", zorder=30 if mtype=="maetal" else 0) + ax.set_ylabel("F1 score") + ax.set_xlabel("IoU threshold") + ax.set_yticks(np.arange(0, 1.1, 0.2)) + ax.set_ylim([-0.01, 1.01]) + ax.set_xlim([0.49, 1.01]) + ax.set_xticks([0.5, 0.75, 1.0]) + + if save_fig: + fig.savefig("figs/fig1_neurips.pdf", dpi=100) + +def fig2(imgs_norm, masks_true, dat, type_names, types, emb, emb_test, save_fig=False): + ids = [0, 3, 56, 55, 81, 75] + + outlines_gt = [utils.outlines_list(masks_true[iex], multiprocessing=False) for iex in ids] + outlines_cp = [utils.outlines_list(dat["default"][iex], multiprocessing=False) for iex in ids] + outlines_m = [utils.outlines_list(dat["mediar"][iex], multiprocessing=False) for iex in ids] + + fig = plt.figure(figsize=(14,10)) + grid = plt.GridSpec(4, 6, figure=fig, left=0.025, right=0.98, top=0.97, bottom=0.04, + wspace=0.1, hspace=0.2) + il = 0 + transl = mtransforms.ScaledTranslation(-20 / 72, 14 / 72, fig.dpi_scale_trans) + transl1 = mtransforms.ScaledTranslation(-18 / 72, 7 / 72, fig.dpi_scale_trans) + + ylims = [[0, 500], [1550, 1950], [450, 700], [250, 500], [300, 700], [400, 800]] + xlims = [[0, 600], [500, 900], [300, 550], [100, 350], [0, 400], [200, 600]] + + for j in range(len(ids)): + iex = ids[j] + img0 = np.clip(imgs_norm[iex].transpose(1,2,0).copy(), 0, 1) + + ax = plt.subplot(grid[0,j]) + maskk = dat["default"][iex].copy() + ax.imshow(img0) + for o in outlines_gt[j]: + ax.plot(o[:, 0], o[:, 1], color=[0.7,0.4,1], lw=2, ls="-", rasterized=True) + for o in outlines_cp[j]: + ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--", rasterized=True) + ax.set_ylim(ylims[j]) + ax.set_xlim(xlims[j]) + ax.axis("off") + f1 = dat["default_f1"][iex,0] + ax.text(1, -0.1, f"F1@0.5 = {f1:.2f}", transform=ax.transAxes, ha="right") + if j==0: + ax.text(-0.1, 0.5, "Cellpose (default)", rotation=90, va="center", transform=ax.transAxes) + ax.set_title("Example validation images", y=1.07) + il = plot_label(ltr, il, ax, transl, fs_title) + ax.text(-0.1, -0.18, "ground-truth", color=[0.7, 0.4, 1], transform=ax.transAxes, + ha="left", fontweight="normal", fontsize="large", + path_effects=[pe.withStroke(linewidth=1, foreground="k")]) + ax.text(-0.1, -0.3, "model", color=[1, 1, 0.3], transform=ax.transAxes, + ha="left", fontweight="normal", fontsize="large", + path_effects=[pe.withStroke(linewidth=1, foreground="k")]) + + ax = plt.subplot(grid[1,j]) + maskk = dat["mediar"][iex].copy() + ax.imshow(img0) + for o in outlines_gt[j]: + ax.plot(o[:, 0], o[:, 1], color=[0.7,0.4,1], lw=2, ls="-", rasterized=True) + for o in outlines_m[j]: + ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--", rasterized=True) + ax.set_ylim(ylims[j]) + ax.set_xlim(xlims[j]) + f1 = dat["mediar_f1"][iex,0] + ax.text(1, -0.1, f"F1@0.5 = {f1:.2f}", transform=ax.transAxes, ha="right") + ax.axis("off") + if j==0: + ax.text(-0.1, 0.5, "Mediar", rotation=90, va="center", transform=ax.transAxes) + + ax = plt.subplot(grid[2:, :]) + pos = ax.get_position().bounds + ax.set_position([pos[0], pos[1], pos[2]-0.02, pos[3]-0.05]) + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=ax, + wspace=0.15, hspace=0.15) + ax.remove() + cols = plt.get_cmap("tab10")(np.linspace(0, 1, 10)) + + ax = plt.subplot(grid1[0,0]) + cols0 = plt.get_cmap("Paired")(np.linspace(0, 1, 12)) + cols = np.zeros((len(type_names), 4)) + cols[:,-1] = 1 + cols[:2] = cols0[:2] + cols[2] = cols0[3] + cols[4] = cols0[-3] + cols[-2:] = cols0[4:6] + cols[3] = np.array([0,1.,1.,1.]) + cols[6] = cols0[6] + cols[7] = cols0[-1] + irand = np.random.permutation(len(emb)-100) + ax.scatter(emb[:-100,1][irand], emb[:-100,0][irand], color=cols[types[:-100]][irand],#, cmap="tab10", + s=1, alpha=0.5, marker="o", rasterized=True, zorder=-10) + new_names = ["Omnipose (fluor.)", "Omnipose (phase)", "Cellpose", "DeepBacs", "Livecell", "Ma et al, 2024", "Nuclei", "Tissuenet", "YeaZ (BF)", "YeaZ (phase)"] + ax.set_title("t-SNE of image style vectors\n(training set)", va="top", y=1.05) + torder = np.array([5, 2, 6, 4, 7, 0, 1, 3, 8, 9]) + for k in range(len(type_names)): + th = (torder==k).nonzero()[0][0] + ax.text(0.9, 0.93-0.045*th, new_names[k], color=cols[k], + transform=ax.transAxes, fontsize="small") + ax.axis("off") + il = plot_label(ltr, il, ax, transl1, fs_title) + + dx = 0.03 + ax = plt.subplot(grid1[0,1]) + pos = ax.get_position().bounds + ax.set_position([pos[0]+dx, pos[1], pos[2], pos[3]]) + ax.scatter(emb[:-100,1], emb[:-100,0], color=0.8*np.ones(3), s=1, rasterized=True) + s1 = ax.scatter(emb[-100:,1], emb[-100:,0], color="k", + s=50, marker="x", alpha=1, lw=0.5, rasterized=True) + s2 = ax.scatter(emb_test[:,1], emb_test[:,0], color="k", facecolors='none', + s=50, marker="o", alpha=1, lw=0.5, rasterized=True) + ax.axis("off") + ax.set_title("Validation and test set\n(Ma et al, 2024)", va="top", y=1.05) + ax.legend([s1, s2], ["validation", "test"], frameon=False, loc="upper left") + il = plot_label(ltr, il, ax, transl1, fs_title) + + ax = plt.subplot(grid1[0,2]) + pos = ax.get_position().bounds + ax.set_position([pos[0]+dx, pos[1], pos[2], pos[3]]) + pos = ax.get_position().bounds + ax.scatter(emb[:-100,1], emb[:-100,0], color=0.8*np.ones(3), s=1, rasterized=True) + im = ax.scatter(emb[-100:,1], emb[-100:,0], c=dat["default_f1"][:,0], lw=2, + s=60, marker="x", alpha=1, cmap="plasma", vmin=0, vmax=1, rasterized=True) + ax.axis("off") + cax = fig.add_axes([pos[0]+pos[2]-0.02, pos[1]+pos[3]-0.12, 0.005, 0.11]) + plt.colorbar(im, cax=cax) + ax.set_title("F1 score for Cellpose (default)") + il = plot_label(ltr, il, ax, transl1, fs_title) + + ax = plt.subplot(grid1[0,3]) + pos = ax.get_position().bounds + ax.set_position([pos[0]+dx, pos[1], pos[2], pos[3]]) + pos = ax.get_position().bounds + ax.scatter(emb[:-100,1], emb[:-100,0], color=0.8*np.ones(3), s=1, rasterized=True) + im = ax.scatter(emb[-100:,1], emb[-100:,0], c=dat["default_f1"][:,0] - dat["mediar_f1"][:,0], lw=2, + s=60, marker="x", alpha=1, cmap="coolwarm", vmin=-0.3, vmax=0.3, rasterized=True) + ax.axis("off") + cax = fig.add_axes([pos[0]+pos[2]-0.02, pos[1]+pos[3]-0.12, 0.005, 0.11]) + plt.colorbar(im, cax=cax) + ax.set_title("$\Delta$F1, Cellpose (default) - Mediar") + il = plot_label(ltr, il, ax, transl1, fs_title) + + if save_fig: + fig.savefig("figs/fig2_neurips.pdf", dpi=200) diff --git a/setup.py b/setup.py index 18a73d10..2a0ae5f0 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,6 @@ gui_deps = [ 'pyqtgraph>=0.11.0rc0', "pyqt6", "pyqt6.sip", 'qtpy', 'superqt', - 'google-cloud-storage' ] docs_deps = [ diff --git a/tests/test_import.py b/tests/test_import.py index 2cc9f01e..d0526e31 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -7,7 +7,9 @@ def test_cellpose_imports_without_error(): def test_model_zoo_imports_without_error(): from cellpose import models, denoise for model_name in models.MODEL_NAMES: - model = models.CellposeModel(model_type=model_name) + if "neurips" not in model_name: + model = models.CellposeModel(model_type=model_name, + backbone="transformer" if "transformer" in model_name else "default") def test_gui_imports_without_error():