Skip to content

Commit

Permalink
adding train_size for file sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Apr 7, 2024
1 parent 5b2d834 commit 8c8e9d8
Show file tree
Hide file tree
Showing 14 changed files with 564 additions and 133 deletions.
25 changes: 13 additions & 12 deletions cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,15 @@ def main():
test_dir = None if len(args.test_dir) == 0 else args.test_dir
images, labels, image_names, train_probs = None, None, None, None
test_images, test_labels, image_names_test, test_probs = None, None, None, None
compute_flows = False
if len(args.file_list) > 0:
if os.path.exists(args.file_list):
dat = np.load(args.file_list, allow_pickle=True).item()
image_names = dat["train_files"]
image_names_test = dat["test_files"]
if "train_probs" in dat:
train_probs = dat["train_probs"]
test_probs = dat["test_probs"]
image_names_test = dat.get("test_files", None)
train_probs = dat.get("train_probs", None)
test_probs = dat.get("test_probs", None)
compute_flows = dat.get("compute_flows", False)
load_files = False
else:
logger.critical(f"ERROR: {args.file_list} does not exist")
Expand All @@ -263,7 +264,7 @@ def main():
channels = None
else:
nchan = 2

# model path
szmean = args.diam_mean
if not os.path.exists(pretrained_model) and model_type is None:
Expand Down Expand Up @@ -291,24 +292,22 @@ def main():
test_data=test_images, test_labels=test_labels,
test_files=image_names_test,
train_probs=train_probs, test_probs=test_probs,
load_files=load_files, learning_rate=args.learning_rate,
weight_decay=args.weight_decay, channels=channels,
compute_flows=compute_flows, load_files=load_files,
normalize=(not args.no_norm), channels=channels,
channel_axis=args.channel_axis, rgb=(nchan==3),
save_path=os.path.realpath(args.dir), save_every=args.save_every,
learning_rate=args.learning_rate, weight_decay=args.weight_decay,
SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.batch_size,
min_train_masks=args.min_train_masks,
nimg_per_epoch=args.nimg_per_epoch, normalize=(not args.no_norm),
nimg_per_epoch=args.nimg_per_epoch,
nimg_test_per_epoch=args.nimg_test_per_epoch,
save_path=os.path.realpath(args.dir), save_every=args.save_every,
model_name=args.model_name_out)
model.pretrained_model = cpmodel_path
logger.info(">>>> model trained and saved to %s" % cpmodel_path)

# train size model
if args.train_size:
sz_model = models.SizeModel(cp_model=model, device=device)
masks = [lbl[0] for lbl in labels]
test_masks = [lbl[0] for lbl in test_labels
] if test_labels is not None else test_labels
# data has already been normalized and reshaped
sz_model.params = train.train_size(model.net, model.pretrained_model,
images, labels, train_files=image_names,
Expand All @@ -322,6 +321,8 @@ def main():
nimg_test_per_epoch=args.nimg_test_per_epoch,
batch_size=args.batch_size)
if test_images is not None:
test_masks = [lbl[0] for lbl in test_labels
] if test_labels is not None else test_labels
predicted_diams, diams_style = sz_model.eval(
test_images, channels=channels)
ccs = np.corrcoef(
Expand Down
2 changes: 1 addition & 1 deletion cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1,
(faster if augment is False)
Args:
net (class): cellpose network (model.net)
imgs (np.ndarray): The input image or stack of images of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan].
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
Expand Down Expand Up @@ -240,7 +241,6 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1,

return y, style


def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0.1):
"""
Run network on tiles of size [bsize x bsize]
Expand Down
6 changes: 0 additions & 6 deletions cellpose/gui/menus.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ def mainmenu(parent):
file_menu.addAction(parent.saveFlows)
parent.saveFlows.setEnabled(False)

parent.saveServer = QAction("Send manually labelled data to server", parent)
parent.saveServer.triggered.connect(lambda: save_server(parent))
file_menu.addAction(parent.saveServer)
parent.saveServer.setEnabled(False)


def editmenu(parent):
main_menu = parent.menuBar()
edit_menu = main_menu.addMenu("&Edit")
Expand Down
41 changes: 1 addition & 40 deletions cellpose/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,43 +744,4 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[
imsave(os.path.join(flowdir, basename + "_flows" + suffix + ".tif"),
(flows[0] * (2**16 - 1)).astype(np.uint16))
#save full flow data
imsave(os.path.join(flowdir, basename + "_dP" + suffix + ".tif"), flows[1])


def save_server(parent=None, filename=None):
""" Uploads a *_seg.npy file to the bucket.
Args:
parent (PyQt.MainWindow, optional): GUI window to grab file info from. Defaults to None.
filename (str, optional): if no GUI, send this file to server. Defaults to None.
"""
if parent is not None:
q = QMessageBox.question(
parent, "Send to server",
"Are you sure? Only send complete and fully manually segmented data.\n (do not send partially automated segmentations)",
QMessageBox.Yes | QMessageBox.No)
if q != QMessageBox.Yes:
return
else:
filename = parent.filename

if filename is not None:
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"key/cellpose-data-writer.json")
bucket_name = "cellpose_data"
base = os.path.splitext(filename)[0]
source_file_name = base + "_seg.npy"
io_logger.info(f"sending {source_file_name} to server")
time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S.%f")
filestring = time + ".npy"
io_logger.info(f"name on server: {filestring}")
destination_blob_name = filestring
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)

blob.upload_from_filename(source_file_name)

io_logger.info("File {} uploaded to {}.".format(source_file_name,
destination_blob_name))
imsave(os.path.join(flowdir, basename + "_dP" + suffix + ".tif"), flows[1])
13 changes: 0 additions & 13 deletions cellpose/key/cellpose-data-writer.json

This file was deleted.

103 changes: 59 additions & 44 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
MODEL_NAMES = [
"cyto3", "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3",
"yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "cyto2", "cyto",
"transformer_cp3"
"transformer_cp3", "neurips_cellpose_default", "neurips_cellpose_transformer",
"neurips_grayscale_cyto2"
]

MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt"))
Expand All @@ -43,16 +44,13 @@
}

def model_path(model_type, model_index=0):
if not os.path.exists(model_type):
torch_str = "torch"
if model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei":
basename = "%s%s_%d" % (model_type, torch_str, model_index)
else:
basename = model_type
return cache_model_path(basename)
torch_str = "torch"
if model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei":
basename = "%s%s_%d" % (model_type, torch_str, model_index)
else:
return model_type

basename = model_type
return cache_model_path(basename)

def size_model_path(model_type):
if os.path.exists(model_type):
return model_type + "_size.npy"
Expand Down Expand Up @@ -139,7 +137,7 @@ def __init__(self, gpu=False, model_type="cyto3", nchan=2,
self.sz.model_type = model_type

def eval(self, x, batch_size=8, channels=[0,0], channel_axis=None, invert=False,
normalize=True, diameter=30., do_3D=False, **kwargs):
normalize=True, diameter=30., do_3D=False, find_masks=True, **kwargs):
"""Run cellpose size model and mask model and get masks.
Args:
Expand Down Expand Up @@ -197,14 +195,12 @@ def eval(self, x, batch_size=8, channels=[0,0], channel_axis=None, invert=False,
else:
diams = diameter

tic = time.time()
models_logger.info("~~~ FINDING MASKS ~~~")
masks, flows, styles = self.cp.eval(x, channels=channels,
channel_axis=channel_axis,
batch_size=batch_size, normalize=normalize,
invert=invert, diameter=diams, do_3D=do_3D,
**kwargs)

models_logger.info(">>>> TOTAL TIME %0.2f sec" % (time.time() - tic0))

return masks, flows, styles, diams
Expand Down Expand Up @@ -250,37 +246,47 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
nchan (int, optional): Number of channels to use as input to network, default is 2 (cyto + nuclei) or (nuclei + zeros).
"""
self.diam_mean = diam_mean
builtin = True

### set model path
default_model = "cyto3" if backbone=="default" else "transformer_cp3"
if model_type is not None or (pretrained_model and
not os.path.exists(pretrained_model)):
pretrained_model_string = model_type if model_type is not None else default_model
model_strings = get_user_models()
all_models = MODEL_NAMES.copy()
all_models.extend(model_strings)
if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]):
builtin = False
if (not os.path.exists(pretrained_model_string) and
~np.any([pretrained_model_string == s for s in all_models])):
pretrained_model_string = default_model
models_logger.warning("model_type does not exist / has incorrect path")

if (pretrained_model and not os.path.exists(pretrained_model)):
models_logger.warning("pretrained model has incorrect path")
models_logger.info(f">> {pretrained_model_string} << model set to be used")

if pretrained_model_string == "nuclei":
builtin = False
use_default = False
model_strings = get_user_models()
all_models = MODEL_NAMES.copy()
all_models.extend(model_strings)

# check if pretrained_model is builtin or custom user model saved in .cellpose/models
# if yes, then set to model_type
if (pretrained_model and not Path(pretrained_model).exists() and
np.any([pretrained_model == s for s in all_models])):
model_type = pretrained_model

# check if model_type is builtin or custom user model saved in .cellpose/models
if model_type is not None and np.any([model_type == s for s in all_models]):
if np.any([model_type == s for s in MODEL_NAMES]):
builtin = True
models_logger.info(f">> {model_type} << model set to be used")
if model_type == "nuclei":
self.diam_mean = 17.
pretrained_model = model_path(model_type)
# if model_type is not None and does not exist, use default model
elif model_type is not None:
if Path(model_type).exists():
pretrained_model = model_type
else:
self.diam_mean = 30.
pretrained_model = model_path(pretrained_model_string)
models_logger.warning("model_type does not exist, using default model")
use_default = True
# if model_type is None...
else:
builtin = False
if pretrained_model:
pretrained_model_string = pretrained_model
models_logger.info(f">>>> loading model {pretrained_model_string}")

# assign network device
# if pretrained_model does not exist, use default model
if pretrained_model and not Path(pretrained_model).exists():
models_logger.warning("pretrained_model path does not exist, using default model")
use_default = True

builtin = True if use_default else builtin
self.pretrained_model = model_path(default_model) if use_default else pretrained_model

### assign model device
self.mkldnn = None
if device is None:
sdevice, gpu = assign_device(use_torch=True, gpu=gpu)
Expand All @@ -291,12 +297,11 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
if not self.gpu:
self.mkldnn = check_mkl(True)

# create network
### create neural network
self.nchan = nchan
self.nclasses = 3
nbase = [32, 64, 128, 256]
self.nbase = [nchan, *nbase]

self.pretrained_model = pretrained_model
if backbone=="default":
self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn,
Expand All @@ -306,7 +311,9 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
self.net = Transformer(encoder_weights="imagenet" if not self.pretrained_model else None,
diam_mean=diam_mean).to(self.device)

### load model weights
if self.pretrained_model:
models_logger.info(f">>>> loading model {pretrained_model}")
self.net.load_model(self.pretrained_model, device=self.device)
if not builtin:
self.diam_mean = self.net.diam_mean.data.cpu().numpy()[0]
Expand All @@ -318,6 +325,9 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
models_logger.info(
f">>>> model diam_labels = {self.diam_labels: .3f} (mean diameter of training ROIs)"
)
else:
models_logger.info(f">>>> no model weights loaded")
self.diam_labels = self.diam_mean

self.net_type = f"cellpose_{backbone}"

Expand Down Expand Up @@ -382,12 +392,14 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
"""
if isinstance(x, list) or x.squeeze().ndim == 5:
self.timing = []
masks, styles, flows = [], [], []
tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
nimg = len(x)
iterator = trange(nimg, file=tqdm_out,
mininterval=30) if nimg > 1 else range(nimg)
for i in iterator:
tic = time.time()
maski, flowi, stylei = self.eval(
x[i], batch_size=batch_size,
channels=channels[i] if channels is not None and
Expand All @@ -409,6 +421,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
masks.append(maski)
flows.append(flowi)
styles.append(stylei)
self.timing.append(time.time() - tic)
return masks, flows, styles

else:
Expand Down Expand Up @@ -499,7 +512,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
if rescale != 1.0:
img = transforms.resize_image(img, rsz=rescale)
yf, style = run_net(self.net, img, bsize=bsize, augment=augment,
tile=tile, tile_overlap=tile_overlap)
tile=tile, tile_overlap=tile_overlap)
if resample:
yf = transforms.resize_image(yf, shape[1], shape[2])

Expand Down Expand Up @@ -661,14 +674,15 @@ def eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False
- diam (np.ndarray): Final estimated diameters from images x or styles style after running both steps.
- diam_style (np.ndarray): Estimated diameters from style alone.
"""

if isinstance(x, list):
self.timing = []
diams, diams_style = [], []
nimg = len(x)
tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
iterator = trange(nimg, file=tqdm_out,
mininterval=30) if nimg > 1 else range(nimg)
for i in iterator:
tic = time.time()
diam, diam_style = self.eval(
x[i], channels=channels[i] if
(channels is not None and len(channels) == len(x) and
Expand All @@ -679,6 +693,7 @@ def eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False
batch_size=batch_size, progress=progress)
diams.append(diam)
diams_style.append(diam_style)
self.timing.append(time.time() - tic)

return diams, diams_style

Expand Down
Loading

0 comments on commit 8c8e9d8

Please sign in to comment.