Skip to content

Commit

Permalink
update ou_utils.py: parallize CSN correct case generation
Browse files Browse the repository at this point in the history
  • Loading branch information
dummyindex committed Mar 29, 2024
1 parent bc94d33 commit 90e6fdc
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 40 deletions.
47 changes: 42 additions & 5 deletions livecellx/model_zoo/segmentation/csn_configs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
import torch
from torchvision import transforms
from torchvision import transforms
Expand Down Expand Up @@ -94,7 +95,10 @@ def gen_train_transform_v2(
return train_transforms


def gauss_noise_tensor(img, sigma=30.0):
def gauss_noise_tensor(
img,
sigma=30.0,
):
assert isinstance(img, torch.Tensor)
dtype = img.dtype
if not img.is_floating_point():
Expand All @@ -113,7 +117,7 @@ def gauss_noise_tensor(img, sigma=30.0):


def gen_train_transform_v3(
degrees: float, translation_range: Tuple[float, float], scale: Tuple[float, float]
degrees: float, translation_range: Tuple[float, float], scale: Tuple[float, float], gauss_sigma=30
) -> transforms.Compose:
"""Generate the training data transformation.
Expand All @@ -137,15 +141,48 @@ def gen_train_transform_v3(
# transforms.Resize((412, 412)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomAffine(degrees=degrees, translate=translation_range, scale=scale),
gauss_noise_tensor,
partial(gauss_noise_tensor, sigma=gauss_sigma),
transforms.Resize((412, 412)),
]
)
return train_transforms


def gen_train_transform_v4(
degrees: float, translation_range: Tuple[float, float], scale: Tuple[float, float]
degrees: float, translation_range: Tuple[float, float], scale: Tuple[float, float], gauss_sigma=30
) -> transforms.Compose:
"""Generate the training data transformation.
Parameters
----------
degrees : float
The range of degrees to rotate the image.
translation_range : Tuple[float, float]
The range of translation in pixels.
scale : Tuple[float, float]
The range of scale factors.
Returns
-------
transforms.Compose
The composed transformation for training data.
"""

train_transforms = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomAffine(degrees=degrees, translate=translation_range, scale=scale, shear=10),
partial(gauss_noise_tensor, sigma=gauss_sigma),
transforms.Resize((412, 412)),
transforms.Normalize([0.485], [0.229]),
]
)
return train_transforms


def gen_train_transform_v5(
degrees: float, translation_range: Tuple[float, float], scale: Tuple[float, float], gauss_sigma=30
) -> transforms.Compose:
"""Generate the training data transformation.
Expand All @@ -169,7 +206,7 @@ def gen_train_transform_v4(
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomAffine(degrees=degrees, translate=translation_range, scale=scale, shear=10),
gauss_noise_tensor,
transforms.GaussianBlur(kernel_size=3),
transforms.Resize((412, 412)),
transforms.Normalize([0.485], [0.229]),
]
Expand Down
105 changes: 70 additions & 35 deletions livecellx/segment/ou_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SingleCellTrajectoryCollection,
)
from livecellx.core.datasets import LiveCellImageDataset
from livecellx.core.parallel import parallelize
from livecellx.preprocess.utils import (
overlay,
enhance_contrast,
Expand Down Expand Up @@ -292,6 +293,8 @@ def csn_augment_helper(
"""
if train_path_tuples is None:
train_path_tuples = []
if augmented_data is None:
augmented_data = []
if normalize_img_uint8:
img_crop = normalize_img_to_uint8(img_crop)
combined_gt_binary_mask = combined_gt_label_mask > 0
Expand Down Expand Up @@ -402,6 +405,56 @@ def csn_augment_helper(
}


def _gen_sc_csn_correct_data_wrapper(
sc: SingleCellStatic,
filename_pattern,
raw_out_dir,
seg_out_dir,
gt_out_dir,
gt_label_out_dir,
augmented_seg_dir,
raw_transformed_img_dir,
augmented_diff_seg_dir,
):
img_id = sc.timeframe
seg_label = sc.id
# (img_crop, seg_crop, combined_gt_label_mask) = underseg_overlay_gt_masks(seg_label, scs, padding_scale=2)
img_crop = sc.get_img_crop()
seg_crop = sc.get_contour_mask()
# Only 1 gt mask for mask cases, seg_crop is sufficient
combined_gt_label_mask = seg_crop

filename = filename_pattern % (img_id, seg_label)
raw_img_path = raw_out_dir / filename
seg_img_path = seg_out_dir / filename
gt_img_path = gt_out_dir / filename
gt_label_img_path = gt_label_out_dir / filename

scale_factors = [0] # We don't need to erode/dilate the data for correct cases
# call csn augment helper
res_dict = csn_augment_helper(
img_crop=img_crop,
seg_label_crop=seg_crop,
combined_gt_label_mask=combined_gt_label_mask,
scale_factors=scale_factors,
train_path_tuples=None,
augmented_data=None,
img_id=img_id,
seg_label=seg_label,
gt_label=None,
raw_img_path=raw_img_path,
seg_img_path=seg_img_path,
gt_img_path=gt_img_path,
gt_label_img_path=gt_label_img_path,
augmented_seg_dir=augmented_seg_dir,
augmented_diff_seg_dir=augmented_diff_seg_dir,
raw_transformed_img_dir=raw_transformed_img_dir,
df_save_path=None,
filename_pattern="img-%d_scId-%s.tif",
)
return res_dict


def gen_csn_correct_case(scs, out_dir, filename_pattern="img-%d_scId-%s.tif"):
out_subdir = out_dir / "correct_cases"
raw_out_dir = out_subdir / "raw"
Expand All @@ -420,46 +473,28 @@ def gen_csn_correct_case(scs, out_dir, filename_pattern="img-%d_scId-%s.tif"):
os.makedirs(raw_transformed_img_dir, exist_ok=True)
os.makedirs(augmented_diff_seg_dir, exist_ok=True)

scale_factors = [0] # We
train_path_tuples = []
augmented_data = []

sc_inputs = []
for sc in tqdm(scs):
img_id = sc.timeframe
seg_label = sc.id
# (img_crop, seg_crop, combined_gt_label_mask) = underseg_overlay_gt_masks(seg_label, scs, padding_scale=2)
img_crop = sc.get_img_crop()
seg_crop = sc.get_contour_mask()
# Only 1 gt mask for mask cases, seg_crop is sufficient
combined_gt_label_mask = seg_crop

filename = filename_pattern % (img_id, seg_label)
raw_img_path = raw_out_dir / filename
seg_img_path = seg_out_dir / filename
gt_img_path = gt_out_dir / filename
gt_label_img_path = gt_label_out_dir / filename

# call csn augment helper
csn_augment_helper(
img_crop=img_crop,
seg_label_crop=seg_crop,
combined_gt_label_mask=combined_gt_label_mask,
scale_factors=scale_factors,
train_path_tuples=train_path_tuples,
augmented_data=augmented_data,
img_id=img_id,
seg_label=seg_label,
gt_label=None,
raw_img_path=raw_img_path,
seg_img_path=seg_img_path,
gt_img_path=gt_img_path,
gt_label_img_path=gt_label_img_path,
augmented_seg_dir=augmented_seg_dir,
augmented_diff_seg_dir=augmented_diff_seg_dir,
raw_transformed_img_dir=raw_transformed_img_dir,
df_save_path=None,
filename_pattern="img-%d_scId-%s.tif",
sc_inputs.append(
{
"sc": sc,
"filename_pattern": filename_pattern,
"raw_out_dir": raw_out_dir,
"seg_out_dir": seg_out_dir,
"gt_out_dir": gt_out_dir,
"gt_label_out_dir": gt_label_out_dir,
"augmented_seg_dir": augmented_seg_dir,
"raw_transformed_img_dir": raw_transformed_img_dir,
"augmented_diff_seg_dir": augmented_diff_seg_dir,
}
)
process_outputs = parallelize(_gen_sc_csn_correct_data_wrapper, sc_inputs)
for output in process_outputs:
train_path_tuples.extend(output["train_path_tuples"])
augmented_data.extend(output["augmented_data"])

pd.DataFrame(
train_path_tuples,
Expand Down

0 comments on commit 90e6fdc

Please sign in to comment.