From 1e7f3cc3ec10c5e9b3bc89b14fdeb2949e45a8e2 Mon Sep 17 00:00:00 2001 From: tangy5 <58751975+tangy5@users.noreply.github.com> Date: Tue, 25 Jul 2023 14:34:30 -0500 Subject: [PATCH] Replace 2pt5 model files on monailabel side (#12) Replace the old 2pt5 model files with latest training models networks. model.py vista_image_encoder vista_prompt_encoder --------- Signed-off-by: tangy5 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../lib/configs/vista_point_2pt5.py | 2 +- .../lib/model/vista_point_2pt5/__init__.py | 10 + .../{models_samm2pt5d.py => model_2pt5.py} | 149 ++- .../model/vista_point_2pt5/trainer_2pt5d.py | 966 ------------------ .../vista_2pt5_image_encoder.py | 138 +++ .../vista_2pt5_prompt_encoder.py | 148 +++ 6 files changed, 361 insertions(+), 1052 deletions(-) create mode 100644 monailabel/monaivista/lib/model/vista_point_2pt5/__init__.py rename monailabel/monaivista/lib/model/vista_point_2pt5/{models_samm2pt5d.py => model_2pt5.py} (77%) delete mode 100644 monailabel/monaivista/lib/model/vista_point_2pt5/trainer_2pt5d.py create mode 100644 monailabel/monaivista/lib/model/vista_point_2pt5/vista_2pt5_image_encoder.py create mode 100644 monailabel/monaivista/lib/model/vista_point_2pt5/vista_2pt5_prompt_encoder.py diff --git a/monailabel/monaivista/lib/configs/vista_point_2pt5.py b/monailabel/monaivista/lib/configs/vista_point_2pt5.py index 388e224..0ff4b4d 100644 --- a/monailabel/monaivista/lib/configs/vista_point_2pt5.py +++ b/monailabel/monaivista/lib/configs/vista_point_2pt5.py @@ -15,7 +15,7 @@ import lib.infers import lib.trainers -from lib.model.vista_point_2pt5.models_samm2pt5d import sam_model_registry +from lib.model.vista_point_2pt5.model_2pt5 import sam_model_registry from monailabel.interfaces.config import TaskConfig from monailabel.interfaces.tasks.infer_v2 import InferTask from monailabel.interfaces.tasks.train import TrainTask diff --git a/monailabel/monaivista/lib/model/vista_point_2pt5/__init__.py b/monailabel/monaivista/lib/model/vista_point_2pt5/__init__.py new file mode 100644 index 0000000..1e97f89 --- /dev/null +++ b/monailabel/monaivista/lib/model/vista_point_2pt5/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monailabel/monaivista/lib/model/vista_point_2pt5/models_samm2pt5d.py b/monailabel/monaivista/lib/model/vista_point_2pt5/model_2pt5.py similarity index 77% rename from monailabel/monaivista/lib/model/vista_point_2pt5/models_samm2pt5d.py rename to monailabel/monaivista/lib/model/vista_point_2pt5/model_2pt5.py index c5999af..08759ec 100644 --- a/monailabel/monaivista/lib/model/vista_point_2pt5/models_samm2pt5d.py +++ b/monailabel/monaivista/lib/model/vista_point_2pt5/model_2pt5.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. @@ -7,24 +18,25 @@ from functools import partial from typing import Any, Dict, List, Tuple +import monai import torch +from segment_anything.modeling import TwoWayTransformer +from segment_anything.modeling.mask_decoder import MaskDecoder from torch import nn from torch.nn import functional as F -from .segment_anything.modeling import TwoWayTransformer -from .segment_anything.modeling.image_encoder import ImageEncoderViT -from .segment_anything.modeling.mask_decoder import MaskDecoder -from .segment_anything.modeling.prompt_encoder import PromptEncoder +from .vista_2pt5_image_encoder import VistaImageEncoderViT +from .vista_2pt5_prompt_encoder import VistaPromptEncoder -class Samm2pt5D(nn.Module): +class Vista2pt5D(nn.Module): mask_threshold: float = 0.5 image_format: str = "RGB" def __init__( self, - image_encoder: ImageEncoderViT, - prompt_encoder: PromptEncoder, + image_encoder: VistaImageEncoderViT, + prompt_encoder: VistaPromptEncoder, mask_decoder: MaskDecoder, pixel_mean: List[float] = [123.675, 116.28, 103.53], pixel_std: List[float] = [58.395, 57.12, 57.375], @@ -67,7 +79,6 @@ def get_mask_prediction( for image_record, curr_embedding in zip(batched_input, image_embeddings): if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) - # raise NotImplementedError else: points = None sparse_embeddings, dense_embeddings = self.prompt_encoder( @@ -86,7 +97,6 @@ def get_mask_prediction( high_res_masks = self.postprocess_masks( low_res_masks, - # input_size=image_record["image"].shape[-2:], original_size=image_record["original_size"], ) masks = high_res_masks > self.mask_threshold @@ -124,6 +134,8 @@ def forward( input frame of the model. 'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN. + 'labels': (torch.Tensor) Batched labels for class-label prompt, + with shape BxN. 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of the model. 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, @@ -151,7 +163,6 @@ def forward( for image_record, curr_embedding in zip(batched_input, image_embeddings): if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) - # raise NotImplementedError else: points = None sparse_embeddings, dense_embeddings = self.prompt_encoder( @@ -194,7 +205,6 @@ def forward( def postprocess_masks( self, masks: torch.Tensor, - # input_size: Tuple[int, ...], original_size: Tuple[int, ...], ) -> torch.Tensor: """ @@ -261,7 +271,7 @@ def preprocess(self, x: torch.Tensor, is_input=True) -> torch.Tensor: return x -def _build_sam2pt5d( +def _build_vista2pt5d( encoder_in_chans, encoder_embed_dim, encoder_depth, @@ -269,13 +279,15 @@ def _build_sam2pt5d( encoder_global_attn_indexes, checkpoint=None, image_size=1024, + clip_class_label_prompt=False, + patch_embed_3d=False, ): prompt_embed_dim = 256 - image_size = image_size # TODO: Shall we try to adapt model to 512x512 ? + image_size = image_size vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size - sam = Samm2pt5D( - image_encoder=ImageEncoderViT( + sam = Vista2pt5D( + image_encoder=VistaImageEncoderViT( in_chans=encoder_in_chans, depth=encoder_depth, embed_dim=encoder_embed_dim, @@ -289,12 +301,14 @@ def _build_sam2pt5d( global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, + patch_embed_3d=patch_embed_3d, ), - prompt_encoder=PromptEncoder( + prompt_encoder=VistaPromptEncoder( embed_dim=prompt_embed_dim, image_embedding_size=(image_embedding_size, image_embedding_size), input_image_size=(image_size, image_size), mask_in_chans=16, + clip_class_label_prompt=clip_class_label_prompt, ), mask_decoder=MaskDecoder( num_multimask_outputs=3, # TODO: only predict one binary mask @@ -315,22 +329,26 @@ def _build_sam2pt5d( if checkpoint is not None: with open(checkpoint, "rb") as f: state_dict = torch.load(f) + if image_size == 1024: # we try to use all pretrained weights new_dict = state_dict - if encoder_in_chans != 3: - new_dict.pop("image_encoder.patch_embed.proj.weight") else: new_dict = {} for k, v in state_dict.items(): - # skip weights in prompt_encoder and mask_decoder - if k.startswith("prompt_encoder") or k.startswith("mask_decoder"): - continue # skip weights in position embedding and learned relative positional embeddings - elif "pos_embed" in k or "attn.rel_pos" in k: + # due to the change of input size + if ("pos_embed" in k and k.startswith("image_encoder")) or ( + "attn.rel_pos" in k and k.startswith("image_encoder") + ): continue else: new_dict[k] = v + + if encoder_in_chans != 3: + new_dict.pop("image_encoder.patch_embed.proj.weight") + new_dict.pop("image_encoder.patch_embed.proj.bias") + sam.load_state_dict(new_dict, strict=False) print(f"Load {len(new_dict)} keys from checkpoint {checkpoint}, current model has {len(sam.state_dict())} keys") @@ -355,33 +373,15 @@ def _build_sam2pt5d( f"{sum(mask_decoder_params) * 1.e-6:.2f} M params in mask decoder." ) - # comment to unfreeze all encoder layers - # for name, param in sam.named_parameters(): - # if name.startswith("image_encoder"): - # if image_size == 1024: - # if "pos_embed" in name or "patch_embed" in name or "blocks.0" in name: - # # we only retrain layers before blocks.1 in image_encoder - # continue - # # if "pos_embed" in name or "patch_embed" in name: - # # # we only retrain pos_embed and patch_embed - # # continue - # else: - # if "pos_embed" in name or "attn.rel_pos" in name or \ - # "patch_embed" in name or "blocks.0" in name or "neck" in name: - # # we only train pos_embed, patch_embed, blocks.0, attn.rel_pos (due res change) - # # and neck (a few conv layers for outputs) in image_encoder - # continue - # - # # we freeze all other layers in image_encoder - # param.requires_grad = False - total_trainable_params = sum(p.numel() if p.requires_grad else 0 for p in sam.parameters()) print(f"{sam.__class__.__name__} has {total_trainable_params * 1.e-6:.2f} M trainable params.") return sam -def build_samm2pt5d_vit_h(checkpoint=None, image_size=1024, encoder_in_chans=3): - return _build_sam2pt5d( +def build_vista2pt5d_vit_h( + checkpoint=None, image_size=1024, encoder_in_chans=3, clip_class_label_prompt=False, patch_embed_3d=False +): + return _build_vista2pt5d( encoder_in_chans=encoder_in_chans, encoder_embed_dim=1280, encoder_depth=32, @@ -389,11 +389,15 @@ def build_samm2pt5d_vit_h(checkpoint=None, image_size=1024, encoder_in_chans=3): encoder_global_attn_indexes=[7, 15, 23, 31], checkpoint=checkpoint, image_size=image_size, + clip_class_label_prompt=clip_class_label_prompt, + patch_embed_3d=patch_embed_3d, ) -def build_samm2pt5d_vit_l(checkpoint=None, image_size=1024, encoder_in_chans=3): - return _build_sam2pt5d( +def build_vista2pt5d_vit_l( + checkpoint=None, image_size=1024, encoder_in_chans=3, clip_class_label_prompt=False, patch_embed_3d=False +): + return _build_vista2pt5d( encoder_in_chans=encoder_in_chans, encoder_embed_dim=1024, encoder_depth=24, @@ -401,11 +405,15 @@ def build_samm2pt5d_vit_l(checkpoint=None, image_size=1024, encoder_in_chans=3): encoder_global_attn_indexes=[5, 11, 17, 23], checkpoint=checkpoint, image_size=image_size, + clip_class_label_prompt=clip_class_label_prompt, + patch_embed_3d=patch_embed_3d, ) -def build_samm2pt5d_vit_b(checkpoint=None, image_size=1024, encoder_in_chans=3): - return _build_sam2pt5d( +def build_vista2pt5d_vit_b( + checkpoint=None, image_size=1024, encoder_in_chans=3, clip_class_label_prompt=False, patch_embed_3d=False +): + return _build_vista2pt5d( encoder_in_chans=encoder_in_chans, encoder_embed_dim=768, encoder_depth=12, @@ -413,48 +421,19 @@ def build_samm2pt5d_vit_b(checkpoint=None, image_size=1024, encoder_in_chans=3): encoder_global_attn_indexes=[2, 5, 8, 11], checkpoint=checkpoint, image_size=image_size, + clip_class_label_prompt=clip_class_label_prompt, + patch_embed_3d=patch_embed_3d, ) sam_model_registry = { - "default": build_samm2pt5d_vit_h, - "vit_h": build_samm2pt5d_vit_h, - "vit_l": build_samm2pt5d_vit_l, - "vit_b": build_samm2pt5d_vit_b, + "default": build_vista2pt5d_vit_h, + "vit_h": build_vista2pt5d_vit_h, + "vit_l": build_vista2pt5d_vit_l, + "vit_b": build_vista2pt5d_vit_b, } + if __name__ == "__main__": - model = build_samm2pt5d_vit_b() + model = build_vista2pt5d_vit_b() model.cuda() - # - # dummy_input = [{"image": torch.rand(3, 176, 345).cuda(), "original_size": (176, 345), - # "point_coords": torch.rand(3, 5, 2).cuda(), "point_labels": torch.ones(3, 5).cuda(), - # "labels": torch.ones(3, 1).long().cuda()}, - # {"image": torch.rand(3, 128, 365).cuda(), "original_size": (128, 365), - # "point_coords": torch.rand(1, 3, 2).cuda(), "point_labels": torch.ones(1, 3).cuda(), - # "labels": torch.ones(1, 1).long().cuda()} - # ] - # # dummy_input = [{"image": torch.rand(3, 176, 345).cuda(), "original_size": (256, 512), - # # "point_coords": torch.rand(3, 5, 2).cuda(), "point_labels": torch.ones(3, 5).cuda()}] - # outputs = model(dummy_input) - - # test if postprocessing can inverse preprocess - # path = "/home/pengfeig/Downloads/fffabebf-74fd3a1f-673b6b41-96ec0ac9-2ab69818.jpg" - # from PIL import Image - # import numpy as np - # import matplotlib.pyplot as plt - # - # image = np.array(Image.open(path)).transpose(2, 0, 1)[:, :365, :256].astype(np.float32) - # plt.imshow(image.transpose(1, 2, 0).astype(np.uint8)) - # plt.show() - # dummy_tensor = torch.from_numpy(image).cuda() - # tmp = model.preprocess(dummy_tensor) - # plt.imshow(tmp.cpu().numpy().transpose(1, 2, 0).astype(np.uint8)) - # plt.show() - # inverse_tensor = model.postprocess_masks(tmp.unsqueeze(0), (365, 256)).squeeze(0) - # print(torch.sum(torch.abs(inverse_tensor-dummy_tensor))) - # print("dummy_tensor", torch.min(dummy_tensor), torch.max(dummy_tensor)) - # print("inverse_tensor", torch.min(inverse_tensor), torch.max(inverse_tensor)) - # plt.imshow(inverse_tensor.cpu().numpy().transpose(1, 2, 0).astype(np.uint8)) - # plt.show() - # print() diff --git a/monailabel/monaivista/lib/model/vista_point_2pt5/trainer_2pt5d.py b/monailabel/monaivista/lib/model/vista_point_2pt5/trainer_2pt5d.py deleted file mode 100644 index 9c6f74a..0000000 --- a/monailabel/monaivista/lib/model/vista_point_2pt5/trainer_2pt5d.py +++ /dev/null @@ -1,966 +0,0 @@ -# Copyright 2020 - 2022 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import random -import time -from copy import deepcopy - -import numpy as np -import torch -import torch.nn.functional as F -import torch.nn.parallel -import torch.utils.data.distributed -from monai.data import decollate_batch -from tensorboardX import SummaryWriter -from torch.cuda.amp import GradScaler, autocast -from utils.utils import AverageMeter, distributed_all_gather - - -def apply_coords(coords, original_size, sam_image_size) -> np.ndarray: - """ - Expects a numpy array of length 2 in the final dimension. Requires the - original image size in (H, W) format. - """ - old = original_size - new = sam_image_size - coords = deepcopy(coords).astype(float) - # Here, we can apply a same scale factor to h and w, because we first pad the input to a square image along the - # longest side then resize it to sam_image_size. In other words, the scale factor is determined by the longest side. - coords[..., 0] = coords[..., 0] * (new / old) - coords[..., 1] = coords[..., 1] * (new / old) - return coords - - -def apply_coords_torch(coords, original_size, sam_image_size) -> np.ndarray: - """ - Expects a numpy array of length 2 in the final dimension. Requires the - original image size in (H, W) format. - """ - old = original_size - new = sam_image_size - coords = deepcopy(coords).float() - # Here, we can apply a same scale factor to h and w, because we first pad the input to a square image along the - # longest side then resize it to sam_image_size. In other words, the scale factor is determined by the longest side. - coords[..., 0] = coords[..., 0] * (new / old) - coords[..., 1] = coords[..., 1] * (new / old) - return coords - - -def sample_points(labelpoints, n_points): - idx = torch.randperm(len(labelpoints), dtype=torch.long, device=labelpoints.device)[:n_points] - return [labelpoints[idx]] - - -def generate_point_prompt_train(batch_labels_, args): - max_point = args.max_points - Np = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1) - Nn = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2)))) - # To follow original SAM, with equal probability either a foreground point - # is selected randomly for the target mask - _point = [] - _point_label = [] - b, h, w = batch_labels_.shape - device = batch_labels_.device - for i in range(b): - # if Np == 0 and Nn == 0: - # # _point.append(np.stack([np.array([0, 0])])) - # # _point_label.append(torch.tensor([-1])) - # _point.append(torch.stack([torch.tensor([0, 0])]).to(device)) - # _point_label.append(torch.tensor([-1]).to(device)) - # else: - plabels = batch_labels_[i, ...] - nlabels = (plabels == 0.0).float() - plabelpoints = torch.nonzero(plabels) - nlabelpoints = torch.nonzero(nlabels) - # 1 indicates a foreground point and 0 indicates a background point. - # -1 indicates a dummy non-point as the placeholder. - n_placeholder = Np + Nn - min(len(plabelpoints), Np) - min(len(nlabelpoints), Nn) - # _point.append(np.stack( - # random.choices(plabelpoints, k=min(len(plabelpoints), Np)) + - # random.choices(nlabelpoints, k=min(len(nlabelpoints), Nn)) + - # [np.array([0, 0])] * n_placeholder)) - # _point_label.append(torch.tensor( - # [1] * min(len(plabelpoints), Np) + [0] * min(len(nlabelpoints), Nn) + [-1] * n_placeholder)) - - # Use torch.randperm to generate indices on a GPU tensor - _point.append( - torch.cat( - sample_points(plabelpoints, min(len(plabelpoints), Np)) - + sample_points(nlabelpoints, min(len(nlabelpoints), Nn)) - + [torch.zeros((1, 2), device=device)] * n_placeholder, - dim=0, - ) - ) - _point_label.append( - torch.tensor([1] * min(len(plabelpoints), Np) + [0] * min(len(nlabelpoints), Nn) + [-1] * n_placeholder).to( - device - ) - ) - - # point = np.stack(_point) - point = torch.stack(_point) - point_label = torch.stack(_point_label) - # point_coords = torch.as_tensor(apply_coords(point, max(h, w), args.sam_image_size), dtype=torch.float) - point_coords = apply_coords_torch(point, max(h, w), args.sam_image_size) - - return point_coords, point_label - - -def prepare_sam_training_input(inputs, labels, args, model): - # start_time = time.time() - - unique_labels = torch.unique(labels).as_tensor().long() - - # TODO: (Note!) We don't skip background fow now. - # we skip the slice only having background - skip = False - # if len(unique_labels) == 1: - # # skip = True - # unique_labels = torch.randint(low=1, high=105, size=(args.num_prompt,)) - # else: - # # we skip the background class when multiple labels exist - # unique_labels = unique_labels[1:] - - # TODO: (Note!) We cannot make up prompts, since if we random sample a lot of prompts that is not existing in this - # slice, the network will tend to learn nothing (giving every prompt a black mask). - # if len(unique_labels) < args.num_prompt: - # make the number of labels equals to args.num_prompt - # makeup_labels = torch.randint(low=1, high=105, size=(args.num_prompt - len(unique_labels),)) - # unique_labels = torch.cat([unique_labels, makeup_labels], dim=0) - - # random sample args.num_prompt prompts, this will help to manage the GPU memory upper bound. - if len(unique_labels) > args.num_prompt: - idxs = random.sample(range(len(unique_labels)), args.num_prompt) - idxs = torch.tensor(idxs) - unique_labels = unique_labels[idxs] - if len(unique_labels) < args.num_prompt: - while len(unique_labels) < args.num_prompt: - unique_labels = torch.cat([unique_labels, unique_labels], 0) - unique_labels = unique_labels[: args.num_prompt] - - # add one background label to every batch - background_labels = list(set([i for i in range(105)]) - set(unique_labels.cpu().numpy())) - random.shuffle(background_labels) - unique_labels = torch.cat([unique_labels, torch.tensor(background_labels[:1]).cuda(args.rank)]) - - # preprocess make the size of label same as low_res_logit - batch_labels_ = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float() - - if args.distributed: - batch_labels = model.module.preprocess(batch_labels_, is_input=False) - else: - batch_labels = model.preprocess(batch_labels_, is_input=False) - - # TODO: we currently only use class-label and points prompt. - - prepared_input = [{"image": inputs, "original_size": tuple(labels.shape)}] - if args.label_prompt: - labels_prompt = unique_labels.unsqueeze(-1) - prepared_input[0].update({"labels": labels_prompt}) - - if args.point_prompt: - point_coords, point_labels = generate_point_prompt_train(batch_labels_, args) - prepared_input[0].update({"point_coords": point_coords, "point_labels": point_labels}) - - if args.label_prompt and args.point_prompt: - # if we use both two kinds of prompts, then we randomly drop one kind. - if random.uniform(0, 1) < args.drop_label_prob: - prepared_input[0].pop("labels") - else: - if random.uniform(0, 1) < args.drop_point_prob: - prepared_input[0].pop("point_coords") - prepared_input[0].pop("point_labels") - - return prepared_input, batch_labels.unsqueeze(1).cuda(args.rank), skip - - -def train_epoch(model, loader, optimizer, scaler, epoch, loss_func, args): - model.train() - start_time = time.time() - run_loss = AverageMeter() - # we need to make sure the number of 2.5D input is an odd number. - assert args.roi_z_iter % 2 == 1 - for idx, batch_data in enumerate(loader): - # only take 1 batch - inputs_l = batch_data["image"] - labels_l = batch_data["label"] - # TODO: we only support batch_size = 1 for data loader. - inputs_l = inputs_l.squeeze() - labels_l = labels_l.squeeze() - n_z_before_pad = labels_l.shape[-1] - # pad the z direction, so we can easily extract 2.5D input and predict labels for the center slice - pd = (args.roi_z_iter // 2, args.roi_z_iter // 2) - inputs_l = F.pad(inputs_l, pd, "constant", 0) - labels_l = F.pad(labels_l, pd, "constant", 0) - _loss = torch.tensor(0.0).cuda(args.rank) - for _k in range(min(args.num_patch, n_z_before_pad)): - # Return random integers from `low` (inclusive) to `high` (exclusive). - start_idx = int(np.random.randint(low=args.roi_z_iter // 2, high=(args.roi_z_iter // 2 + n_z_before_pad))) - - inputs = inputs_l[..., start_idx - args.roi_z_iter // 2 : start_idx + args.roi_z_iter // 2 + 1].permute( - 2, 0, 1 - ) - # we only need the label for the center slice - labels = labels_l[..., start_idx - args.roi_z_iter // 2 : start_idx + args.roi_z_iter // 2 + 1][ - ..., args.roi_z_iter // 2 - ] - - data, target, skip = prepare_sam_training_input(inputs.cuda(args.rank), labels.cuda(args.rank), args, model) - - for param in model.parameters(): - param.grad = None - - with autocast(enabled=args.amp): - outputs = model(data, is_train=True) - loss = loss_func(outputs[0]["low_res_logits"], target) - - if args.amp: - scaler.scale(loss).backward() - if args.clip is not None: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - if args.clip is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) - optimizer.step() - - _loss += loss.detach() - _loss /= min(args.num_patch, n_z_before_pad) - if args.distributed: - loss_list = distributed_all_gather( - [_loss], - out_numpy=True, # is_valid=idx < loader.sampler.valid_length - ) - run_loss.update( - np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size - ) - else: - run_loss.update(_loss.item(), n=args.num_patch) - if args.rank == 0: - print( - "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), - "loss: {:.4f}".format(run_loss.avg), - "time {:.2f}s".format(time.time() - start_time), - ) - start_time = time.time() - for param in model.parameters(): - param.grad = None - return run_loss.avg - - -def generate_point_prompt_train_iterative(batch_labels_, args, points_pos=None, points_neg=None, previous_pred=None): - max_point = args.max_points - Np = ( - points_pos - if points_pos is not None - else min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1) - ) - Nn = points_neg if points_neg is not None else min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2)))) - # To follow original SAM, with equal probability either a foreground point - # is selected randomly for the target mask - _point = [] - _point_label = [] - b, h, w = batch_labels_.shape - device = batch_labels_.device - for i in range(b): - plabels = batch_labels_[i, ...] - nlabels = (plabels == 0.0).float() - if previous_pred is not None: - ppred = previous_pred[i, 0, ...] - npred = (previous_pred[i, 0, ...] == 0.0).float() - - # False positive mask (pixels that are predicted as positive but are actually negative) - fp_mask = torch.logical_and(nlabels, ppred) - # False negative mask (pixels that are predicted as negative but are actually positive) - fn_mask = torch.logical_and(plabels, npred) - # we sample positive points from false negative pred. - # we sample negative points from false positive pred. - plabelpoints = torch.nonzero(fn_mask) - nlabelpoints = torch.nonzero(fp_mask) - - else: - plabelpoints = torch.nonzero(plabels) - nlabelpoints = torch.nonzero(nlabels) - # 1 indicates a foreground point and 0 indicates a background point. - # -1 indicates a dummy non-point as the placeholder. - n_placeholder = Np + Nn - min(len(plabelpoints), Np) - min(len(nlabelpoints), Nn) - - # Use torch.randperm to generate indices on a GPU tensor - _point.append( - torch.cat( - sample_points(plabelpoints, min(len(plabelpoints), Np)) - + sample_points(nlabelpoints, min(len(nlabelpoints), Nn)) - + [torch.zeros((1, 2), device=device)] * n_placeholder, - dim=0, - ) - ) - _point_label.append( - torch.tensor([1] * min(len(plabelpoints), Np) + [0] * min(len(nlabelpoints), Nn) + [-1] * n_placeholder).to( - device - ) - ) - - point = torch.stack(_point) - point_label = torch.stack(_point_label) - point_coords = apply_coords_torch(point, max(h, w), args.sam_image_size) - - return point_coords, point_label - - -def prepare_sam_iterative_training_input(inputs, labels, args, model): - unique_labels = torch.unique(labels).as_tensor().long() - - # random sample args.num_prompt prompts, this will help to manage the GPU memory upper bound. - if len(unique_labels) > args.num_prompt: - idxs = random.sample(range(len(unique_labels)), args.num_prompt) - idxs = torch.tensor(idxs) - unique_labels = unique_labels[idxs] - if len(unique_labels) < args.num_prompt: - while len(unique_labels) < args.num_prompt: - unique_labels = torch.cat([unique_labels, unique_labels], 0) - unique_labels = unique_labels[: args.num_prompt] - - # add one background label to every batch - background_labels = list(set([i for i in range(105)]) - set(unique_labels.cpu().numpy())) - random.shuffle(background_labels) - unique_labels = torch.cat([unique_labels, torch.tensor(background_labels[:1]).cuda(args.rank)]) - - # preprocess make the size of label same as low_res_logit - batch_labels_ = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float() - - if args.distributed: - batch_labels = model.module.preprocess(batch_labels_, is_input=False) - else: - batch_labels = model.preprocess(batch_labels_, is_input=False) - - # TODO: we currently only use class-label and points prompt. - - prepared_input = [{"image": inputs, "original_size": tuple(labels.shape)}] - if args.label_prompt: - labels_prompt = unique_labels.unsqueeze(-1) - prepared_input[0].update({"labels": labels_prompt}) - - if args.point_prompt: - point_coords, point_labels = generate_point_prompt_train_iterative(batch_labels_, args) - prepared_input[0].update({"point_coords": point_coords, "point_labels": point_labels}) - - if args.label_prompt and args.point_prompt: - # if we use both two kinds of prompts, then we randomly drop one kind. - if random.uniform(0, 1) < args.drop_label_prob: - prepared_input[0].pop("labels") - else: - if random.uniform(0, 1) < args.drop_point_prob: - prepared_input[0].pop("point_coords") - prepared_input[0].pop("point_labels") - - return prepared_input, batch_labels.unsqueeze(1).cuda(args.rank), batch_labels_ - - -def train_epoch_iterative(model, loader, optimizer, scaler, epoch, loss_func, args): - model.train() - start_time = time.time() - run_loss = AverageMeter() - # we need to make sure the number of 2.5D input is an odd number. - assert args.roi_z_iter % 2 == 1 - for idx, batch_data in enumerate(loader): - # only take 1 batch - inputs_l = batch_data["image"] - labels_l = batch_data["label"] - # TODO: we only support batch_size = 1 for data loader. - inputs_l = inputs_l.squeeze() - labels_l = labels_l.squeeze() - n_z_before_pad = labels_l.shape[-1] - # pad the z direction, so we can easily extract 2.5D input and predict labels for the center slice - pd = (args.roi_z_iter // 2, args.roi_z_iter // 2) - inputs_l = F.pad(inputs_l, pd, "constant", 0) - labels_l = F.pad(labels_l, pd, "constant", 0) - _loss = torch.tensor(0.0).cuda(args.rank) - for _k in range(min(args.num_patch, n_z_before_pad)): - # Return random integers from `low` (inclusive) to `high` (exclusive). - start_idx = int(np.random.randint(low=args.roi_z_iter // 2, high=(args.roi_z_iter // 2 + n_z_before_pad))) - - inputs = inputs_l[..., start_idx - args.roi_z_iter // 2 : start_idx + args.roi_z_iter // 2 + 1].permute( - 2, 0, 1 - ) - # we only need the label for the center slice - labels = labels_l[..., start_idx - args.roi_z_iter // 2 : start_idx + args.roi_z_iter // 2 + 1][ - ..., args.roi_z_iter // 2 - ] - - data, target, target_original = prepare_sam_iterative_training_input( - inputs.cuda(args.rank), labels.cuda(args.rank), args, model - ) - for param in model.parameters(): - param.grad = None - - with autocast(enabled=args.amp): - if args.distributed: - image_embeddings = model.module.get_image_embeddings(data) - else: - image_embeddings = model.get_image_embeddings(data) - - # iterative training - loss = 0 - drop_iter = random.randint(0, args.num_iterative_step - 2) - for i in range(args.num_iterative_step): - with autocast(enabled=args.amp): - if args.distributed: - outputs = model.module.get_mask_prediction(data, image_embeddings) - else: - outputs = model.get_mask_prediction(data, image_embeddings) - loss += loss_func(outputs[0]["low_res_logits"], target) - if i == args.num_iterative_step - 1: - # no need to perform the following operations after the last step - continue - # we also supply the mask prediction from the previous iteration - # as an additional prompt to our model (follow original SAM). - data[0]["mask_inputs"] = outputs[0]["low_res_logits"].detach() - if i == drop_iter: - # for the next iter, no additional points are sampled (follow original SAM). - continue - - previous_point_coords = data[0].get("point_coords", None) - previous_point_labels = data[0].get("point_labels", None) - - if previous_point_coords is None and args.no_more_points_for_cp_only: - # if no point prompt at the first prompt generation, - # we will not add more additional pointa during iterative training. - continue - - # sample one pos and on neg point based on previous prediction - previous_pred = (F.sigmoid(outputs[0]["high_res_logits"].detach()) > 0.5).float() - point_coords, point_labels = generate_point_prompt_train_iterative( - target_original, args=args, points_pos=1, points_neg=1, previous_pred=previous_pred - ) - - if previous_point_coords is not None: - data[0]["point_coords"] = torch.cat([previous_point_coords, point_coords], dim=1) - data[0]["point_labels"] = torch.cat([previous_point_labels, point_labels], dim=1) - else: - data[0]["point_coords"] = point_coords - data[0]["point_labels"] = point_labels - - if args.amp: - scaler.scale(loss).backward() - if args.clip is not None: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - if args.clip is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) - optimizer.step() - - _loss += loss.detach() / args.num_iterative_step - _loss /= min(args.num_patch, n_z_before_pad) - if args.distributed: - loss_list = distributed_all_gather( - [_loss], - out_numpy=True, # is_valid=idx < loader.sampler.valid_length - ) - run_loss.update( - np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size - ) - else: - run_loss.update(_loss.item(), n=args.num_patch) - if args.rank == 0: - print( - "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), - "loss: {:.4f}".format(run_loss.avg), - "time {:.2f}s".format(time.time() - start_time), - ) - start_time = time.time() - for param in model.parameters(): - param.grad = None - return run_loss.avg - - -def train_epoch_iterative_not_reuse_img_embedding(model, loader, optimizer, scaler, epoch, loss_func, args): - model.train() - start_time = time.time() - run_loss = AverageMeter() - # we need to make sure the number of 2.5D input is an odd number. - assert args.roi_z_iter % 2 == 1 - for idx, batch_data in enumerate(loader): - # only take 1 batch - inputs_l = batch_data["image"] - labels_l = batch_data["label"] - # TODO: we only support batch_size = 1 for data loader. - inputs_l = inputs_l.squeeze() - labels_l = labels_l.squeeze() - n_z_before_pad = labels_l.shape[-1] - # pad the z direction, so we can easily extract 2.5D input and predict labels for the center slice - pd = (args.roi_z_iter // 2, args.roi_z_iter // 2) - inputs_l = F.pad(inputs_l, pd, "constant", 0) - labels_l = F.pad(labels_l, pd, "constant", 0) - _loss = torch.tensor(0.0).cuda(args.rank) - for _k in range(min(args.num_patch, n_z_before_pad)): - # Return random integers from `low` (inclusive) to `high` (exclusive). - start_idx = int(np.random.randint(low=args.roi_z_iter // 2, high=(args.roi_z_iter // 2 + n_z_before_pad))) - - inputs = inputs_l[..., start_idx - args.roi_z_iter // 2 : start_idx + args.roi_z_iter // 2 + 1].permute( - 2, 0, 1 - ) - # we only need the label for the center slice - labels = labels_l[..., start_idx - args.roi_z_iter // 2 : start_idx + args.roi_z_iter // 2 + 1][ - ..., args.roi_z_iter // 2 - ] - - data, target, target_original = prepare_sam_iterative_training_input( - inputs.cuda(args.rank), labels.cuda(args.rank), args, model - ) - for param in model.parameters(): - param.grad = None - - # iterative training - loss_accum = 0 - drop_iter = random.randint(0, args.num_iterative_step - 2) - for i in range(args.num_iterative_step): - with autocast(enabled=args.amp): - if args.distributed: - image_embeddings = model.module.get_image_embeddings(data) - outputs = model.module.get_mask_prediction(data, image_embeddings) - else: - image_embeddings = model.get_image_embeddings(data) - outputs = model.get_mask_prediction(data, image_embeddings) - loss = loss_func(outputs[0]["low_res_logits"], target) - if i == args.num_iterative_step - 1: - # no need to perform the following operations after the last step - continue - # we also supply the mask prediction from the previous iteration - # as an additional prompt to our model (follow original SAM). - data[0]["mask_inputs"] = outputs[0]["low_res_logits"].detach() - if i == drop_iter: - # for the next iter, no additional points are sampled (follow original SAM). - continue - - previous_point_coords = data[0].get("point_coords", None) - previous_point_labels = data[0].get("point_labels", None) - - if previous_point_coords is None and args.no_more_points_for_cp_only: - # if no point prompt at the first prompt generation, - # we will not add more additional pointa during iterative training. - continue - - # sample one pos and on neg point based on previous prediction - previous_pred = (F.sigmoid(outputs[0]["high_res_logits"].detach()) > 0.5).float() - point_coords, point_labels = generate_point_prompt_train_iterative( - target_original, args=args, points_pos=1, points_neg=1, previous_pred=previous_pred - ) - - if previous_point_coords is not None: - data[0]["point_coords"] = torch.cat([previous_point_coords, point_coords], dim=1) - data[0]["point_labels"] = torch.cat([previous_point_labels, point_labels], dim=1) - else: - data[0]["point_coords"] = point_coords - data[0]["point_labels"] = point_labels - - if args.amp: - scaler.scale(loss).backward() - if args.clip is not None: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - if args.clip is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) - optimizer.step() - - loss_accum += loss.detach() - - _loss += loss_accum / args.num_iterative_step - _loss /= min(args.num_patch, n_z_before_pad) - if args.distributed: - loss_list = distributed_all_gather( - [_loss], - out_numpy=True, # is_valid=idx < loader.sampler.valid_length - ) - run_loss.update( - np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size - ) - else: - run_loss.update(_loss.item(), n=args.num_patch) - if args.rank == 0: - print( - "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), - "loss: {:.4f}".format(run_loss.avg), - "time {:.2f}s".format(time.time() - start_time), - ) - start_time = time.time() - for param in model.parameters(): - param.grad = None - return run_loss.avg - - -def generate_point_prompt_val(batch_labels_, args, previous_pred=None): - Np = args.points_val_pos - Nn = args.points_val_neg - # To follow original SAM, with equal probability either a foreground point - # is selected randomly for the target mask - _point = [] - _point_label = [] - b, h, w = batch_labels_.shape - device = batch_labels_.device - for i in range(b): - plabels = batch_labels_[i, ...] - nlabels = (plabels == 0.0).float() - if previous_pred is not None: - ppred = previous_pred[i, 0, ...] - npred = (previous_pred[i, 0, ...] == 0.0).float() - - # False positive mask (pixels that are predicted as positive but are actually negative) - fp_mask = torch.logical_and(nlabels, ppred) - # False negative mask (pixels that are predicted as negative but are actually positive) - fn_mask = torch.logical_and(plabels, npred) - # we sample positive points from false negative pred. - # we sample negative points from false positive pred. - plabelpoints = torch.nonzero(fn_mask) - nlabelpoints = torch.nonzero(fp_mask) - - else: - plabelpoints = torch.nonzero(plabels) - nlabelpoints = torch.nonzero(nlabels) - # 1 indicates a foreground point and 0 indicates a background point. - # -1 indicates a dummy non-point as the placeholder. - n_placeholder = Np + Nn - min(len(plabelpoints), Np) - min(len(nlabelpoints), Nn) - - # Use torch.randperm to generate indices on a GPU tensor - _point.append( - torch.cat( - sample_points(plabelpoints, min(len(plabelpoints), Np)) - + sample_points(nlabelpoints, min(len(nlabelpoints), Nn)) - + [torch.zeros((1, 2), device=device)] * n_placeholder, - dim=0, - ) - ) - _point_label.append( - torch.tensor([1] * min(len(plabelpoints), Np) + [0] * min(len(nlabelpoints), Nn) + [-1] * n_placeholder).to( - device - ) - ) - - # point = torch.stack(_point) - point = torch.stack(_point) - point_label = torch.stack(_point_label) - # point_coords = torch.as_tensor(apply_coords(point, max(h, w), args.sam_image_size), dtype=torch.float) - point_coords = apply_coords_torch(point, max(h, w), args.sam_image_size) - - return point_coords, point_label - - -def prepare_sam_val_input(inputs, labels, args, previous_pred=None): - # Don't exclude background in val but will ignore it in metric calculation - unique_labels = torch.unique(labels).as_tensor().long() - - # preprocess make the size of lable same as high_res_logit - batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float() - - # TODO: we currently only use class label prompt. - # prepared_input = [{"image": inputs, "original_size": tuple(labels.shape), - # "labels": labels_prompt}] - # if args.point_prompt: - # point_coords, point_labels = generate_point_prompt_val(batch_labels, args) - # prepared_input[0].update( - # {"point_coords": point_coords, "point_labels": point_labels}) - - prepared_input = [{"image": inputs, "original_size": tuple(labels.shape)}] - if args.label_prompt: - labels_prompt = unique_labels.unsqueeze(-1) - prepared_input[0].update({"labels": labels_prompt}) - - if args.point_prompt: - point_coords, point_labels = generate_point_prompt_val(batch_labels, args, previous_pred) - prepared_input[0].update({"point_coords": point_coords, "point_labels": point_labels}) - - return prepared_input, batch_labels.unsqueeze(1).cuda(args.rank), unique_labels - - -def val_epoch(model, loader, epoch, acc_func, args, iterative=False, post_label=None, post_pred=None): - model.eval() - run_acc = AverageMeter() - start_time = time.time() - with torch.no_grad(): - for idx, batch_data in enumerate(loader): - # only take 1 batch - inputs_l = batch_data["image"] - labels_l = batch_data["label"] - labels_l.shape[-1] - # assert n_z_before_pad >= args.num_patch_val + args.roi_z_iter - - # TODO: we only support batch_size = 1 for data loader. - inputs_l = inputs_l.squeeze() - labels_l = labels_l.squeeze() - # pad the z direction, so we can easily extract 2.5D input and predict labels for the center slice - pd = (args.roi_z_iter // 2, args.roi_z_iter // 2) - inputs_l = F.pad(inputs_l, pd, "constant", 0) - labels_l = F.pad(labels_l, pd, "constant", 0) - n_z_after_pad = labels_l.shape[-1] - - acc_sum_total = 0.0 - not_nans_total = 0.0 - # We only loop the center args.num_patch_val slices to save val time - for start_idx in range( - n_z_after_pad // 2 - args.num_patch_val // 2, n_z_after_pad // 2 + args.num_patch_val // 2 - ): - inputs = inputs_l[..., start_idx - args.roi_z_iter // 2 : start_idx + args.roi_z_iter // 2 + 1].permute( - 2, 0, 1 - ) - # we only need the label for the center slice - labels = labels_l[..., start_idx - args.roi_z_iter // 2 : start_idx + args.roi_z_iter // 2 + 1][ - ..., args.roi_z_iter // 2 - ] - - if iterative: - # first inference to get cp only results - args.label_prompt = True - args.point_prompt = False - - data, target, _ = prepare_sam_val_input(inputs.cuda(args.rank), labels.cuda(args.rank), args) - - if len(target) == 1: - # skip the prediction only having bk - continue - - with autocast(enabled=args.amp): - outputs = model(data) - logit = outputs[0]["high_res_logits"] - - y_pred = torch.stack(post_pred(decollate_batch(logit)), 0) - - if iterative: - # second inference to get refined results - args.label_prompt = True - args.point_prompt = True - - point_coords, point_labels = generate_point_prompt_val(target.squeeze(1), args, y_pred) - data[0]["point_coords"] = point_coords - data[0]["point_labels"] = point_labels - data[0]["mask_inputs"] = outputs[0]["low_res_logits"] - - with autocast(enabled=args.amp): - outputs = model(data) - logit = outputs[0]["high_res_logits"] - - y_pred = torch.stack(post_pred(decollate_batch(logit)), 0) - - # TODO: we compute metric for each prompt for simplicity in validation. - # Hacking into DiceMetric to compute dice for each single case. - acc_batch = acc_func._compute_tensor(y_pred=y_pred, y=target) - acc_sum, not_nans = torch.sum(acc_batch).item(), len(acc_batch) - acc_sum_total += acc_sum - not_nans_total += not_nans - - acc, not_nans = acc_sum_total / not_nans_total, not_nans_total - f_name = batch_data["image"].meta["filename_or_obj"] - print(f"Rank: {args.rank}, Case: {f_name}, Acc: {acc:.4f}, N_prompts: {int(not_nans)} ") - - acc = torch.tensor(acc).cuda(args.rank) - not_nans = torch.tensor(not_nans).cuda(args.rank) - - if args.distributed: - acc_list, not_nans_list = distributed_all_gather( - [acc, not_nans], out_numpy=True # , is_valid=idx < loader.sampler.valid_length - ) - for al, nl in zip(acc_list, not_nans_list): - run_acc.update(al, n=nl) - - else: - run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy()) - - if args.rank == 0: - avg_acc = np.mean(run_acc.avg) - print( - "Val {}/{} {}/{}".format(epoch, args.max_epochs, idx + 1, len(loader)), - "acc", - avg_acc, - "time {:.2f}s".format(time.time() - start_time), - ) - start_time = time.time() - return run_acc.avg - - -def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0, optimizer=None, scheduler=None): - state_dict = model.state_dict() if not args.distributed else model.module.state_dict() - save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict} - if optimizer is not None: - save_dict["optimizer"] = optimizer.state_dict() - if scheduler is not None: - save_dict["scheduler"] = scheduler.state_dict() - filename = os.path.join(args.logdir, filename) - torch.save(save_dict, filename) - print("Saving checkpoint", filename) - - -def run_training( - model, - train_loader, - val_loader, - optimizer, - loss_func, - acc_func, - args, - scheduler=None, - start_epoch=0, - post_label=None, - post_pred=None, -): - writer = None - if args.logdir is not None and args.rank == 0: - writer = SummaryWriter(log_dir=args.logdir) - if args.rank == 0: - print("Writing Tensorboard logs to ", args.logdir) - scaler = None - if args.amp: - scaler = GradScaler() - val_acc_max = 0.0 - # best_val_MA = - np.inf - best_epoch = -1 - val_MA = None - best_log = {} - # best_MA_log = {} - for epoch in range(start_epoch, args.max_epochs): - if args.distributed: - # train_loader.sampler.set_epoch(epoch) - torch.distributed.barrier() - print(args.rank, time.ctime(), "Epoch:", epoch) - epoch_time = time.time() - if args.rank == 0: - if scheduler is not None: - print("Current lr:", scheduler.get_last_lr()) - else: - print("Current lr:", optimizer.param_groups[0]["lr"]) - - if args.label_prompt: - if epoch < args.label_prompt_warm_up_epoch: - # during warm up, we drop class label prompt embedding with less prob, - # since class label prompt embedding layer is trained from scratch. - args.drop_label_prob = 0.2 - args.drop_point_prob = 0.5 - # args.point_prompt = False - else: - # after warmp up, we evenly drop two kinds of prompts - args.drop_label_prob = 0.5 - args.drop_point_prob = 0.5 - # args.point_prompt = True - print( - "rank:", - args.rank, - "label_prompt (train):", - args.label_prompt, - ", label_drop_prob:", - args.drop_label_prob, - "| point_prompt (train):", - args.point_prompt, - ", point_drop_prob:", - args.drop_point_prob, - ) - - # we don't perform iterative training for the first args.iterative_training_warm_up_epoch epochs - if epoch > args.iterative_training_warm_up_epoch: - if args.reuse_img_embedding: - if args.rank == 0: - print("Iterative Training: Reuse image embedding!") - train_loss = train_epoch_iterative( - model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args - ) - else: - if args.rank == 0: - print("Iterative Training: Don't reuse image embedding!") - train_loss = train_epoch_iterative_not_reuse_img_embedding( - model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args - ) - else: - if args.rank == 0: - print("Single-step Training") - train_loss = train_epoch( - model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args - ) - - if args.rank == 0: - print( - "Final training {}/{}".format(epoch, args.max_epochs - 1), - "loss: {:.4f}".format(train_loss), - "time {:.2f}s".format(time.time() - epoch_time), - ) - if args.rank == 0 and writer is not None: - writer.add_scalar("train_loss", train_loss, epoch) - - if (epoch + 1) % args.val_every == 0: - if args.distributed: - torch.distributed.barrier() - if args.rank == 0: - print("Start validation") - print("label_prompt (val):", args.label_prompt, "point_prompt (val):", args.point_prompt) - epoch_time = time.time() - val_avg_acc = val_epoch( - model, - val_loader, - iterative=False, - epoch=epoch, - acc_func=acc_func, - args=args, - post_label=post_label, - post_pred=post_pred, - ) - - val_avg_acc = np.mean(val_avg_acc) - if val_MA is None: - val_MA = val_avg_acc - else: - val_MA = 0.9 * val_MA + 0.1 * val_avg_acc - if args.rank == 0: - print( - "Final validation {}/{},".format(epoch, args.max_epochs - 1), - f"Acc {val_avg_acc:.4f},", - f"mv Acc {val_MA:.4f},", - "Previous Best validation at epoch {} is {:.4f},".format(best_epoch, val_acc_max), - "time {:.2f}s".format(time.time() - epoch_time), - ) - if writer is not None: - writer.add_scalar("val_acc", val_avg_acc, epoch) - if val_avg_acc > val_acc_max: - print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc)) - val_acc_max = val_avg_acc - best_log[epoch] = float(val_acc_max) - best_epoch = epoch - if args.rank == 0 and args.logdir is not None and args.save_checkpoint: - save_checkpoint( - model, - epoch, - args, - best_acc=val_acc_max, - filename="model_best.pt", - optimizer=optimizer, - scheduler=scheduler, - ) - with open(os.path.join(args.logdir, "train.log"), "w") as f: - json.dump(best_log, f) - if args.rank == 0 and args.logdir is not None and args.save_checkpoint: - save_checkpoint(model, epoch, args, best_acc=val_acc_max, filename="model_final.pt") - - if scheduler is not None: - scheduler.step() - - if args.rank == 0 and writer is not None: - writer.close() - - print("Training Finished !, Best Accuracy: ", val_acc_max, "at epoch", best_epoch) - - return val_acc_max diff --git a/monailabel/monaivista/lib/model/vista_point_2pt5/vista_2pt5_image_encoder.py b/monailabel/monaivista/lib/model/vista_point_2pt5/vista_2pt5_image_encoder.py new file mode 100644 index 0000000..166794e --- /dev/null +++ b/monailabel/monaivista/lib/model/vista_point_2pt5/vista_2pt5_image_encoder.py @@ -0,0 +1,138 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple, Type + +import torch +import torch.nn as nn +from segment_anything.modeling.image_encoder import ImageEncoderViT, PatchEmbed + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py +class VistaImageEncoderViT(ImageEncoderViT): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + patch_embed_3d: bool = False, + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + patch_embed_3d (bool): If True, use 3D Patch Embedding. + """ + super().__init__( + img_size, + patch_size, + in_chans, + embed_dim, + depth, + num_heads, + mlp_ratio, + out_chans, + qkv_bias, + norm_layer, + act_layer, + use_abs_pos, + use_rel_pos, + rel_pos_zero_init, + window_size, + global_attn_indexes, + ) + + self.img_size = img_size + + if in_chans > 3 and patch_embed_3d: + print("ImageEncoderViT: Using 3D PatchEmbed") + self.patch_embed = PatchEmbed2pt5D( + kernel_size=(patch_size, patch_size, in_chans // 3), + stride=(patch_size, patch_size, in_chans // 3), + in_chans=3, + embed_dim=embed_dim, + ) + else: + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + +class PatchEmbed2pt5D(nn.Module): + """ + Image to Patch Embedding by 3D Conv. + """ + + def __init__( + self, + kernel_size: Tuple[int, int, int] = (16, 16, 1), + stride: Tuple[int, int, int] = (16, 16, 1), + padding: Tuple[int, int, int] = (0, 0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # got restore RGB channel dim and the depth dim + c = x.shape[1] + x = torch.stack(x.chunk(c // 3, dim=1), dim=-1) + x = self.proj(x) + # remove dummy depth dim to make it 2d + x = x.squeeze(-1) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/monailabel/monaivista/lib/model/vista_point_2pt5/vista_2pt5_prompt_encoder.py b/monailabel/monaivista/lib/model/vista_point_2pt5/vista_2pt5_prompt_encoder.py new file mode 100644 index 0000000..16958e2 --- /dev/null +++ b/monailabel/monaivista/lib/model/vista_point_2pt5/vista_2pt5_prompt_encoder.py @@ -0,0 +1,148 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional, Tuple, Type + +import numpy as np +import torch +from segment_anything.modeling.common import LayerNorm2d +from segment_anything.modeling.prompt_encoder import PromptEncoder +from torch import nn + + +class VistaPromptEncoder(PromptEncoder): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + n_classes: int = 512, + clip_class_label_prompt: bool = False, + ) -> None: + """ + Encodes prompts for input to Segment Anything Model's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + n_classes (int): The number of pre-defined classes. + clip_class_label_prompt (bool): Using clip txt features + as class label prompt. + """ + super().__init__(embed_dim, image_embedding_size, input_image_size, mask_in_chans, activation) + + self.clip_class_label_prompt = clip_class_label_prompt + # Add support for onehot vector embedding for pre-defined classes + if self.clip_class_label_prompt: + raise NotImplementedError + else: + self.label_embeddings = nn.Embedding(n_classes, embed_dim) + self.no_label_embed = nn.Embedding(1, embed_dim) + + def _embed_labels(self, labels: torch.Tensor) -> torch.Tensor: + """Embeds onehot vector inputs.""" + if self.clip_class_label_prompt: + raise NotImplementedError + else: + # Add support for onehot vector embedding for pre-defined classes + label_embedding = self.label_embeddings(labels) + return label_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + labels: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + elif labels is not None: + return labels.shape[0] + else: + return 1 + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + class_labels: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + class_labels (torch.Tensor or none): labels to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks, class_labels) + + # Add support for onehot vector embedding for pre-defined classes + if class_labels is not None: + label_embeddings = self._embed_labels(class_labels) + else: + label_embeddings = self.no_label_embed.weight.reshape(1, 1, -1).expand(bs, -1, -1) + + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + + # Add support for onehot vector embedding for pre-defined classes + sparse_embeddings = torch.cat([sparse_embeddings, label_embeddings], dim=1) + + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings