diff --git a/cellpose/train.py b/cellpose/train.py index 6f17b8c2..88338c09 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -293,6 +293,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, if train_probs is not None: train_probs = train_probs[ikeep] diam_train = diam_train[ikeep] + nimg = len(train_data) ### normalize probabilities train_probs = 1. / nimg * np.ones(nimg, @@ -410,7 +411,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, "channel_axis": channel_axis, "rgb": rgb } - + net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device) nimg = len(train_data) if train_data is not None else len(train_files)