Skip to content

Commit

Permalink
add Monkey and revise ocr llm
Browse files Browse the repository at this point in the history
  • Loading branch information
panshaowu committed Dec 3, 2024
1 parent 71065ce commit 0357d7b
Show file tree
Hide file tree
Showing 16 changed files with 1,090 additions and 254 deletions.
57 changes: 57 additions & 0 deletions configs/llm/monkey/monkey.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
model:
name: MonkeyQwenForCausalLM
batch_size: 1
seq_length: 2048
hidden_size: 4096
num_layers: 32
num_heads: 32
vocab_size: 151936
intermediate_size: 11008
rms_norm_eps: 1.0e-6
emb_dropout_prob: 0.0
eos_token_id: 151643
pad_token_id: 151643
compute_dtype: "float16"
layernorm_compute_type: "float32"
softmax_compute_type: "float16"
rotary_dtype: "float16"
param_init_type: "float16"
ln_param_init_type: "float32"
use_past: True
use_flash_attention: False
use_past_shard: False
offset: 0
checkpoint_name_or_path: "/path/to/monkey.ckpt"
repetition_penalty: 1.5
max_decode_length: 2048
top_k: 0
top_p: 0.8
do_sample: False
max_new_tokens: 250
temperature: 0.7
num_beams: 1
length_penalty: 1
num_patches: 1280

# configuration items copied from Qwen
rotary_pct: 1.0
rotary_emb_base: 10000
kv_channels: 128

visual:
heads: 16
image_size: 896
image_start_id: 151857
layers: 48
mlp_ratio: 4.9231
output_dim: 4096
patch_size: 14
width: 1664
lora_repeat_num: 4
positional_embedding_size: 1024
model_type: open_clip

processor:
tokenizer:
vocab_file: "/path/to/qwen.tiktoken"
pad_token: "<|endoftext|>"
3 changes: 1 addition & 2 deletions configs/llm/vary/vary_toy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ model:
max_new_tokens: 1024
temperature: 1.0
num_beams: 1
num_patches: 256

# configuration items copied from Qwen
rotary_pct: 1.0
rotary_emb_base: 10000
kv_channels: 128

processor:
return_tensors: ms
tokenizer:
vocab_file: "/path/to/qwen.tiktoken"
pad_token: "<|endoftext|>"
type: QwenProcessor
69 changes: 67 additions & 2 deletions mindocr/data/transforms/llm_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def f(im):
return f


def load_image(image_file):
image = Image.open(image_file).convert("RGB")
return image


image_processor_high = alb_wrapper(
alb.Compose(
[
Expand Down Expand Up @@ -270,7 +275,7 @@ def __init__(self, image_resolution=224):
self.batch_totensor = BatchToTensor()
self.batch_normalizer = BatchNormalize()

def preprocess(self, images):
def __call__(self, images):
if not self._bhwc_check(images):
images = self.bchw2bhwc(images)
images = self.batch_pilizer(images)
Expand Down Expand Up @@ -299,4 +304,64 @@ def _bhwc_check(image_batch):
return False


image_processor = VaryCLIPImageProcessor().preprocess
class VarySAMImageProcessor:
def __init__(self):
self.image_processor_high = alb_wrapper(
alb.Compose(
[
alb.Resize(1024, 1024),
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
]
)
)

def __call__(self, images):
images = self.image_processor_high(images)
return images


class VaryImageProcessor:
def __init__(self):
self.sam_processor = VarySAMImageProcessor()
self.clip_processor = VaryCLIPImageProcessor()

def __call__(self, images):
if isinstance(images, str):
images = load_image(images)
image_clip = self.clip_processor(images)
image_sam = self.sam_processor(images)
return image_clip, image_sam


class MonkeyImageProcessor:
def __init__(self):
self.resize1 = vision.c_transforms.Resize((896, 896), Inter.PILCUBIC)
self.resize2 = vision.Resize((448, 448), Inter.BICUBIC)
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = vision.Normalize(mean=mean, std=std, is_hwc=False)

@staticmethod
def sliding_window(images, window_size=(448, 448), stride=448):
windows = []
for i in range(2):
for j in range(2):
window = images[
:, :, i * stride : i * stride + window_size[0], j * stride : j * stride + window_size[1]
]
windows.append(window)
return windows

def __call__(self, images):
if isinstance(images, str):
images = load_image(images)
images = self.resize1(images) # hwc -> hwc
images = images / 255.0 # hwc -> hwc
images_trans = images.transpose(2, 0, 1) # hwc -> chw
images_norm = self.normalize(images_trans) # chw -> chw
images_nchw = np.expand_dims(images_norm, 0) # chw -> nchw
windows = self.sliding_window(images_nchw, window_size=(448, 448), stride=448) # nchw -> List[nchw]
images_norm = images_norm.transpose(1, 2, 0)
images_448 = self.resize2(images_norm) # hwc -> hwc
images_448 = np.expand_dims(images_448.transpose(2, 0, 1), 0) # hwc -> nchw
return windows, images_448
1 change: 1 addition & 0 deletions mindocr/nlp/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._registry import register_llm
from .builder import build_llm_model
from .monkey_qwen_model import MonkeyQwenForCausalLM
from .vary_qwen_model import VaryQwenForCausalLM
59 changes: 54 additions & 5 deletions mindocr/nlp/llm/base_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mindocr.nlp.generation import GeneratorMixin
from mindocr.nlp.llm.builder import build_llm_model
from mindocr.nlp.llm.configs import BaseConfig, LLMConfig
from mindocr.utils.conversation import Conversation


class BaseLLMModel(nn.Cell, GeneratorMixin):
Expand All @@ -24,6 +25,13 @@ class BaseLLMModel(nn.Cell, GeneratorMixin):
def __init__(self, config: BaseConfig, **kwargs):
super(BaseLLMModel, self).__init__(**kwargs)
self.config = config
self.conversation = Conversation()
self.image_path = None
self.IMAGE_START_TAG = "<img>"
self.IMAGE_END_TAG = "</img>"
self.IMAGE_PAD_TAG = "<imgpad>"
self.num_patches = config.num_patches
self.image_prefix = f"{self.IMAGE_START_TAG}{self.IMAGE_PAD_TAG * self.num_patches}{self.IMAGE_END_TAG}"

def load_checkpoint(self, config):
"""
Expand All @@ -42,11 +50,7 @@ def load_checkpoint(self, config):
if os.path.exists(checkpoint_name_or_path):
param = load_checkpoint(checkpoint_name_or_path)
else:
raise ValueError(
f"{checkpoint_name_or_path} is not a supported default model"
f" or a valid path to checkpoint,"
f" please select from {self._support_list}."
)
raise ValueError(f"{checkpoint_name_or_path} is not a valid path to checkpoint.")

load_param_into_net(self, param)

Expand Down Expand Up @@ -81,3 +85,48 @@ def from_pretrained(cls, pretrained_model_name_or_dir: str, **kwargs):
config_args = cls._get_config_args(pretrained_model_name_or_dir, **kwargs)
model = build_llm_model(config_args.model)
return model

def reset(self):
self.conversation.messages = list()

def add_image_token_pad_in_query(self, query: str):
query = self.image_prefix + query
return query

def chat(
self,
query: str,
image_path: str = None,
) -> str:
"""
If `image_path` is provided, the conversation will be reset.
example:
inputs:
query: Provide the ocr results of this image.
image_path: xx/xxx.png.
outputs:
response: the modalities of irradiation could be modified...
history: [
("user", "Provide the ocr results of this image."),
("assistant", "the modalities of irradiation could be modified..."),
]
"""
if image_path is not None:
query = self.add_image_token_pad_in_query(query=query)
self.image_path = image_path
self.reset()

self.conversation.add_message(role="user", message=query)
prompt = self.conversation.get_prompt()

inputs = self.tokenizer([prompt], max_length=self.seq_length)
input_ids = inputs["input_ids"]
outputs = self.generate(input_ids=input_ids, image_path=self.image_path)
outputs = self.tokenizer.decode(outputs, skip_special_tokens=False)
response = outputs[0][len(prompt) :]

for special_token in self.tokenizer.special_tokens:
response = response.replace(special_token, "")
self.conversation.add_message(role="assistant", message=response)

return response
11 changes: 10 additions & 1 deletion mindocr/nlp/llm/builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._registry import is_llm, is_llm_class, list_llms, llm_class_entrypoint, llm_entrypoint
from .configs import LLMConfig

__all__ = ["build_llm_model"]

Expand All @@ -11,8 +12,16 @@ def build_llm_model(config):
>>> llm_model = build_llm_model(dict(name='VaryQwenForCausalLM'))
>>> print(llm_model)
"""
if isinstance(config, dict):
config = LLMConfig(**config)
elif isinstance(config, str):
config = LLMConfig(config)
else:
raise TypeError(f"config must be str or dict, but got {type(config)}")
if "model" in config:
config = LLMConfig(**config["model"], **config["processor"])
if "name" not in config:
raise ValueError("name must in `config`.")
raise ValueError("`name` must in `config`.")
name = config["name"]
if is_llm(name):
create_fn = llm_entrypoint(name)
Expand Down
7 changes: 5 additions & 2 deletions mindocr/nlp/llm/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,7 @@ def __init__(
self.num_heads = num_heads
self.max_position_embedding = max_position_embedding if max_position_embedding else seq_length
self.intermediate_size = intermediate_size
self.multiple_of = multiple_of
self.n_kv_heads = n_kv_heads
self.ffn_dim_multiplier = ffn_dim_multiplier
self.rms_norm_eps = rms_norm_eps
self.qkv_concat = qkv_concat
self.param_init_type = convert_mstype(param_init_type)
Expand Down Expand Up @@ -266,6 +264,11 @@ def __init__(self, **kwargs):
super(VaryConfig, self).__init__(**kwargs)


class MonkeyConfig(QwenConfig):
def __init__(self, **kwargs):
super(MonkeyConfig, self).__init__(**kwargs)


class SAMConfig(BaseConfig):
def __init__(
self,
Expand Down
11 changes: 10 additions & 1 deletion mindocr/nlp/llm/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _name_replace(name: str):
name = name.replace("layer_norm2.bias", "ln_2.beta")
name = name.replace("self_attn", "attn")
name = name.replace("post_layernorm", "vision_model.post_layernorm")
if "visual" in name:
name = name.replace("attention_norm", "ln_1")
name = name.replace("ffn_norm", "ln_2")
if "ln_" in name:
name = name.replace("weight", "gamma").replace("bias", "beta")
name = name.replace("feed_forward.w2", "mlp.c_proj")

# sam
name = name.replace("norm1.weight", "norm1.gamma")
Expand Down Expand Up @@ -82,6 +88,9 @@ def convert_pt_to_ms(torch_ckpt_path, output_path, dtype=ms.float16):
state_dict = torch.load(torch_ckpt_path, map_location="cpu")
ckpt_weights = []
for k, v in state_dict.items():
if "lora_scale" in k:
continue

value = pt2ms(v, dtype)

msname = _name_replace(k)
Expand Down Expand Up @@ -111,4 +120,4 @@ def convert_pt_to_ms(torch_ckpt_path, output_path, dtype=ms.float16):

args = parser.parse_args()

convert_pt_to_ms(args.torch_ckpt_path, args.mindspore_ckpt_path, ms.float16)
convert_pt_to_ms(args.torch_ckpt_path, args.mindspore_ckpt_path, ms.float32)
Loading

0 comments on commit 0357d7b

Please sign in to comment.