Skip to content

Commit

Permalink
update sc_correction_aux.py and train_csn.py: fix bug for 02 MSE+EDT+…
Browse files Browse the repository at this point in the history
…AUX model
  • Loading branch information
dummyindex committed Mar 29, 2024
1 parent 44e9405 commit 96815c2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 26 deletions.
7 changes: 1 addition & 6 deletions livecellx/model_zoo/segmentation/sc_correction_aux.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,11 @@ def weighted_mse_loss(predict, target, weights=None):
- loss: Scalar tensor representing the weighted MSE loss.
"""
if weights is not None:
# Ensure the weights can be broadcasted to match the input shape
# Weights for channels other than the first are assumed to be 1
expanded_weights = torch.ones_like(predict)
expanded_weights[:, 0, :, :] = weights[:, 0, :, :] # Apply weights to the first channel

# Calculate squared differences
squared_diff = (predict - target) ** 2

# Apply weights
weighted_squared_diff = squared_diff * expanded_weights
weighted_squared_diff = squared_diff * weights

# Calculate mean of the weighted squared differences
loss = weighted_squared_diff.mean()
Expand Down
57 changes: 37 additions & 20 deletions livecellx/model_zoo/segmentation/train_csn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def parse_args():
parser.add_argument("--ou_aux", dest="ou_aux", default=False, action="store_true")
parser.add_argument("--aug-ver", default="v0", type=str, help="The version of the augmentation to use.")
parser.add_argument("--use-gt-pixel-weight", default=False, action="store_true")
parser.add_argument("--aux-loss-weight", default=0.5, type=float)

args = parser.parse_args()

Expand Down Expand Up @@ -188,27 +189,43 @@ def df2dataset(df):
if args.debug:
logger = TensorBoardLogger(save_dir=".", name="test_logs", version=args.model_version)

csn_model_cls = CorrectSegNet
if args.ou_aux:
csn_model_cls = CorrectSegNetAux

model = csn_model_cls(
# train_input_paths=train_input_tuples,
lr=args.lr,
num_workers=1,
batch_size=args.batch_size,
train_transforms=train_transforms,
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
kernel_size=kernel_size,
loss_type=args.loss,
class_weights=args.class_weights,
# only for record keeping purposes; handled by the dataset
input_type=args.input_type,
apply_gt_seg_edt=args.apply_gt_seg_edt,
exclude_raw_input_bg=args.exclude_raw_input_bg,
)
model = CorrectSegNetAux(
# train_input_paths=train_input_tuples,
lr=args.lr,
num_workers=1,
batch_size=args.batch_size,
train_transforms=train_transforms,
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
kernel_size=kernel_size,
loss_type=args.loss,
class_weights=args.class_weights,
# only for record keeping purposes; handled by the dataset
input_type=args.input_type,
apply_gt_seg_edt=args.apply_gt_seg_edt,
exclude_raw_input_bg=args.exclude_raw_input_bg,
aux_loss_weight=args.aux_loss_weight,
)
else:
model = CorrectSegNet(
# train_input_paths=train_input_tuples,
lr=args.lr,
num_workers=1,
batch_size=args.batch_size,
train_transforms=train_transforms,
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
kernel_size=kernel_size,
loss_type=args.loss,
class_weights=args.class_weights,
# only for record keeping purposes; handled by the dataset
input_type=args.input_type,
apply_gt_seg_edt=args.apply_gt_seg_edt,
exclude_raw_input_bg=args.exclude_raw_input_bg,
)

print("logger save dir:", logger.save_dir)
print("logger subdir:", logger.sub_dir)
Expand Down

0 comments on commit 96815c2

Please sign in to comment.