Skip to content

Commit

Permalink
Replace 2pt5 model files on monailabel side (#12)
Browse files Browse the repository at this point in the history
Replace the old 2pt5 model files with latest training models networks. 
model.py
vista_image_encoder
vista_prompt_encoder

---------

Signed-off-by: tangy5 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tangy5 and pre-commit-ci[bot] authored Jul 25, 2023
1 parent 9fc976b commit 1e7f3cc
Show file tree
Hide file tree
Showing 6 changed files with 361 additions and 1,052 deletions.
2 changes: 1 addition & 1 deletion monailabel/monaivista/lib/configs/vista_point_2pt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import lib.infers
import lib.trainers
from lib.model.vista_point_2pt5.models_samm2pt5d import sam_model_registry
from lib.model.vista_point_2pt5.model_2pt5 import sam_model_registry
from monailabel.interfaces.config import TaskConfig
from monailabel.interfaces.tasks.infer_v2 import InferTask
from monailabel.interfaces.tasks.train import TrainTask
Expand Down
10 changes: 10 additions & 0 deletions monailabel/monaivista/lib/model/vista_point_2pt5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

Expand All @@ -7,24 +18,25 @@
from functools import partial
from typing import Any, Dict, List, Tuple

import monai
import torch
from segment_anything.modeling import TwoWayTransformer
from segment_anything.modeling.mask_decoder import MaskDecoder
from torch import nn
from torch.nn import functional as F

from .segment_anything.modeling import TwoWayTransformer
from .segment_anything.modeling.image_encoder import ImageEncoderViT
from .segment_anything.modeling.mask_decoder import MaskDecoder
from .segment_anything.modeling.prompt_encoder import PromptEncoder
from .vista_2pt5_image_encoder import VistaImageEncoderViT
from .vista_2pt5_prompt_encoder import VistaPromptEncoder


class Samm2pt5D(nn.Module):
class Vista2pt5D(nn.Module):
mask_threshold: float = 0.5
image_format: str = "RGB"

def __init__(
self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
image_encoder: VistaImageEncoderViT,
prompt_encoder: VistaPromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
Expand Down Expand Up @@ -67,7 +79,6 @@ def get_mask_prediction(
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
# raise NotImplementedError
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
Expand All @@ -86,7 +97,6 @@ def get_mask_prediction(

high_res_masks = self.postprocess_masks(
low_res_masks,
# input_size=image_record["image"].shape[-2:],
original_size=image_record["original_size"],
)
masks = high_res_masks > self.mask_threshold
Expand Down Expand Up @@ -124,6 +134,8 @@ def forward(
input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts,
with shape BxN.
'labels': (torch.Tensor) Batched labels for class-label prompt,
with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
Already transformed to the input frame of the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
Expand Down Expand Up @@ -151,7 +163,6 @@ def forward(
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
# raise NotImplementedError
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
Expand Down Expand Up @@ -194,7 +205,6 @@ def forward(
def postprocess_masks(
self,
masks: torch.Tensor,
# input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -261,21 +271,23 @@ def preprocess(self, x: torch.Tensor, is_input=True) -> torch.Tensor:
return x


def _build_sam2pt5d(
def _build_vista2pt5d(
encoder_in_chans,
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
image_size=1024,
clip_class_label_prompt=False,
patch_embed_3d=False,
):
prompt_embed_dim = 256
image_size = image_size # TODO: Shall we try to adapt model to 512x512 ?
image_size = image_size
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Samm2pt5D(
image_encoder=ImageEncoderViT(
sam = Vista2pt5D(
image_encoder=VistaImageEncoderViT(
in_chans=encoder_in_chans,
depth=encoder_depth,
embed_dim=encoder_embed_dim,
Expand All @@ -289,12 +301,14 @@ def _build_sam2pt5d(
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
patch_embed_3d=patch_embed_3d,
),
prompt_encoder=PromptEncoder(
prompt_encoder=VistaPromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
clip_class_label_prompt=clip_class_label_prompt,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3, # TODO: only predict one binary mask
Expand All @@ -315,22 +329,26 @@ def _build_sam2pt5d(
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)

if image_size == 1024:
# we try to use all pretrained weights
new_dict = state_dict
if encoder_in_chans != 3:
new_dict.pop("image_encoder.patch_embed.proj.weight")
else:
new_dict = {}
for k, v in state_dict.items():
# skip weights in prompt_encoder and mask_decoder
if k.startswith("prompt_encoder") or k.startswith("mask_decoder"):
continue
# skip weights in position embedding and learned relative positional embeddings
elif "pos_embed" in k or "attn.rel_pos" in k:
# due to the change of input size
if ("pos_embed" in k and k.startswith("image_encoder")) or (
"attn.rel_pos" in k and k.startswith("image_encoder")
):
continue
else:
new_dict[k] = v

if encoder_in_chans != 3:
new_dict.pop("image_encoder.patch_embed.proj.weight")
new_dict.pop("image_encoder.patch_embed.proj.bias")

sam.load_state_dict(new_dict, strict=False)
print(f"Load {len(new_dict)} keys from checkpoint {checkpoint}, current model has {len(sam.state_dict())} keys")

Expand All @@ -355,106 +373,67 @@ def _build_sam2pt5d(
f"{sum(mask_decoder_params) * 1.e-6:.2f} M params in mask decoder."
)

# comment to unfreeze all encoder layers
# for name, param in sam.named_parameters():
# if name.startswith("image_encoder"):
# if image_size == 1024:
# if "pos_embed" in name or "patch_embed" in name or "blocks.0" in name:
# # we only retrain layers before blocks.1 in image_encoder
# continue
# # if "pos_embed" in name or "patch_embed" in name:
# # # we only retrain pos_embed and patch_embed
# # continue
# else:
# if "pos_embed" in name or "attn.rel_pos" in name or \
# "patch_embed" in name or "blocks.0" in name or "neck" in name:
# # we only train pos_embed, patch_embed, blocks.0, attn.rel_pos (due res change)
# # and neck (a few conv layers for outputs) in image_encoder
# continue
#
# # we freeze all other layers in image_encoder
# param.requires_grad = False

total_trainable_params = sum(p.numel() if p.requires_grad else 0 for p in sam.parameters())
print(f"{sam.__class__.__name__} has {total_trainable_params * 1.e-6:.2f} M trainable params.")
return sam


def build_samm2pt5d_vit_h(checkpoint=None, image_size=1024, encoder_in_chans=3):
return _build_sam2pt5d(
def build_vista2pt5d_vit_h(
checkpoint=None, image_size=1024, encoder_in_chans=3, clip_class_label_prompt=False, patch_embed_3d=False
):
return _build_vista2pt5d(
encoder_in_chans=encoder_in_chans,
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
image_size=image_size,
clip_class_label_prompt=clip_class_label_prompt,
patch_embed_3d=patch_embed_3d,
)


def build_samm2pt5d_vit_l(checkpoint=None, image_size=1024, encoder_in_chans=3):
return _build_sam2pt5d(
def build_vista2pt5d_vit_l(
checkpoint=None, image_size=1024, encoder_in_chans=3, clip_class_label_prompt=False, patch_embed_3d=False
):
return _build_vista2pt5d(
encoder_in_chans=encoder_in_chans,
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
image_size=image_size,
clip_class_label_prompt=clip_class_label_prompt,
patch_embed_3d=patch_embed_3d,
)


def build_samm2pt5d_vit_b(checkpoint=None, image_size=1024, encoder_in_chans=3):
return _build_sam2pt5d(
def build_vista2pt5d_vit_b(
checkpoint=None, image_size=1024, encoder_in_chans=3, clip_class_label_prompt=False, patch_embed_3d=False
):
return _build_vista2pt5d(
encoder_in_chans=encoder_in_chans,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
image_size=image_size,
clip_class_label_prompt=clip_class_label_prompt,
patch_embed_3d=patch_embed_3d,
)


sam_model_registry = {
"default": build_samm2pt5d_vit_h,
"vit_h": build_samm2pt5d_vit_h,
"vit_l": build_samm2pt5d_vit_l,
"vit_b": build_samm2pt5d_vit_b,
"default": build_vista2pt5d_vit_h,
"vit_h": build_vista2pt5d_vit_h,
"vit_l": build_vista2pt5d_vit_l,
"vit_b": build_vista2pt5d_vit_b,
}


if __name__ == "__main__":
model = build_samm2pt5d_vit_b()
model = build_vista2pt5d_vit_b()
model.cuda()
#
# dummy_input = [{"image": torch.rand(3, 176, 345).cuda(), "original_size": (176, 345),
# "point_coords": torch.rand(3, 5, 2).cuda(), "point_labels": torch.ones(3, 5).cuda(),
# "labels": torch.ones(3, 1).long().cuda()},
# {"image": torch.rand(3, 128, 365).cuda(), "original_size": (128, 365),
# "point_coords": torch.rand(1, 3, 2).cuda(), "point_labels": torch.ones(1, 3).cuda(),
# "labels": torch.ones(1, 1).long().cuda()}
# ]
# # dummy_input = [{"image": torch.rand(3, 176, 345).cuda(), "original_size": (256, 512),
# # "point_coords": torch.rand(3, 5, 2).cuda(), "point_labels": torch.ones(3, 5).cuda()}]
# outputs = model(dummy_input)

# test if postprocessing can inverse preprocess
# path = "/home/pengfeig/Downloads/fffabebf-74fd3a1f-673b6b41-96ec0ac9-2ab69818.jpg"
# from PIL import Image
# import numpy as np
# import matplotlib.pyplot as plt
#
# image = np.array(Image.open(path)).transpose(2, 0, 1)[:, :365, :256].astype(np.float32)
# plt.imshow(image.transpose(1, 2, 0).astype(np.uint8))
# plt.show()
# dummy_tensor = torch.from_numpy(image).cuda()
# tmp = model.preprocess(dummy_tensor)
# plt.imshow(tmp.cpu().numpy().transpose(1, 2, 0).astype(np.uint8))
# plt.show()
# inverse_tensor = model.postprocess_masks(tmp.unsqueeze(0), (365, 256)).squeeze(0)
# print(torch.sum(torch.abs(inverse_tensor-dummy_tensor)))
# print("dummy_tensor", torch.min(dummy_tensor), torch.max(dummy_tensor))
# print("inverse_tensor", torch.min(inverse_tensor), torch.max(inverse_tensor))
# plt.imshow(inverse_tensor.cpu().numpy().transpose(1, 2, 0).astype(np.uint8))
# plt.show()
# print()
Loading

0 comments on commit 1e7f3cc

Please sign in to comment.