Skip to content

Commit

Permalink
update sc_correction.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dummyindex committed Mar 19, 2024
1 parent 8d68338 commit 297d4be
Showing 1 changed file with 24 additions and 13 deletions.
37 changes: 24 additions & 13 deletions livecellx/model_zoo/segmentation/sc_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
TEST_LOADER_IN_VAL_LOADER_LIST_IDX = 1

LOG_PROGRESS_BAR = False


class CorrectSegNet(LightningModule):
def __init__(
self,
Expand Down Expand Up @@ -156,14 +158,16 @@ def compute_loss(self, output: torch.tensor, target: torch.tensor):
target = target.permute(0, 2, 3, 1)
return self.loss_func(output, target)

def training_step(self, batch, batch_idx):
def training_step(self, batch, batch_idx, subdir_metrics=False):
# print("[train_step] x shape: ", batch["input"].shape)
# print("[train_step] y shape: ", batch["gt_mask"].shape)
x, y = batch["input"], batch["gt_mask"]
output = self(x)
loss = self.compute_loss(output, y)
predicted_labels = torch.argmax(output, dim=1)
self.log("train_loss", loss, batch_size=self.batch_size, on_step=True, on_epoch=True, prog_bar=self.log_progress_bar)
self.log(
"train_loss", loss, batch_size=self.batch_size, on_step=True, on_epoch=True, prog_bar=self.log_progress_bar
)
# monitor more stats during training
# compute on subdirs
if self.global_step % 1000 != 0:
Expand All @@ -172,17 +176,21 @@ def training_step(self, batch, batch_idx):
batch_subdirs = np.array([self.train_dataset.get_subdir(idx.item()) for idx in batch["idx"]])
bin_output = self.compute_bin_output(output)
acc = self.train_accuracy(bin_output.long(), y.long())
self.log("train_acc", acc, prog_bar=self.log_progress_bar, on_step=True, on_epoch=True, batch_size=self.batch_size)
self.log(
"train_acc", acc, prog_bar=self.log_progress_bar, on_step=True, on_epoch=True, batch_size=self.batch_size
)

for subdir in subdir_set:
if not (subdir in batch_subdirs):
continue
subdir_indexer = batch_subdirs == subdir
batched_loss = self.compute_loss(output[subdir_indexer], y[subdir_indexer])
# subdir_loss_map[subdir] = loss[list(subdir_indexer)].mean()
self.log(f"train_loss_{subdir}", batched_loss, prog_bar=self.log_progress_bar)
batched_acc = self.val_accuracy(bin_output[subdir_indexer].long(), y[subdir_indexer].long())
self.log(f"train_acc_{subdir}", batched_acc, prog_bar=self.log_progress_bar)
if subdir_metrics:
# acc by subdir
for subdir in subdir_set:
if not (subdir in batch_subdirs):
continue
subdir_indexer = batch_subdirs == subdir
batched_loss = self.compute_loss(output[subdir_indexer], y[subdir_indexer])
# subdir_loss_map[subdir] = loss[list(subdir_indexer)].mean()
self.log(f"train_loss_{subdir}", batched_loss, prog_bar=self.log_progress_bar)
batched_acc = self.val_accuracy(bin_output[subdir_indexer].long(), y[subdir_indexer].long())
self.log(f"train_acc_{subdir}", batched_acc, prog_bar=self.log_progress_bar)
return loss

def training_epoch_end(self, outputs):
Expand Down Expand Up @@ -261,7 +269,10 @@ def test_step(self, batch, batch_idx):
]
for metric in log_metrics:
self.log(
f"test_{metric}_{subdir}", np.mean(metrics_dict[metric]), prog_bar=self.log_progress_bar, add_dataloader_idx=False
f"test_{metric}_{subdir}",
np.mean(metrics_dict[metric]),
prog_bar=self.log_progress_bar,
add_dataloader_idx=False,
)

def compute_bin_output(self, output):
Expand Down

0 comments on commit 297d4be

Please sign in to comment.