Skip to content

Commit

Permalink
update sc_seg_operator.py: adapt to aux_out ver
Browse files Browse the repository at this point in the history
  • Loading branch information
dummyindex committed Mar 27, 2024
1 parent 941a8d7 commit f6d9e73
Showing 1 changed file with 66 additions and 50 deletions.
116 changes: 66 additions & 50 deletions livecellx/core/sc_seg_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,62 @@
from livecellx.core.datasets import SingleImageDataset


def correct_sc_segment(
sc,
model,
create_ou_input_kwargs={
"padding_pixels": 50,
"dtype": float,
"remove_bg": False,
"one_object": True,
"scale": 0,
},
) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray, torch.Tensor]:
import torch
from torchvision import transforms
from livecellx.model_zoo.segmentation.sc_correction_aux import CorrectSegNetAux

# padding_pixels=padding_pixels, dtype=dtype, remove_bg=remove_bg, one_object=one_object, scale=scale
input_transforms = transforms.Compose(
[
transforms.Resize(size=(412, 412)),
]
)

temp_sc = sc.copy()
new_contour = np.array(temp_sc.contour)
new_contour = new_contour[:, -2:] # remove slice index (time)
temp_sc.update_contour(new_contour)
temp_sc.update_bbox()
res_bbox = temp_sc.bbox
ou_input = create_ou_input_from_sc(temp_sc, **create_ou_input_kwargs)
# ou_input = create_ou_input_from_sc(self.sc, **create_ou_input_kwargs)
original_shape = ou_input.shape

# TODO: change to comply with the training data preparation
# for now we simply use one of the input types during training: raw_aug_duplicate.
# Please read sc_correction_dataset impl.
ou_input = input_transforms(torch.tensor([ou_input]))
ou_input = torch.stack([ou_input, ou_input, ou_input], dim=1)
ou_input = ou_input.float().cuda()

back_transforms = transforms.Compose(
[
transforms.Resize(size=(original_shape[0], original_shape[1])),
]
)
seg_output, aux_output = None, None
if isinstance(model, CorrectSegNetAux):
model_output = model(ou_input)
seg_output, aux_output = model_output
else:
seg_output = model(ou_input)
seg_output = back_transforms(seg_output)
if not model.apply_gt_seg_edt:
seg_output = torch.sigmoid(seg_output)
return ou_input, seg_output, res_bbox, aux_output


class ScSegOperator:
"""
A class for performing segmentation on single cell images.
Expand All @@ -40,11 +96,15 @@ class ScSegOperator:
DEFAULT_CSN_MODEL = None

@staticmethod
def load_default_csn_model(path, cuda=True):
def load_default_csn_model(path, cuda=True, has_aux=True):
import torch
from livecellx.model_zoo.segmentation.sc_correction import CorrectSegNet
from livecellx.model_zoo.segmentation.sc_correction_aux import CorrectSegNetAux

model = CorrectSegNet.load_from_checkpoint(path)
if has_aux:
model = CorrectSegNetAux.load_from_checkpoint(path)
else:
model = CorrectSegNet.load_from_checkpoint(path)
if cuda:
model.cuda()
model.eval()
Expand Down Expand Up @@ -139,54 +199,10 @@ def correct_segment(self, model, create_ou_input_kwargs=None):
# padding_pixels=padding_pixels, dtype=dtype, remove_bg=remove_bg, one_object=one_object, scale=scale
temp_sc = self.sc.copy()
if create_ou_input_kwargs is None:
return self.correct_sc_segment(temp_sc, model)
# Use default values
return correct_sc_segment(temp_sc, model)
else:
return self.correct_sc_segment(temp_sc, model, create_ou_input_kwargs=create_ou_input_kwargs)

def correct_sc_segment(
sc,
model,
create_ou_input_kwargs={
"padding_pixels": 50,
"dtype": float,
"remove_bg": False,
"one_object": True,
"scale": 0,
},
) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
import torch
from torchvision import transforms

# padding_pixels=padding_pixels, dtype=dtype, remove_bg=remove_bg, one_object=one_object, scale=scale
input_transforms = transforms.Compose(
[
transforms.Resize(size=(412, 412)),
]
)
temp_sc = sc.copy()
new_contour = np.array(temp_sc.contour)
new_contour = new_contour[:, -2:] # remove slice index (time)
temp_sc.update_contour(new_contour)
temp_sc.update_bbox()
res_bbox = temp_sc.bbox
ou_input = create_ou_input_from_sc(temp_sc, **create_ou_input_kwargs)
# ou_input = create_ou_input_from_sc(self.sc, **create_ou_input_kwargs)
original_shape = ou_input.shape

ou_input = input_transforms(torch.tensor([ou_input]))
ou_input = torch.stack([ou_input, ou_input, ou_input], dim=1)
ou_input = ou_input.float().cuda()

back_transforms = transforms.Compose(
[
transforms.Resize(size=(original_shape[0], original_shape[1])),
]
)
output = model(ou_input)
output = back_transforms(output)
if not model.apply_gt_seg_edt:
output = torch.sigmoid(output)
return ou_input, output, res_bbox
return correct_sc_segment(temp_sc, model, create_ou_input_kwargs=create_ou_input_kwargs)

def replace_sc_contour(self, contour, padding_pixels=0, refresh=True):
self.sc.contour = contour + self.sc.bbox[:2] - padding_pixels
Expand Down Expand Up @@ -324,7 +340,7 @@ def csn_correct_seg_callback(self, padding_pixels=50, threshold=0.5):
"one_object": True,
"scale": 0,
}
model_ou_input, output, res_bbox = self.correct_segment(
model_ou_input, output, res_bbox, aux_output = self.correct_segment(
self.csn_model, create_ou_input_kwargs=create_ou_input_kwargs
)
bin_mask = output[0].cpu().detach().numpy()[0] > threshold
Expand Down

0 comments on commit f6d9e73

Please sign in to comment.