Skip to content

Commit

Permalink
moving print statement for resize
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Oct 28, 2024
1 parent c9b22d6 commit 51535c9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
19 changes: 10 additions & 9 deletions cellpose/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,21 +625,22 @@ def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu"
if major_version == "1" and int(minor_version) < 10:
# for PyTorch version lower than 1.10
def mem_info():
total_mem = torch.cuda.get_device_properties(0).total_memory
used_mem = torch.cuda.memory_allocated()
return total_mem, used_mem
total_mem = torch.cuda.get_device_properties(device0.index).total_memory
used_mem = torch.cuda.memory_allocated(device0.index)
free_mem = total_mem - used_mem
return total_mem, free_mem
else:
# for PyTorch version 1.10 and above
def mem_info():
total_mem, used_mem = torch.cuda.mem_get_info()
return total_mem, used_mem

if masks.size * 20 > mem_info()[0]:
free_mem, total_mem = torch.cuda.mem_get_info(device0.index)
return total_mem, free_mem
total_mem, free_mem = mem_info()
if masks.size * 32 > free_mem:
dynamics_logger.warning(
"WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold"
)
dynamics_logger.info("turn off QC step with flow_threshold=0 if too slow")
device0 = None
device0 = torch.device("cpu")

merrors, _ = metrics.flow_error(masks, flows, device0)
badi = 1 + (merrors > threshold).nonzero()[0]
Expand Down Expand Up @@ -904,7 +905,7 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
# calculate masks
mask = get_masks_torch(p_final, inds, dP.shape[1:],
max_size_fraction=max_size_fraction)

del p_final
# flow thresholding factored out of get_masks
if not do_3D:
if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0:
Expand Down
11 changes: 6 additions & 5 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,11 +565,12 @@ def _run_net(self, x, rescale=1.0, resample=True, augment=False,
batch_size=batch_size, augment=augment,
tile_overlap=tile_overlap, net_ortho=self.net_ortho)
if resample:
models_logger.info("resizing 3D flows and cellprob to original image size")
if rescale != 1.0:
yf = transforms.resize_image(yf, Ly=Ly, Lx=Lx)
if Lz != yf.shape[0]:
yf = transforms.resize_image(yf.transpose(1,0,2,3),
if rescale != 1.0 or Lz != yf.shape[0]:
models_logger.info("resizing 3D flows and cellprob to original image size")
if rescale != 1.0:
yf = transforms.resize_image(yf, Ly=Ly, Lx=Lx)
if Lz != yf.shape[0]:
yf = transforms.resize_image(yf.transpose(1,0,2,3),
Ly=Lz, Lx=Lx).transpose(1,0,2,3)
cellprob = yf[..., -1]
dP = yf[..., :-1].transpose((3, 0, 1, 2))
Expand Down

0 comments on commit 51535c9

Please sign in to comment.