Skip to content

Commit

Permalink
update sc_correction_aux.py and sc_correction_dataset.py: fix augment…
Browse files Browse the repository at this point in the history
…ation/edt order issue
  • Loading branch information
dummyindex committed Apr 3, 2024
1 parent 4f4dbf1 commit 1d1cdb1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 26 deletions.
8 changes: 0 additions & 8 deletions livecellx/model_zoo/segmentation/sc_correction_aux.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,6 @@ def training_step(self, batch, batch_idx):
# print("[train_step] x shape: ", batch["input"].shape)
# print("[train_step] y shape: ", batch["gt_mask"].shape)
x, y = batch["input"], batch["gt_mask"]
if self.apply_gt_seg_edt:
y = batch["gt_mask_edt"]
aux_target = batch["ou_aux"]
gt_pixel_weight = batch["gt_pixel_weight"]
output, aux_out = self(x)
Expand Down Expand Up @@ -346,9 +344,6 @@ def validation_step(self, batch, batch_idx, dataloader_idx):
self.test_step(batch, batch_idx)
return
x, y = batch["input"], batch["gt_mask"]
if self.apply_gt_seg_edt:
y = batch["gt_mask_edt"]

aux_target = batch["ou_aux"]
output, aux_out = self(x)
seg_loss, aux_loss = self.compute_loss(output, y, aux_out=aux_out, aux_target=aux_target)
Expand Down Expand Up @@ -388,9 +383,6 @@ def test_step(self, batch, batch_idx):
from livecellx.model_zoo.segmentation.eval_csn import compute_metrics

x, y = batch["input"], batch["gt_mask"]
if self.apply_gt_seg_edt:
y = batch["gt_mask_edt"]

aux_target = batch["ou_aux"]
output, aux_out = self(x)
seg_loss, aux_loss = self.compute_loss(output, y, aux_out=aux_out, aux_target=aux_target)
Expand Down
32 changes: 14 additions & 18 deletions livecellx/model_zoo/segmentation/sc_correction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.utils.data import DataLoader, random_split
import scipy.ndimage
import skimage.measure
from livecellx.core.utils import label_mask_to_edt_mask
from livecellx.preprocess.utils import normalize_img_to_uint8

# class CorrectSegNetData(data.Dataset):
Expand Down Expand Up @@ -173,10 +174,14 @@ def __getitem__(self, idx):
gt_pixel_weight = np.ones_like(gt_label_mask__np)
gt_pixel_weight = torch.tensor(gt_pixel_weight).float()

# transform to edt for inputs before augmentation
# Transform to edt for inputs before augmentation
if self.input_type == "edt_v0":
scaled_seg_mask = self.label_mask_to_edt(scaled_seg_mask)
# prepare for augmentation
scaled_seg_mask = label_mask_to_edt_mask(scaled_seg_mask, bg_val=self.bg_val)
scaled_seg_mask = torch.tensor(scaled_seg_mask).float()

gt_label_edt = torch.tensor(label_mask_to_edt_mask(gt_label_mask, bg_val=self.bg_val)).float()

# Prepare for augmentation
concat_img = torch.stack(
[
augmented_raw_img,
Expand All @@ -186,9 +191,11 @@ def __getitem__(self, idx):
aug_diff_img,
gt_label_mask,
gt_pixel_weight,
gt_label_edt,
],
dim=0,
)

if self.transform:
concat_img = self.transform(concat_img)

Expand Down Expand Up @@ -233,24 +240,13 @@ def __getitem__(self, idx):
gt_mask[gt_mask > 0.5] = 1
gt_mask[gt_mask <= 0.5] = 0
gt_binary = gt_mask
gt_mask_edt = None
gt_mask_edt = concat_img[7, :, :]

# apply edt to each label in gt label mask, and normalize edt to [0, 1]
if self.apply_gt_seg_edt:
augmented_gt_label_mask__np = augmented_gt_label_mask.numpy()
gt_mask_edt = np.zeros(augmented_gt_label_mask__np.shape)
gt_labels = set(np.unique(augmented_gt_label_mask__np))
if self.bg_val in gt_labels:
gt_labels.remove(self.bg_val)
for label in gt_labels:
tmp_bin_mask = augmented_gt_label_mask__np == label
tmp_edt = scipy.ndimage.morphology.distance_transform_edt(tmp_bin_mask)
gt_mask_edt = np.maximum(gt_mask_edt, tmp_edt)
# gt_mask = normalize_img_to_uint8(gt_mask_edt, dtype=float)
# gt_mask /= np.max(gt_mask)
gt_mask = torch.tensor(gt_mask_edt).float()

combined_gt = torch.stack([gt_mask, aug_diff_overseg, aug_diff_underseg], dim=0).float()
combined_gt = torch.stack([gt_mask_edt, aug_diff_overseg, aug_diff_underseg], dim=0).float()
else:
combined_gt = torch.stack([gt_mask, aug_diff_overseg, aug_diff_underseg], dim=0).float()

# Prepare ou_aux tensor: 4 classes auxillary output
ou_aux = torch.tensor([0, 0, 0, 0]).float()
Expand Down

0 comments on commit 1d1cdb1

Please sign in to comment.