Skip to content

Commit

Permalink
Merge pull request #7 from juglab/v0.2.3
Browse files Browse the repository at this point in the history
v0.2.3
  • Loading branch information
lmanan authored Jun 15, 2021
2 parents b7b965e + ff56e4f commit 1f8fe3e
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 1,969 deletions.
6 changes: 3 additions & 3 deletions EmbedSeg/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def begin_evaluating(test_configs, verbose=True, mask_region = None, mask_intens
if(test_configs['name']=='2d'):
test(verbose = verbose, grid_x = test_configs['grid_x'], grid_y = test_configs['grid_y'],
pixel_x = test_configs['pixel_x'], pixel_y = test_configs['pixel_y'],
one_hot = test_configs['dataset']['kwargs']['one_hot'], avg_bg = avg_bg)
one_hot = test_configs['dataset']['kwargs']['one_hot'], avg_bg = avg_bg, n_sigma=n_sigma)
elif(test_configs['name']=='3d'):
test_3d(verbose=verbose,
grid_x=test_configs['grid_x'], grid_y=test_configs['grid_y'], grid_z=test_configs['grid_z'],
Expand All @@ -70,7 +70,7 @@ def begin_evaluating(test_configs, verbose=True, mask_region = None, mask_intens



def test(verbose, grid_y=1024, grid_x=1024, pixel_y=1, pixel_x=1, one_hot = False, avg_bg = 0):
def test(verbose, grid_y=1024, grid_x=1024, pixel_y=1, pixel_x=1, one_hot = False, avg_bg = 0, n_sigma = 2):
"""
:param verbose: if True, then average prevision is printed out for each image
:param grid_y:
Expand Down Expand Up @@ -126,7 +126,7 @@ def test(verbose, grid_y=1024, grid_x=1024, pixel_y=1, pixel_x=1, one_hot = Fals

center_x, center_y, samples_x, samples_y, sample_spatial_embedding_x, sample_spatial_embedding_y, sigma_x, sigma_y, \
color_sample_dic, color_embedding_dic = prepare_embedding_for_test_image(instance_map = instance_map, output = output, grid_x = grid_x, grid_y = grid_y,
pixel_x = pixel_x, pixel_y =pixel_y, predictions =predictions)
pixel_x = pixel_x, pixel_y =pixel_y, predictions =predictions, n_sigma = n_sigma)

base, _ = os.path.splitext(os.path.basename(sample['im_name'][0]))
imageFileNames.append(base)
Expand Down
100 changes: 20 additions & 80 deletions EmbedSeg/train.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import os
import shutil

import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

from EmbedSeg.criterions import get_loss
from EmbedSeg.datasets import get_dataset
from EmbedSeg.models import get_model
from EmbedSeg.utils.utils import AverageMeter, Cluster, Cluster_3d, Logger, Visualizer, prepare_embedding_for_train_image

torch.backends.cudnn.benchmark = True
from matplotlib.colors import ListedColormap
import numpy as np
Expand Down Expand Up @@ -137,7 +134,7 @@ def train_3d(virtual_batch_multiplier, one_hot, n_sigma, args):


def train_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, grid_y, grid_z, pixel_x, pixel_y, pixel_z, n_sigma,
args): # this is without virtual batches!
zslice, args): # this is without virtual batches!

# define meters
loss_meter = AverageMeter()
Expand All @@ -163,50 +160,22 @@ def train_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, gr
loss_meter.update(loss.item())
if display and i % display_it == 0:
with torch.no_grad():
visualizer.display(im[0], key='image', title='Image')
visualizer.display(im[0, 0, zslice], key='image', title='Image')
predictions = cluster.cluster_with_gt(output[0], instances[0], n_sigma=n_sigma)
if one_hot:
instance = invert_one_hot(instances[0].cpu().detach().numpy())
visualizer.display(instance, key='groundtruth', title='Ground Truth') # TODO
instance_ids = np.arange(instances.size(1)) # instances[0] --> DYX
else:
visualizer.display(instances[0].cpu(), key='groundtruth', title='Ground Truth') # TODO
visualizer.display(instances[0, zslice].cpu(), key='groundtruth', title='Ground Truth') # TODO
instance_ids = instances[0].unique()
instance_ids = instance_ids[instance_ids != 0]

if display_embedding:
center_x, center_y, samples_x, samples_y, sample_spatial_embedding_x, \
sample_spatial_embedding_y, sigma_x, sigma_y, color_sample_dic, color_embedding_dic = \
prepare_embedding_for_train_image(one_hot=one_hot, grid_x=grid_x, grid_y=grid_y,
pixel_x=pixel_x, pixel_y=pixel_y,
predictions=predictions, instance_ids=instance_ids,
center_images=center_images,
output=output, instances=instances, n_sigma=n_sigma)
if one_hot:
visualizer.display(torch.max(instances[0], dim=0)[0], key='center', title='Center',
center_x=center_x,
center_y=center_y,
samples_x=samples_x, samples_y=samples_y,
sample_spatial_embedding_x=sample_spatial_embedding_x,
sample_spatial_embedding_y=sample_spatial_embedding_y,
sigma_x=sigma_x, sigma_y=sigma_y,
color_sample=color_sample_dic, color_embedding=color_embedding_dic)
else:
visualizer.display(instances[0] > 0, key='center', title='Center', center_x=center_x,
center_y=center_y,
samples_x=samples_x, samples_y=samples_y,
sample_spatial_embedding_x=sample_spatial_embedding_x,
sample_spatial_embedding_y=sample_spatial_embedding_y,
sigma_x=sigma_x, sigma_y=sigma_y,
color_sample=color_sample_dic, color_embedding=color_embedding_dic)
visualizer.display(predictions.cpu(), key='prediction', title='Prediction') # TODO
visualizer.display(predictions.cpu()[zslice, ...], key='prediction', title='Prediction') # TODO

return loss_meter.avg





def val(virtual_batch_multiplier, one_hot, n_sigma, args):
# define meters
loss_meter, iou_meter = AverageMeter(), AverageMeter()
Expand Down Expand Up @@ -248,7 +217,7 @@ def val_vanilla(display, display_embedding, display_it, one_hot, grid_x, grid_y,
if one_hot:
instance = invert_one_hot(instances[0].cpu().detach().numpy())
visualizer.display(instance, key='groundtruth', title='Ground Truth') # TODO
instance_ids = np.arange(instances[0].size(1))
instance_ids = np.arange(instances.size(1))
else:
visualizer.display(instances[0].cpu(), key='groundtruth', title='Ground Truth') # TODO
instance_ids = instances[0].unique()
Expand Down Expand Up @@ -284,8 +253,6 @@ def val_vanilla(display, display_embedding, display_it, one_hot, grid_x, grid_y,

return loss_meter.avg, iou_meter.avg



def val_3d(virtual_batch_multiplier, one_hot, n_sigma, args):
# define meters
loss_meter, iou_meter = AverageMeter(), AverageMeter()
Expand All @@ -306,7 +273,7 @@ def val_3d(virtual_batch_multiplier, one_hot, n_sigma, args):
return loss_meter.avg * virtual_batch_multiplier, iou_meter.avg


def val_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, grid_y, grid_z, pixel_x, pixel_y, pixel_z, n_sigma, args):
def val_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, grid_y, grid_z, pixel_x, pixel_y, pixel_z, n_sigma, zslice, args):
# define meters
loss_meter, iou_meter = AverageMeter(), AverageMeter()
# put model into eval mode
Expand All @@ -322,44 +289,19 @@ def val_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, grid
loss = loss.mean()
if display and i % display_it == 0:
with torch.no_grad():
visualizer.display(im[0], key='image', title='Image')
visualizer.display(im[0, 0, zslice], key='image', title='Image')
predictions = cluster.cluster_with_gt(output[0], instances[0], n_sigma=n_sigma)
if one_hot:
instance = invert_one_hot(instances[0].cpu().detach().numpy())
visualizer.display(instance, key='groundtruth', title='Ground Truth') # TODO
instance_ids = np.arange(instances[0].size(1))
instance_ids = np.arange(instances.size(1))
else:
visualizer.display(instances[0].cpu(), key='groundtruth', title='Ground Truth') # TODO
visualizer.display(instances[0, zslice].cpu(), key='groundtruth', title='Ground Truth') # TODO
instance_ids = instances[0].unique()
instance_ids = instance_ids[instance_ids != 0]
if (display_embedding):
center_x, center_y, samples_x, samples_y, sample_spatial_embedding_x, \
sample_spatial_embedding_y, sigma_x, sigma_y, color_sample_dic, color_embedding_dic = \
prepare_embedding_for_train_image(one_hot=one_hot, grid_x=grid_x, grid_y=grid_y,
pixel_x=pixel_x, pixel_y=pixel_y,
predictions=predictions, instance_ids=instance_ids,
center_images=center_images,
output=output, instances=instances, n_sigma=n_sigma)
if one_hot:
visualizer.display(torch.max(instances[0], dim=0)[0].cpu(), key='center', title='Center',
# torch.max returns a tuple
center_x=center_x,
center_y=center_y,
samples_x=samples_x, samples_y=samples_y,
sample_spatial_embedding_x=sample_spatial_embedding_x,
sample_spatial_embedding_y=sample_spatial_embedding_y,
sigma_x=sigma_x, sigma_y=sigma_y,
color_sample=color_sample_dic, color_embedding=color_embedding_dic)
else:
visualizer.display(instances[0] > 0, key='center', title='Center', center_x=center_x,
center_y=center_y,
samples_x=samples_x, samples_y=samples_y,
sample_spatial_embedding_x=sample_spatial_embedding_x,
sample_spatial_embedding_y=sample_spatial_embedding_y,
sigma_x=sigma_x, sigma_y=sigma_y,
color_sample=color_sample_dic, color_embedding=color_embedding_dic)

visualizer.display(predictions.cpu(), key='prediction', title='Prediction') # TODO

visualizer.display(predictions.cpu()[zslice, ...], key='prediction', title='Prediction') # TODO

loss_meter.update(loss.item())

Expand All @@ -375,13 +317,14 @@ def invert_one_hot(image):
return instance


def save_checkpoint(state, is_best, epoch, save_dir, name='checkpoint.pth'):
def save_checkpoint(state, is_best, epoch, save_dir, save_checkpoint_frequency, name='checkpoint.pth'):
print('=> saving checkpoint')
file_name = os.path.join(save_dir, name)
torch.save(state, file_name)
if (epoch % 10 == 0):
file_name2 = os.path.join(save_dir, str(epoch) + "_" + name)
torch.save(state, file_name2)
if(save_checkpoint_frequency is not None):
if (epoch % int(save_checkpoint_frequency) == 0):
file_name2 = os.path.join(save_dir, str(epoch) + "_" + name)
torch.save(state, file_name2)
if is_best:
shutil.copyfile(file_name, os.path.join(
save_dir, 'best_iou_model.pth'))
Expand Down Expand Up @@ -409,7 +352,6 @@ def begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict,

# train dataloader


train_dataset = get_dataset(train_dataset_dict['name'], train_dataset_dict['kwargs'])
train_dataset_it = torch.utils.data.DataLoader(train_dataset, batch_size=train_dataset_dict['batch_size'],
shuffle=True, drop_last=True,
Expand Down Expand Up @@ -459,8 +401,6 @@ def lambda_(epoch):
configs['pixel_x'], configs['one_hot'])

# Visualizer


visualizer = Visualizer(('image', 'groundtruth', 'prediction', 'center'), color_map) # 5 keys

# Logger
Expand Down Expand Up @@ -519,10 +459,10 @@ def lambda_(epoch):
train_loss = train_vanilla_3d(display=configs['display'],
display_embedding=configs['display_embedding'],
display_it=configs['display_it'], one_hot=configs['one_hot'],
n_sigma=loss_dict['lossOpts']['n_sigma'], grid_x=configs['grid_x'],
n_sigma=loss_dict['lossOpts']['n_sigma'], zslice = configs['display_zslice'], grid_x=configs['grid_x'],
grid_y=configs['grid_y'], grid_z=configs['grid_z'],
pixel_x=configs['pixel_x'], pixel_y=configs['pixel_y'],
pixel_z=configs['pixel_z'], args=loss_dict['lossW'])
pixel_z=configs['pixel_z'], args=loss_dict['lossW'], )

if (val_dataset_dict['virtual_batch_multiplier'] > 1):
val_loss, val_iou = val_3d(virtual_batch_multiplier=val_dataset_dict['virtual_batch_multiplier'],
Expand All @@ -532,7 +472,7 @@ def lambda_(epoch):
val_loss, val_iou = val_vanilla_3d(display=configs['display'],
display_embedding=configs['display_embedding'],
display_it=configs['display_it'], one_hot=configs['one_hot'],
n_sigma=loss_dict['lossOpts']['n_sigma'], grid_x=configs['grid_x'],
n_sigma=loss_dict['lossOpts']['n_sigma'], zslice = configs['display_zslice'], grid_x=configs['grid_x'],
grid_y=configs['grid_y'], grid_z=configs['grid_z'],
pixel_x=configs['pixel_x'], pixel_y=configs['pixel_y'], pixel_z=configs['pixel_z'],
args=loss_dict['lossW'])
Expand All @@ -558,6 +498,6 @@ def lambda_(epoch):
'optim_state_dict': optimizer.state_dict(),
'logger_data': logger.data,
}
save_checkpoint(state, is_best, epoch, save_dir=configs['save_dir'])
save_checkpoint(state, is_best, epoch, save_dir=configs['save_dir'], save_checkpoint_frequency=configs['save_checkpoint_frequency'])


11 changes: 9 additions & 2 deletions EmbedSeg/utils/create_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ def create_configs(save_dir,
anisotropy_factor = None,
l_y = 1,
l_x = 1,

save_checkpoint_frequency = None,
display_zslice = None
):
"""
Creates `configs` dictionary from parameters.
Expand Down Expand Up @@ -337,6 +338,9 @@ def create_configs(save_dir,
Pixel size in y
pixel_x: int
Pixel size in x
save_checkpoint_frequency: int
Save model weights after 'n' epochs (in addition to last and best model weights)
Default is None
"""
if (n_z is None):
l_z = None
Expand All @@ -358,7 +362,10 @@ def create_configs(save_dir,
pixel_z = l_z,
pixel_y = l_y,
pixel_x = l_x,
one_hot=one_hot)
one_hot=one_hot,
save_checkpoint_frequency=save_checkpoint_frequency,
display_zslice = display_zslice
)
print(
"`configs` dictionary successfully created with: "
"\n -- n_epochs equal to {}, "
Expand Down
12 changes: 6 additions & 6 deletions EmbedSeg/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,9 @@ def prepare_embedding_for_train_image(one_hot, grid_x, grid_y, pixel_x, pixel_y,
sample_spatial_embedding_y[id.item()] = add_samples(samples_spatial_embeddings, 1, grid_y - 1,
pixel_y)

centre_mask = in_mask & center_images[0]
if (centre_mask.sum().eq(1)):
center = xym_s[centre_mask.expand_as(xym_s)].view(2, 1, 1)
center_mask = in_mask & center_images[0].byte()
if (center_mask.sum().eq(1)):
center = xym_s[center_mask.expand_as(xym_s)].view(2, 1, 1)
else:
xy_in = xym_s[in_mask.expand_as(xym_s)].view(2, -1)
center = xy_in.mean(1).view(2, 1, 1) # 2 x 1 x 1
Expand All @@ -444,7 +444,7 @@ def prepare_embedding_for_train_image(one_hot, grid_x, grid_y, pixel_x, pixel_y,
sample_spatial_embedding_y, sigma_x, sigma_y, color_sample_dic, color_embedding_dic


def prepare_embedding_for_test_image(instance_map, output, grid_x, grid_y, pixel_x, pixel_y, predictions):
def prepare_embedding_for_test_image(instance_map, output, grid_x, grid_y, pixel_x, pixel_y, predictions, n_sigma):
instance_ids = instance_map.unique()
instance_ids = instance_ids[instance_ids != 0]

Expand All @@ -454,7 +454,7 @@ def prepare_embedding_for_test_image(instance_map, output, grid_x, grid_y, pixel
height, width = instance_map.size(0), instance_map.size(1)
xym_s = xym[:, 0:height, 0:width].contiguous()
spatial_emb = torch.tanh(output[0, 0:2]).cpu() + xym_s
sigma = output[0, 2:2 + 2] # 2/3 Y X replace last + 2 with n_sigma parameter IMP TODO
sigma = output[0, 2:2 + n_sigma]
color_sample = sns.color_palette("dark")
color_embedding = sns.color_palette("bright")
color_sample_dic = {}
Expand Down Expand Up @@ -495,7 +495,7 @@ def prepare_embedding_for_test_image(instance_map, output, grid_x, grid_y, pixel
center_y[id.item()] = degrid(center[1], grid_y - 1, pixel_y)

# sigma
s = sigma[in_mask.expand_as(sigma)].view(2, -1).mean(1) # TODO view(2, -1) should become nsigma, -1
s = sigma[in_mask.expand_as(sigma)].view(n_sigma, -1).mean(1)
s = torch.exp(s * 10)
sigma_x_tmp = 0.5 / s[0]
sigma_y_tmp = 0.5 / s[1]
Expand Down
Loading

0 comments on commit 1f8fe3e

Please sign in to comment.