diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst index 93e3e8416ad..5d27cae2739 100644 --- a/docs/en/api/models.rst +++ b/docs/en/api/models.rst @@ -220,6 +220,8 @@ Backbones ViTSAM XCiT ViTEVA02 + PromptedViT + PromptedViTEVA02 .. module:: mmpretrain.models.necks diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py index 60e37fb7b6e..de102e32124 100644 --- a/mmpretrain/models/backbones/__init__.py +++ b/mmpretrain/models/backbones/__init__.py @@ -57,73 +57,21 @@ from .vision_transformer import VisionTransformer from .vit_eva02 import ViTEVA02 from .vit_sam import ViTSAM +from .vpt import PromptedViT, PromptedViTEVA02 from .xcit import XCiT __all__ = [ - 'LeNet5', - 'AlexNet', - 'VGG', - 'RegNet', - 'ResNet', - 'ResNeXt', - 'ResNetV1d', - 'ResNeSt', - 'ResNet_CIFAR', - 'SEResNet', - 'SEResNeXt', - 'ShuffleNetV1', - 'ShuffleNetV2', - 'MobileNetV2', - 'MobileNetV3', - 'VisionTransformer', - 'SwinTransformer', - 'TNT', - 'TIMMBackbone', - 'T2T_ViT', - 'Res2Net', - 'RepVGG', - 'Conformer', - 'MlpMixer', - 'DistilledVisionTransformer', - 'PCPVT', - 'SVT', - 'EfficientNet', - 'EfficientNetV2', - 'ConvNeXt', - 'HRNet', - 'ResNetV1c', - 'ConvMixer', - 'EdgeNeXt', - 'CSPDarkNet', - 'CSPResNet', - 'CSPResNeXt', - 'CSPNet', - 'RepLKNet', - 'RepMLPNet', - 'PoolFormer', - 'RIFormer', - 'DenseNet', - 'VAN', - 'InceptionV3', - 'MobileOne', - 'EfficientFormer', - 'SwinTransformerV2', - 'MViT', - 'DeiT3', - 'HorNet', - 'MobileViT', - 'DaViT', - 'BEiTViT', - 'RevVisionTransformer', - 'MixMIMTransformer', - 'TinyViT', - 'LeViT', - 'Vig', - 'PyramidVig', - 'XCiT', - 'ViTSAM', - 'ViTEVA02', - 'HiViT', - 'SparseResNet', - 'SparseConvNeXt', + 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', + 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', + 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', + 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG', + 'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT', + 'EfficientNet', 'EfficientNetV2', 'ConvNeXt', 'HRNet', 'ResNetV1c', + 'ConvMixer', 'EdgeNeXt', 'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet', + 'RepLKNet', 'RepMLPNet', 'PoolFormer', 'RIFormer', 'DenseNet', 'VAN', + 'InceptionV3', 'MobileOne', 'EfficientFormer', 'SwinTransformerV2', 'MViT', + 'DeiT3', 'HorNet', 'MobileViT', 'DaViT', 'BEiTViT', 'RevVisionTransformer', + 'MixMIMTransformer', 'TinyViT', 'LeViT', 'Vig', 'PyramidVig', 'XCiT', + 'ViTSAM', 'ViTEVA02', 'HiViT', 'SparseResNet', 'SparseConvNeXt', + 'PromptedViT', 'PromptedViTEVA02' ] diff --git a/mmpretrain/models/backbones/vpt.py b/mmpretrain/models/backbones/vpt.py new file mode 100644 index 00000000000..abd794668b2 --- /dev/null +++ b/mmpretrain/models/backbones/vpt.py @@ -0,0 +1,335 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +import torch +import torch.nn as nn + +from mmpretrain.models.backbones import VisionTransformer, ViTEVA02 +from mmpretrain.models.utils import build_norm_layer, resize_pos_embed +from mmpretrain.registry import MODELS + + +def init_prompt(prompt_init, prompt): + if prompt_init == 'uniform': + nn.init.uniform_(prompt, -0.08, 0.08) + elif prompt_init == 'zero': + nn.init.zeros_(prompt) + elif prompt_init == 'kaiming': + nn.init.kaiming_normal_(prompt) + elif prompt_init == 'token': + nn.init.zeros_(prompt) + else: + nn.init.normal_(prompt, std=0.02) + + +@MODELS.register_module() +class PromptedViT(VisionTransformer): + """Vision Transformer with Prompt. + + A PyTorch implement of : `Visual Prompt Tuning + `_ + + Args: + prompt_length (int): the length of prompt parameters. Defaults to 1. + deep_prompt (bool): Whether to use deep prompt, Defaults to True. + prompt_init (str): The Initialisation method. Defaults to 'normal'. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + - ``"avg_all"``: The global averaged feature map & cls_tocken + & prompt tensor with shape (B, C). + - ``"avg_prompt"``: The global averaged prompt tensor with + shape (B, C). + - ``"avg_prompt_clstoken"``: The global averaged cls_tocken + & prompt tensor with shape (B, C). + + Defaults to ``"avg_all"``. + *args(list, optional): Other args for VisionTransformer. + **kwargs(dict, optional): Other args for VisionTransformer. + """ + + num_extra_tokens = 1 # class token + OUT_TYPES = { + 'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt', + 'avg_prompt_clstoken' + } + + def __init__(self, + prompt_length: int = 1, + deep_prompt: bool = True, + out_type: str = 'avg_all', + prompt_init: str = 'normal', + norm_cfg: dict = dict(type='LN'), + *args, + **kwargs): + super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs) + + self.prompt_layers = len(self.layers) if deep_prompt else 1 + prompt = torch.empty(self.prompt_layers, prompt_length, + self.embed_dims) + init_prompt(prompt_init, prompt) + self.prompt_initialized = False if prompt_init == 'token' else True + self.prompt = nn.Parameter(prompt, requires_grad=True) + + self.prompt_length = prompt_length + self.deep_prompt = deep_prompt + self.num_extra_tokens = self.num_extra_tokens + prompt_length + + if self.out_type in { + 'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken' + }: + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + # freeze stages + self.frozen_stages = len(self.layers) + self._freeze_stages() + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + # reshape to [layers, batch, tokens, embed_dims] + prompt = self.prompt.unsqueeze(1).expand(-1, x.shape[0], -1, -1) + x = torch.cat([x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]], dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if self.deep_prompt and i != len(self.layers) - 1: + x = torch.cat([ + x[:, :1, :], prompt[i, :, :, :], + x[:, self.prompt_length + 1:, :] + ], + dim=1) + + # final_norm should be False here + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(x[:, self.prompt_length + 1:].mean(dim=1)) + if self.out_type == 'avg_all': + return self.ln2(x.mean(dim=1)) + if self.out_type == 'avg_prompt': + return self.ln2(x[:, 1:self.prompt_length + 1].mean(dim=1)) + if self.out_type == 'avg_prompt_clstoken': + return self.ln2(x[:, :self.prompt_length + 1].mean(dim=1)) + + +def new_AttentionWithRoPE_forward_fn(self, x, patch_resolution): + B, N, _ = x.shape + H, W = patch_resolution + extra_token_num = N - H * W + + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(dim=0) + + if self.rope: + if extra_token_num > 0: + q_t = q[:, :, extra_token_num:, :] + ro_q_t = self.rope(q_t, patch_resolution) + q = torch.cat((q[:, :, :extra_token_num, :], ro_q_t), + -2).type_as(v) + + k_t = k[:, :, extra_token_num:, :] + ro_k_t = self.rope(k_t, patch_resolution) + k = torch.cat((k[:, :, :extra_token_num, :], ro_k_t), + -2).type_as(v) + else: + q = self.rope(q, patch_resolution) + k = self.rope(k, patch_resolution) + + q = q * self.scale + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1).type_as(x) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +@MODELS.register_module() +class PromptedViTEVA02(ViTEVA02): + """EVA02 Vision Transformer with Prompt. + + A PyTorch implement of : `Visual Prompt Tuning + `_ + + Args: + prompt_length (int): the length of prompt parameters. Defaults to 1. + deep_prompt (bool): Whether to use deep prompt, Defaults to True. + prompt_init (str): The Initialisation method. Defaults to 'normal'. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + - ``"avg_all"``: The global averaged feature map & cls_tocken + & prompt tensor with shape (B, C). + - ``"avg_prompt"``: The global averaged prompt tensor with + shape (B, C). + - ``"avg_prompt_clstoken"``: The global averaged cls_tocken + & prompt tensor with shape (B, C). + + Defaults to ``"avg_all"``. + *args(list, optional): Other args for ViTEVA02. + **kwargs(dict, optional): Other args for ViTEVA02. + """ + + num_extra_tokens = 1 # class token + OUT_TYPES = { + 'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt', + 'avg_prompt_clstoken' + } + + # 'avg_all' : avg of 'prompt' & 'cls_token' & 'featmap' + # 'avg_prompt' avg of 'prompt' + # 'avg_prompt_clstoken' avg of 'cls_token' and 'prompt' + def __init__(self, + prompt_length=1, + deep_prompt=True, + out_type='avg_all', + prompt_init: str = 'normal', + norm_cfg=dict(type='LN'), + *args, + **kwargs): + super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs) + + self.prompt_layers = len(self.layers) if deep_prompt else 1 + prompt = torch.empty(self.prompt_layers, prompt_length, + self.embed_dims) + if prompt_init == 'uniform': + nn.init.uniform_(prompt, -0.08, 0.08) + elif prompt_init == 'zero': + nn.init.zeros_(prompt) + elif prompt_init == 'kaiming': + nn.init.kaiming_normal_(prompt) + elif prompt_init == 'token': + nn.init.zeros_(prompt) + self.prompt_initialized = False + else: + nn.init.normal_(prompt, std=0.02) + self.prompt = nn.Parameter(prompt, requires_grad=True) + self.prompt_length = prompt_length + self.deep_prompt = deep_prompt + + if self.out_type in { + 'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken' + }: + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + # freeze stages + self.frozen_stages = len(self.layers) + self._freeze_stages() + + @patch('mmpretrain.models.backbones.vit_eva02.AttentionWithRoPE.forward', + new_AttentionWithRoPE_forward_fn) + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + # reshape to [layers, batch, tokens, embed_dims] + prompt = self.prompt.unsqueeze(1).expand(-1, x.shape[0], -1, -1) + x = torch.cat([x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]], dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, patch_resolution) + + if self.deep_prompt and i != len(self.layers) - 1: + x = torch.cat([ + x[:, :1, :], prompt[i, :, :, :], + x[:, self.prompt_length + 1:, :] + ], + dim=1) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(x[:, self.prompt_length:].mean(dim=1)) + if self.out_type == 'avg_all': + return self.ln2(x.mean(dim=1)) + if self.out_type == 'avg_prompt': + return self.ln2(x[:, 1:self.prompt_length + 1].mean(dim=1)) + if self.out_type == 'avg_prompt_clstoken': + return self.ln2(x[:, :self.prompt_length + 1].mean(dim=1))