Skip to content

Commit

Permalink
adding new 3d options to gui and to RTD
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Oct 29, 2024
1 parent 51535c9 commit efcb4b9
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 157 deletions.
16 changes: 1 addition & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,4 @@ Check out [Omnipose](https://github.com/kevinjohncutler/omnipose), an extension
Pytorch is now the default deep neural network software for cellpose. Mxnet will still be supported. To install mxnet (CPU), run `pip install mxnet-mkl`. To use mxnet in a notebook, declare `torch=False` when creating a model, e.g. `model = models.Cellpose(torch=False)`. To use mxnet on the command line, add the flag `--mxnet`, e.g. `python -m cellpose --dir ~/images/ --mxnet`. The pytorch implementation is 20% faster than the mxnet implementation when running on the GPU and 20% slower when running on the CPU.
Dynamics are computed using bilinear interpolation by default instead of nearest neighbor interpolation. Set `interp=False` in `model.eval` to turn off. The bilinear interpolation will be slightly slower on the CPU, but it is faster than nearest neighbor if using torch and the GPU is enabled.
### Timing (v0.6)
You can check if cellpose is running the MKL version (if you are using the CPU not the GPU) by adding the flag `--check_mkl`. If you are not using MKL cellpose will be much slower. Here are Cellpose run times divided into the time it takes to run the deep neural network (DNN) and the time for postprocessing (gradient tracking, segmentation, quality control etc.). The DNN runtime is shown using either a GPU (Nvidia GTX 1080Ti) or a CPU (Intel 10-core 7900X), with or without network ensembling (4net vs 1net). The postprocessing runtime is similar regardless of ensembling or CPU/GPU version. Runtime is shown for different image sizes, all with a cell diameter of 30 pixels (the average from our training set).
| | 256 pix | 512 pix | 1024 pix |
|----|-------|------|----------|
| DNN (1net, GPU) | 0.054 s | 0.12 s | 0.31 s |
| DNN (1net, CPU) | 0.30 s | 0.65 s | 2.4 s |
| DNN (4net, GPU) | 0.23 s | 0.41 s | 1.3 s |
| DNN (4net, CPU) | 1.3 s | 2.5 s | 9.1 s |
| | | | |
| Postprocessing (CPU) | 0.32 s | 1.2 s | 6.1 s |
Dynamics are computed using bilinear interpolation by default instead of nearest neighbor interpolation. Set `interp=False` in `model.eval` to turn off. The bilinear interpolation will be slightly slower on the CPU, but it is faster than nearest neighbor if using torch and the GPU is enabled.
56 changes: 46 additions & 10 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ def __init__(self, image=None, logger=None):

self.load_3D = False
self.stitch_threshold = 0.
self.dP_smooth = 0.
self.anisotropy = 1.
self.min_size = 15
self.resample = True

self.setAcceptDrops(True)
self.win.show()
Expand Down Expand Up @@ -2414,6 +2418,15 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
do_3D = self.load_3D
stitch_threshold = float(self.stitch_threshold.text()) if not isinstance(
self.stitch_threshold, float) else self.stitch_threshold
anisotropy = float(self.anisotropy.text()) if not isinstance(
self.anisotropy, float) else self.anisotropy
dP_smooth = float(self.dP_smooth.text()) if not isinstance(
self.dP_smooth, float) else self.dP_smooth
min_size = int(self.min_size.text()) if not isinstance(
self.min_size, int) else self.min_size
resample = self.resample.isChecked() if not isinstance(
self.resample, bool) else self.resample

do_3D = False if stitch_threshold > 0. else do_3D

channels = self.get_channels()
Expand All @@ -2433,6 +2446,8 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, do_3D=do_3D, niter=niter,
normalize=normalize_params, stitch_threshold=stitch_threshold,
anisotropy=anisotropy, resample=resample, dP_smooth=dP_smooth,
min_size=min_size,
progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2]
except Exception as e:
print("NET ERROR: %s" % e)
Expand All @@ -2452,17 +2467,38 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
else:
flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8"))

if self.restore and "upsample" in self.restore:
self.Ly, self.Lx = self.Lyr, self.Lxr

if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx):
self.flows = []
for j in range(len(flows_new)):
self.flows.append(
resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx,
interpolation=cv2.INTER_NEAREST))
if not self.load_3D:
if self.restore and "upsample" in self.restore:
self.Ly, self.Lx = self.Lyr, self.Lxr

if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx):
self.flows = []
for j in range(len(flows_new)):
self.flows.append(
resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx,
interpolation=cv2.INTER_NEAREST))
else:
self.flows = flows_new
else:
self.flows = flows_new
if not resample:
self.flows = []
Lz, Ly, Lx = self.NZ, self.Ly, self.Lx
Lz0, Ly0, Lx0 = flows_new[0].shape[:3]
print("GUI_INFO: resizing flows to original image size")
for j in range(len(flows_new)):
flow0 = flows_new[j]
if Ly0 != Ly:
flow0 = resize_image(flow0, Ly=Ly, Lx=Lx,
no_channels=flow0.ndim==3,
interpolation=cv2.INTER_NEAREST)
if Lz0 != Lz:
flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1),
Ly=Lz, Lx=Lx,
no_channels=flow0.ndim==3,
interpolation=cv2.INTER_NEAREST), 0, 1)
self.flows.append(flow0)
else:
self.flows = flows_new

# add first axis
if self.NZ == 1:
Expand Down
62 changes: 57 additions & 5 deletions cellpose/gui/gui3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,73 @@ def __init__(self, image=None, logger=None):

b = 22

b += 1
label = QLabel("3D stitch threshold:")
label = QLabel("stitch threshold:")
label.setToolTip(
"for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
)
label.setFont(self.medfont)
self.segBoxG.addWidget(label, b, 0, 1, 6)
self.segBoxG.addWidget(label, b, 0, 1, 4)
self.stitch_threshold = QLineEdit()
self.stitch_threshold.setText("0.0")
self.stitch_threshold.setFixedWidth(40)
self.stitch_threshold.setFixedWidth(30)
self.stitch_threshold.setFont(self.medfont)
self.stitch_threshold.setToolTip(
"for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
)
self.segBoxG.addWidget(self.stitch_threshold, b, 7, 1, 2)
self.segBoxG.addWidget(self.stitch_threshold, b, 4, 1, 1)

label = QLabel("dP_smooth:")
label.setToolTip(
"for 3D volumes, smooth flows by a Gaussian with standard deviation dP_smooth (see docs for details)"
)
label.setFont(self.medfont)
self.segBoxG.addWidget(label, b, 5, 1, 3)
self.dP_smooth = QLineEdit()
self.dP_smooth.setText("0.0")
self.dP_smooth.setFixedWidth(30)
self.dP_smooth.setFont(self.medfont)
self.dP_smooth.setToolTip(
"for 3D volumes, smooth flows by a Gaussian with standard deviation dP_smooth (see docs for details)"
)
self.segBoxG.addWidget(self.dP_smooth, b, 8, 1, 1)

b+=1
label = QLabel("anisotropy:")
label.setToolTip(
"for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
)
label.setFont(self.medfont)
self.segBoxG.addWidget(label, b, 0, 1, 4)
self.anisotropy = QLineEdit()
self.anisotropy.setText("1.0")
self.anisotropy.setFixedWidth(30)
self.anisotropy.setFont(self.medfont)
self.anisotropy.setToolTip(
"for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
)
self.segBoxG.addWidget(self.anisotropy, b, 4, 1, 1)

self.resample = QCheckBox("resample")
self.resample.setToolTip("reample before creating masks; if diameter > 30 resample will use more CPU+GPU memory (see docs for more details)")
self.resample.setFont(self.medfont)
self.resample.setChecked(True)
self.segBoxG.addWidget(self.resample, b, 5, 1, 4)

b+=1
label = QLabel("min_size:")
label.setToolTip(
"all masks less than this size in pixels (volume) will be removed"
)
label.setFont(self.medfont)
self.segBoxG.addWidget(label, b, 0, 1, 4)
self.min_size = QLineEdit()
self.min_size.setText("15")
self.min_size.setFixedWidth(50)
self.min_size.setFont(self.medfont)
self.min_size.setToolTip(
"all masks less than this size in pixels (volume) will be removed"
)
self.segBoxG.addWidget(self.min_size, b, 4, 1, 3)

b += 1
self.orthobtn = QCheckBox("ortho")
Expand Down
28 changes: 28 additions & 0 deletions docs/benchmark.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Timing + memory usage
------------------------------------

The algorithm runtime and memory usage increases with the data size. The runtimes
shown below are for a single image run for the first time on an A100 with a batch_size of 32
- this timing includes warm-up of GPU, thus runtimes
will be faster for subsequent images. It will also be faster if you run many images of the same size
input as an array into Cellpose with a large batch_size. The runtimes will also be
slightly faster if you have fewer cells/cell pixels.

.. image:: https://www.cellpose.org/static/images/benchmark_plot.png
:width: 600

Table for 2D:

.. image:: https://www.cellpose.org/static/images/benchmark_2d.png
:width: 400

Table for 3D:

.. image:: https://www.cellpose.org/static/images/benchmark_3d.png
:width: 400

If you are running out of GPU memory for your images, you can reduce the
``batch_size`` parameter in the ``model.eval`` function or in the CLI (default is 8).

If you have even larger images than above, you may want to tile them
before running Cellpose.
134 changes: 134 additions & 0 deletions docs/do3d.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
.. _do3d:

3D segmentation
------------------------------------

Input format
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Tiffs with multiple planes and multiple channels are supported in the GUI (can
drag-and-drop tiffs) and supported when running in a notebook.
To open the GUI with z-stack support, use ``python -m cellpose --Zstack``.
Multiplane images should be of shape nplanes x channels x nY x nX or as
nplanes x nY x nX. You can test this by running in python

::

import tifffile
data = tifffile.imread('img.tif')
print(data.shape)

If drag-and-drop of the tiff into
the GUI does not work correctly, then it's likely that the shape of the tiff is
incorrect. If drag-and-drop works (you can see a tiff with multiple planes),
then the GUI will automatically run 3D segmentation and display it in the GUI. Watch
the command line for progress. It is recommended to use a GPU to speed up processing.

In the CLI/notebook, you can specify the ``channel_axis`` and/or ``z_axis``
parameters to specify the axis (0-based) of the image which corresponds to the image channels and to the z axis.
For example an image with 2 channels of shape (1024,1024,2,105,1) can be
specified with ``channel_axis=2`` and ``z_axis=3``. If ``channel_axis=None``
cellpose will try to automatically determine the channel axis by choosing
the dimension with the minimal size after squeezing. If ``z_axis=None``
cellpose will automatically select the first non-channel axis of the image
to be the Z axis. These parameters can be specified using the command line
with ``--channel_axis`` or ``--z_axis`` or as inputs to ``model.eval`` for
the ``Cellpose`` or ``CellposeModel`` model.

Volumetric stacks do not always have the same sampling in XY as they do in Z.
Therefore you can set an ``anisotropy`` parameter in CLI/notebook to allow for differences in
sampling, e.g. set to 2.0 if Z is sampled half as dense as X or Y, and then in the algorithm
Z is upsampled by 2x.

Segmentation settings
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The default segmentation in the GUI is 2.5D segmentation, where the flows are computed
on each YX, ZY and ZX slice and then averaged, and then the dynamics are run in 3D.
Specify this segmentation format in the notebook with ``do_3D=True`` or in the CLI with ``--do_3D``
(with the CLI it will segment all tiffs in the folder as 3D tiffs if possible).

If you see many cells that are fragmented, you can smooth the flows before the dynamics
are run in 3D using the ``dP_smooth`` parameter, which specifies the standard deviation of
a Gaussian for smoothing the flows. The default is 0.0, which means no smoothing. Alternatively/additionally,
you may want to train a model on 2D slices from your 3D data to improve the segmentation (see below).

The network rescales images using the user diameter and the model ``diam_mean`` (usually 30),
so for example if you input a diameter of 90 and the model was trained with a diameter of 30,
then the image will be downsampled by a factor of 3 for computing the flows. If ``resample``
is enabled, then the image will then be upsampled for finding the masks. This will take
additional CPU and GPU memory, so for 3D you may want to set ``resample=False`` or in the CLI ``--no_resample``
(more details here :ref:`resample`).

There may be additional differences in YZ and XZ slices
that make them unable to be used for 3D segmentation.
I'd recommend viewing the volume in those dimensions if
the segmentation is failing, using the orthoviews (activate in the bottom left of the GUI).
In those instances, you may want to turn off
3D segmentation (``do_3D=False``) and run instead with ``stitch_threshold>0``.
Cellpose will create ROIs in 2D on each XY slice and then stitch them across
slices if the IoU between the mask on the current slice and the next slice is
greater than or equal to the ``stitch_threshold``. Alternatively, you can train a separate model for
YX slices vs ZY and ZX slices, and then specify the separate model for ZY/ZX slices
using the ``pretrained_model_ortho`` option in ``CellposeModel``.

3D segmentation ignores the ``flow_threshold`` because we did not find that
it helped to filter out false positives in our test 3D cell volume. Instead,
we found that setting ``min_size`` is a good way to remove false positives.

Training for 3D segmentation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

You can create image crops from z-stacks (in YX, YZ and XZ) using the script ``cellpose/gui/make_train.py``.
If you have anisotropic volumes, then set the ``--anisotropy`` flag to the ratio between pixel size in Z and in YX,
e.g. set ``--anisotropy 5`` for pixel size of 1.0 um in YX and 5.0 um in Z. Now you can
drag-and-drop an image from the folder into the GUI and start to re-train a model
by labeling your crops and using the ``Train`` option in the GUI (see the
Cellpose2 tutorial for more advice). If the model with all crops
isn't working well, you can alternatively separate the crops
into two folders (YX and ZY/ZX) and train separate networks, and use
``pretrained_model_ortho`` when declaring your model.

See the help message for more information:

::
python cellpose\gui\make_train.py --help
usage: make_train.py [-h] [--dir DIR] [--image_path IMAGE_PATH] [--look_one_level_down] [--img_filter IMG_FILTER]
[--channel_axis CHANNEL_AXIS] [--z_axis Z_AXIS] [--chan CHAN] [--chan2 CHAN2] [--invert]
[--all_channels] [--anisotropy ANISOTROPY] [--sharpen_radius SHARPEN_RADIUS]
[--tile_norm TILE_NORM] [--nimg_per_tif NIMG_PER_TIF] [--crop_size CROP_SIZE]

cellpose parameters

options:
-h, --help show this help message and exit

input image arguments:
--dir DIR folder containing data to run or train on.
--image_path IMAGE_PATH
if given and --dir not given, run on single image instead of folder (cannot train with this
option)
--look_one_level_down
run processing on all subdirectories of current folder
--img_filter IMG_FILTER
end string for images to run on
--channel_axis CHANNEL_AXIS
axis of image which corresponds to image channels
--z_axis Z_AXIS axis of image which corresponds to Z dimension
--chan CHAN channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: 0
--chan2 CHAN2 nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: 0
--invert invert grayscale channel
--all_channels use all channels in image if using own model and images with special channels
--anisotropy ANISOTROPY
anisotropy of volume in 3D

algorithm arguments:
--sharpen_radius SHARPEN_RADIUS
high-pass filtering radius. Default: 0.0
--tile_norm TILE_NORM
tile normalization block size. Default: 0
--nimg_per_tif NIMG_PER_TIF
number of crops in XY to save per tiff. Default: 10
--crop_size CROP_SIZE
size of random crop to save. Default: 512
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Cellpose: a generalist algorithm for cellular segmentation
inputs
settings
outputs
do3d
models
restore
train
Expand Down
Loading

0 comments on commit efcb4b9

Please sign in to comment.