From 51535c99e84c66ddaa51e5f2490d46a6862a3aee Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Mon, 28 Oct 2024 15:42:41 -0400 Subject: [PATCH] moving print statement for resize --- cellpose/dynamics.py | 19 ++++++++++--------- cellpose/models.py | 11 ++++++----- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py index 160a63d7..8775b919 100644 --- a/cellpose/dynamics.py +++ b/cellpose/dynamics.py @@ -625,21 +625,22 @@ def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu" if major_version == "1" and int(minor_version) < 10: # for PyTorch version lower than 1.10 def mem_info(): - total_mem = torch.cuda.get_device_properties(0).total_memory - used_mem = torch.cuda.memory_allocated() - return total_mem, used_mem + total_mem = torch.cuda.get_device_properties(device0.index).total_memory + used_mem = torch.cuda.memory_allocated(device0.index) + free_mem = total_mem - used_mem + return total_mem, free_mem else: # for PyTorch version 1.10 and above def mem_info(): - total_mem, used_mem = torch.cuda.mem_get_info() - return total_mem, used_mem - - if masks.size * 20 > mem_info()[0]: + free_mem, total_mem = torch.cuda.mem_get_info(device0.index) + return total_mem, free_mem + total_mem, free_mem = mem_info() + if masks.size * 32 > free_mem: dynamics_logger.warning( "WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold" ) dynamics_logger.info("turn off QC step with flow_threshold=0 if too slow") - device0 = None + device0 = torch.device("cpu") merrors, _ = metrics.flow_error(masks, flows, device0) badi = 1 + (merrors > threshold).nonzero()[0] @@ -904,7 +905,7 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, # calculate masks mask = get_masks_torch(p_final, inds, dP.shape[1:], max_size_fraction=max_size_fraction) - + del p_final # flow thresholding factored out of get_masks if not do_3D: if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0: diff --git a/cellpose/models.py b/cellpose/models.py index 952a3176..70ad2bc7 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -565,11 +565,12 @@ def _run_net(self, x, rescale=1.0, resample=True, augment=False, batch_size=batch_size, augment=augment, tile_overlap=tile_overlap, net_ortho=self.net_ortho) if resample: - models_logger.info("resizing 3D flows and cellprob to original image size") - if rescale != 1.0: - yf = transforms.resize_image(yf, Ly=Ly, Lx=Lx) - if Lz != yf.shape[0]: - yf = transforms.resize_image(yf.transpose(1,0,2,3), + if rescale != 1.0 or Lz != yf.shape[0]: + models_logger.info("resizing 3D flows and cellprob to original image size") + if rescale != 1.0: + yf = transforms.resize_image(yf, Ly=Ly, Lx=Lx) + if Lz != yf.shape[0]: + yf = transforms.resize_image(yf.transpose(1,0,2,3), Ly=Lz, Lx=Lx).transpose(1,0,2,3) cellprob = yf[..., -1] dP = yf[..., :-1].transpose((3, 0, 1, 2))