Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
dummyindex committed Apr 1, 2024
2 parents 05a1eff + 96815c2 commit 66a792c
Show file tree
Hide file tree
Showing 9 changed files with 2,299 additions and 37 deletions.
149 changes: 146 additions & 3 deletions livecellx/model_zoo/segmentation/csn_configs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from functools import partial
import torch
from torchvision import transforms
from torchvision import transforms
from typing import Tuple

from livecellx.model_zoo.segmentation.custom_transforms import CustomTransformV5


def gen_train_transform_v0(
degrees: float, translation_range: Tuple[float, float], scale: Tuple[float, float]
Expand Down Expand Up @@ -34,9 +38,7 @@ def gen_train_transform_v0(
return train_transforms


def gen_train_transform_v1(
degrees: float, translation_range: Tuple[float, float], scale: Tuple[float, float]
) -> transforms.Compose:
def gen_train_transform_v1(degrees=0, translation_range=None, scale=None) -> transforms.Compose:
"""Generate the training data transformation.
Parameters
Expand Down Expand Up @@ -93,3 +95,144 @@ def gen_train_transform_v2(
]
)
return train_transforms


def gauss_noise_tensor(
img,
sigma=30.0,
):
assert isinstance(img, torch.Tensor)
dtype = img.dtype
if not img.is_floating_point():
img = img.to(torch.float32)

minus_or_plus = torch.randint(0, 2, (1,)).item()
if minus_or_plus == 0:
out = img + sigma * torch.randn_like(img)
else:
out = img - sigma * torch.randn_like(img)

if out.dtype != dtype:
out = out.to(dtype)

return out


def gen_train_transform_v3(
degrees: float, translation_range: Tuple[float, float], scale: Tuple[float, float], gauss_sigma=30
) -> transforms.Compose:
"""Generate the training data transformation.
Parameters
----------
degrees : float
The range of degrees to rotate the image.
translation_range : Tuple[float, float]
The range of translation in pixels.
scale : Tuple[float, float]
The range of scale factors.
Returns
-------
transforms.Compose
The composed transformation for training data.
"""

train_transforms = transforms.Compose(
[
# transforms.Resize((412, 412)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomAffine(degrees=degrees, translate=translation_range, scale=scale),
partial(gauss_noise_tensor, sigma=gauss_sigma),
transforms.Resize((412, 412)),
]
)
return train_transforms


def gen_train_transform_v4(
degrees: float, translation_range: Tuple[float, float], scale: Tuple[float, float], gauss_sigma=30
) -> transforms.Compose:
"""Generate the training data transformation.
Parameters
----------
degrees : float
The range of degrees to rotate the image.
translation_range : Tuple[float, float]
The range of translation in pixels.
scale : Tuple[float, float]
The range of scale factors.
Returns
-------
transforms.Compose
The composed transformation for training data.
"""

train_transforms = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomAffine(degrees=degrees, translate=translation_range, scale=scale, shear=10),
partial(gauss_noise_tensor, sigma=gauss_sigma),
transforms.Resize((412, 412)),
transforms.Normalize([0.485], [0.229]),
]
)
return train_transforms


def gen_train_transform_v5(
degrees: float, translation_range: Tuple[float, float] = None, scale: Tuple[float, float] = None, gauss_sigma=30
) -> CustomTransformV5:
"""Generate the training data transformation.
Parameters
----------
degrees : float
The range of degrees to rotate the image.
translation_range : Tuple[float, float]
The range of translation in pixels.
scale : Tuple[float, float]
The range of scale factors.
Returns
-------
transforms.Compose
The composed transformation for training data.
"""

train_transforms = CustomTransformV5(degrees=degrees, translation_range=translation_range, scale=scale)
return train_transforms


def gen_train_transform_v6(
degrees: float, translation_range: Tuple[float, float] = None, scale: Tuple[float, float] = None, gauss_sigma=30
) -> transforms.Compose:
"""Generate the training data transformation.
Parameters
----------
degrees : float
The range of degrees to rotate the image.
translation_range : Tuple[float, float]
The range of translation in pixels.
scale : Tuple[float, float]
The range of scale factors.
Returns
-------
transforms.Compose
The composed transformation for training data.
"""

train_transforms = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomAffine(degrees=degrees, translate=translation_range, scale=scale, shear=10),
transforms.Resize((256, 256)),
]
)
return train_transforms
48 changes: 48 additions & 0 deletions livecellx/model_zoo/segmentation/custom_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from torchvision import transforms
from typing import Tuple


class CustomTransformV5:
def __init__(
self, degrees: float, translation_range: Tuple[float, float] = None, scale: Tuple[float, float] = None
):
# Common transformations that should be applied to both images and masks
self.common_transforms = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomAffine(degrees=degrees, translate=translation_range, scale=scale, shear=10),
transforms.Resize((256, 256)),
]
)
# Image-specific transformations that should not be applied to masks
self.image_transforms = transforms.Compose(
[
transforms.GaussianBlur(kernel_size=3, sigma=30),
transforms.Normalize([127], [30]), # Adjust channel numbers according to your images
]
)

def apply_common_transforms(self, tensor):
# Assuming tensor is a PyTorch tensor, you might need to convert it to PIL Image first
# Depending on your specific setup, conversion between PIL Images and tensors may be required
tensor = self.common_transforms(tensor)
return tensor

def apply_image_transforms(self, image):
# Apply transformations specific to images
# Adjustments might be necessary depending on whether your data is in PIL Image or tensor format
image = self.image_transforms(image)
return image

def __call__(self, concat_img):
# Apply common transformations
concat_img = self.apply_common_transforms(concat_img)

# Apply image-specific transformations
# Assuming the first two images in concat_img are the ones needing image-specific transformations
# for i in range(2): # Adjust this range based on how many images you have that need these transformations
concat_img[:2] = self.apply_image_transforms(concat_img[:2])

return concat_img
5 changes: 4 additions & 1 deletion livecellx/model_zoo/segmentation/eval_csn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from livecellx.model_zoo.segmentation.sc_correction_dataset import CorrectSegNetDataset


def assemble_dataset(df: pd.DataFrame, apply_gt_seg_edt=False, exclude_raw_input_bg=False, input_type=None):
def assemble_dataset(
df: pd.DataFrame, apply_gt_seg_edt=False, exclude_raw_input_bg=False, input_type=None, use_gt_pixel_weight=False
):
assert input_type is not None
raw_img_paths = list(df["raw"])
scaled_seg_mask_paths = list(df["seg"])
Expand All @@ -47,6 +49,7 @@ def assemble_dataset(df: pd.DataFrame, apply_gt_seg_edt=False, exclude_raw_input
exclude_raw_input_bg=exclude_raw_input_bg,
input_type=input_type,
raw_df=df,
use_gt_pixel_weight=use_gt_pixel_weight,
)
return dataset

Expand Down
76 changes: 69 additions & 7 deletions livecellx/model_zoo/segmentation/sc_correction_aux.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,38 @@

LOG_PROGRESS_BAR = False

import torch
import torch.nn.functional as F


def weighted_mse_loss(predict, target, weights=None):
"""
Compute the weighted MSE loss with an optional weight map for the first channel.
Parameters:
- input: Tensor of predicted values (batch_size, channels, height, width).
- target: Tensor of target values with the same shape as input.
- weights: Optional. Tensor of weights for the first channel (batch_size, 1, height, width).
If None, no weights are applied and standard MSE loss is calculated.
Returns:
- loss: Scalar tensor representing the weighted MSE loss.
"""
if weights is not None:
# Calculate squared differences
squared_diff = (predict - target) ** 2

# Apply weights
weighted_squared_diff = squared_diff * weights

# Calculate mean of the weighted squared differences
loss = weighted_squared_diff.mean()
else:
# If no weights are provided, calculate standard MSE loss
loss = F.mse_loss(predict, target, reduction="mean")

return loss


class CorrectSegNetAux(LightningModule):
def __init__(
Expand All @@ -32,7 +64,7 @@ def __init__(
batch_size=5,
class_weights=[1, 1, 1],
model_type=None,
num_workers=16,
num_workers=32,
train_input_paths=None,
train_transforms=None,
seed=99,
Expand Down Expand Up @@ -151,7 +183,9 @@ def forward(self, x: torch.Tensor):
else:
return x

def compute_loss(self, output: torch.tensor, target: torch.tensor, aux_out=None, aux_target=None):
def compute_loss(
self, output: torch.tensor, target: torch.tensor, aux_out=None, aux_target=None, gt_pixel_weight=None
):
"""Compute loss fuction
Parameters
Expand All @@ -178,27 +212,55 @@ def compute_loss(self, output: torch.tensor, target: torch.tensor, aux_out=None,
), "seg_output shape should be batch_size x num_classes x height x width, got %s" % str(seg_output.shape)

if self.loss_type == "CE":
return self.loss_func(seg_output, target), aux_loss
seg_loss = self.loss_func(seg_output, target)
elif self.loss_type == "MSE":
total_loss = 0
num_classes = seg_output.shape[1]
for cat_dim in range(0, num_classes):
temp_target = target[:, cat_dim, ...]
temp_output = seg_output[:, cat_dim, ...]
total_loss += self.loss_func(temp_output, temp_target) * self.class_weights[cat_dim]
return total_loss, aux_loss
total_loss += (
weighted_mse_loss(temp_output, temp_target, weights=gt_pixel_weight) * self.class_weights[cat_dim]
)
seg_loss = total_loss
elif self.loss_type == "BCE":
# # Debugging
# print("*" * 40)
# print("Dimensions:")
# print("seg_output shape: ", seg_output.shape)
# print("target shape: ", target.shape)
# print("*" * 40)
# if gt_pixel_weight is not None:
# print("gt_pixel_weight shape: ", gt_pixel_weight.shape)
if gt_pixel_weight is not None:
# Repeat to match 3 channels of gt (seg and two OU masks): gt_pixel_weight shape: 2, 412, 412 -> 2, 3, 412, 412
gt_pixel_weight_repeated = gt_pixel_weight.unsqueeze(1).repeat(1, 3, 1, 1)
# assert len(gt_pixel_weight_repeated.shape) == 4
gt_pixel_weight_permuted = gt_pixel_weight_repeated.permute(0, 2, 3, 1)
else:
gt_pixel_weight_permuted = None
seg_output = seg_output.permute(0, 2, 3, 1)
target = target.permute(0, 2, 3, 1)
return self.loss_func(seg_output, target), aux_loss
self.loss_func = torch.nn.BCEWithLogitsLoss(
weight=gt_pixel_weight_permuted, pos_weight=torch.tensor(self.class_weights).cuda()
)

seg_loss = self.loss_func(seg_output, target)
else:
raise NotImplementedError("Loss:%s not implemented", self.loss_type)

return seg_loss, aux_loss

def training_step(self, batch, batch_idx):
# print("[train_step] x shape: ", batch["input"].shape)
# print("[train_step] y shape: ", batch["gt_mask"].shape)
x, y = batch["input"], batch["gt_mask"]
aux_target = batch["ou_aux"]
gt_pixel_weight = batch["gt_pixel_weight"]
output, aux_out = self(x)
seg_loss, aux_loss = self.compute_loss(output, y, aux_out=aux_out, aux_target=aux_target)
seg_loss, aux_loss = self.compute_loss(
output, y, aux_out=aux_out, aux_target=aux_target, gt_pixel_weight=gt_pixel_weight
)
loss = seg_loss + self.aux_loss_weight * aux_loss
predicted_labels = torch.argmax(output, dim=1)
self.log(
Expand Down
Loading

0 comments on commit 66a792c

Please sign in to comment.