Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Monkey model and revise codes of OCR LLM #782

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions configs/llm/monkey/monkey.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
model:
name: MonkeyQwenForCausalLM
batch_size: 1
seq_length: 2048
hidden_size: 4096
num_layers: 32
num_heads: 32
vocab_size: 151936
intermediate_size: 11008
rms_norm_eps: 1.0e-6
emb_dropout_prob: 0.0
eos_token_id: 151643
pad_token_id: 151643
compute_dtype: "float16"
layernorm_compute_type: "float32"
softmax_compute_type: "float16"
rotary_dtype: "float16"
param_init_type: "float16"
ln_param_init_type: "float32"
use_past: True
use_flash_attention: False
use_past_shard: False
offset: 0
checkpoint_name_or_path: "/path/to/monkey.ckpt"
repetition_penalty: 1.5
max_decode_length: 2048
top_k: 0
top_p: 0.8
do_sample: False
max_new_tokens: 250
temperature: 0.7
num_beams: 1
length_penalty: 1
num_patches: 1280

# configuration items copied from Qwen
rotary_pct: 1.0
rotary_emb_base: 10000
kv_channels: 128

visual:
heads: 16
image_size: 896
image_start_id: 151857
layers: 48
mlp_ratio: 4.9231
output_dim: 4096
patch_size: 14
width: 1664
lora_repeat_num: 4
positional_embedding_size: 1024
model_type: open_clip

processor:
tokenizer:
vocab_file: "/path/to/qwen.tiktoken"
pad_token: "<|endoftext|>"
3 changes: 1 addition & 2 deletions configs/llm/vary/vary_toy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ model:
max_new_tokens: 1024
temperature: 1.0
num_beams: 1
num_patches: 256

# configuration items copied from Qwen
rotary_pct: 1.0
rotary_emb_base: 10000
kv_channels: 128

processor:
return_tensors: ms
tokenizer:
vocab_file: "/path/to/qwen.tiktoken"
pad_token: "<|endoftext|>"
type: QwenProcessor
69 changes: 67 additions & 2 deletions mindocr/data/transforms/llm_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def f(im):
return f


def load_image(image_file):
image = Image.open(image_file).convert("RGB")
return image


image_processor_high = alb_wrapper(
alb.Compose(
[
Expand Down Expand Up @@ -270,7 +275,7 @@ def __init__(self, image_resolution=224):
self.batch_totensor = BatchToTensor()
self.batch_normalizer = BatchNormalize()

def preprocess(self, images):
def __call__(self, images):
if not self._bhwc_check(images):
images = self.bchw2bhwc(images)
images = self.batch_pilizer(images)
Expand Down Expand Up @@ -299,4 +304,64 @@ def _bhwc_check(image_batch):
return False


image_processor = VaryCLIPImageProcessor().preprocess
class VarySAMImageProcessor:
def __init__(self):
self.image_processor_high = alb_wrapper(
alb.Compose(
[
alb.Resize(1024, 1024),
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
]
)
)

def __call__(self, images):
images = self.image_processor_high(images)
return images


class VaryImageProcessor:
def __init__(self):
self.sam_processor = VarySAMImageProcessor()
self.clip_processor = VaryCLIPImageProcessor()

def __call__(self, images):
if isinstance(images, str):
images = load_image(images)
image_clip = self.clip_processor(images)
image_sam = self.sam_processor(images)
return image_clip, image_sam


class MonkeyImageProcessor:
def __init__(self):
self.resize1 = vision.c_transforms.Resize((896, 896), Inter.PILCUBIC)
self.resize2 = vision.Resize((448, 448), Inter.BICUBIC)
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = vision.Normalize(mean=mean, std=std, is_hwc=False)

@staticmethod
def sliding_window(images, window_size=(448, 448), stride=448):
windows = []
for i in range(2):
for j in range(2):
window = images[
:, :, i * stride : i * stride + window_size[0], j * stride : j * stride + window_size[1]
]
windows.append(window)
return windows

def __call__(self, images):
if isinstance(images, str):
images = load_image(images)
images = self.resize1(images) # hwc -> hwc
images = images / 255.0 # hwc -> hwc
images_trans = images.transpose(2, 0, 1) # hwc -> chw
images_norm = self.normalize(images_trans) # chw -> chw
images_nchw = np.expand_dims(images_norm, 0) # chw -> nchw
windows = self.sliding_window(images_nchw, window_size=(448, 448), stride=448) # nchw -> List[nchw]
images_norm = images_norm.transpose(1, 2, 0)
images_448 = self.resize2(images_norm) # hwc -> hwc
images_448 = np.expand_dims(images_448.transpose(2, 0, 1), 0) # hwc -> nchw
return windows, images_448
1 change: 1 addition & 0 deletions mindocr/nlp/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._registry import register_llm
from .builder import build_llm_model
from .monkey_qwen_model import MonkeyQwenForCausalLM
from .vary_qwen_model import VaryQwenForCausalLM
59 changes: 54 additions & 5 deletions mindocr/nlp/llm/base_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mindocr.nlp.generation import GeneratorMixin
from mindocr.nlp.llm.builder import build_llm_model
from mindocr.nlp.llm.configs import BaseConfig, LLMConfig
from mindocr.utils.conversation import Conversation


class BaseLLMModel(nn.Cell, GeneratorMixin):
Expand All @@ -24,6 +25,13 @@ class BaseLLMModel(nn.Cell, GeneratorMixin):
def __init__(self, config: BaseConfig, **kwargs):
super(BaseLLMModel, self).__init__(**kwargs)
self.config = config
self.conversation = Conversation()
self.image_path = None
self.IMAGE_START_TAG = "<img>"
self.IMAGE_END_TAG = "</img>"
self.IMAGE_PAD_TAG = "<imgpad>"
self.num_patches = config.num_patches
self.image_prefix = f"{self.IMAGE_START_TAG}{self.IMAGE_PAD_TAG * self.num_patches}{self.IMAGE_END_TAG}"

def load_checkpoint(self, config):
"""
Expand All @@ -42,11 +50,7 @@ def load_checkpoint(self, config):
if os.path.exists(checkpoint_name_or_path):
param = load_checkpoint(checkpoint_name_or_path)
else:
raise ValueError(
f"{checkpoint_name_or_path} is not a supported default model"
f" or a valid path to checkpoint,"
f" please select from {self._support_list}."
)
raise ValueError(f"{checkpoint_name_or_path} is not a valid path to checkpoint.")

load_param_into_net(self, param)

Expand Down Expand Up @@ -81,3 +85,48 @@ def from_pretrained(cls, pretrained_model_name_or_dir: str, **kwargs):
config_args = cls._get_config_args(pretrained_model_name_or_dir, **kwargs)
model = build_llm_model(config_args.model)
return model

def reset(self):
self.conversation.messages = list()

def add_image_token_pad_in_query(self, query: str):
query = self.image_prefix + query
return query

def chat(
self,
query: str,
image_path: str = None,
) -> str:
"""
If `image_path` is provided, the conversation will be reset.
example:
inputs:
query: Provide the ocr results of this image.
image_path: xx/xxx.png.
outputs:
response: the modalities of irradiation could be modified...
history: [
("user", "Provide the ocr results of this image."),
("assistant", "the modalities of irradiation could be modified..."),
]
"""
if image_path is not None:
query = self.add_image_token_pad_in_query(query=query)
self.image_path = image_path
self.reset()

self.conversation.add_message(role="user", message=query)
prompt = self.conversation.get_prompt()

inputs = self.tokenizer([prompt], max_length=self.seq_length)
input_ids = inputs["input_ids"]
outputs = self.generate(input_ids=input_ids, image_path=self.image_path)
outputs = self.tokenizer.decode(outputs, skip_special_tokens=False)
response = outputs[0][len(prompt) :]

for special_token in self.tokenizer.special_tokens:
response = response.replace(special_token, "")
self.conversation.add_message(role="assistant", message=response)

return response
11 changes: 10 additions & 1 deletion mindocr/nlp/llm/builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._registry import is_llm, is_llm_class, list_llms, llm_class_entrypoint, llm_entrypoint
from .configs import LLMConfig

__all__ = ["build_llm_model"]

Expand All @@ -11,8 +12,16 @@ def build_llm_model(config):
>>> llm_model = build_llm_model(dict(name='VaryQwenForCausalLM'))
>>> print(llm_model)
"""
if isinstance(config, dict):
config = LLMConfig(**config)
elif isinstance(config, str):
config = LLMConfig(config)
else:
raise TypeError(f"config must be str or dict, but got {type(config)}")
if "model" in config:
config = LLMConfig(**config["model"], **config["processor"])
if "name" not in config:
raise ValueError("name must in `config`.")
raise ValueError("`name` must in `config`.")
name = config["name"]
if is_llm(name):
create_fn = llm_entrypoint(name)
Expand Down
7 changes: 5 additions & 2 deletions mindocr/nlp/llm/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,7 @@ def __init__(
self.num_heads = num_heads
self.max_position_embedding = max_position_embedding if max_position_embedding else seq_length
self.intermediate_size = intermediate_size
self.multiple_of = multiple_of
self.n_kv_heads = n_kv_heads
self.ffn_dim_multiplier = ffn_dim_multiplier
self.rms_norm_eps = rms_norm_eps
self.qkv_concat = qkv_concat
self.param_init_type = convert_mstype(param_init_type)
Expand Down Expand Up @@ -266,6 +264,11 @@ def __init__(self, **kwargs):
super(VaryConfig, self).__init__(**kwargs)


class MonkeyConfig(QwenConfig):
def __init__(self, **kwargs):
super(MonkeyConfig, self).__init__(**kwargs)


class SAMConfig(BaseConfig):
def __init__(
self,
Expand Down
11 changes: 10 additions & 1 deletion mindocr/nlp/llm/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _name_replace(name: str):
name = name.replace("layer_norm2.bias", "ln_2.beta")
name = name.replace("self_attn", "attn")
name = name.replace("post_layernorm", "vision_model.post_layernorm")
if "visual" in name:
name = name.replace("attention_norm", "ln_1")
name = name.replace("ffn_norm", "ln_2")
if "ln_" in name:
name = name.replace("weight", "gamma").replace("bias", "beta")
name = name.replace("feed_forward.w2", "mlp.c_proj")

# sam
name = name.replace("norm1.weight", "norm1.gamma")
Expand Down Expand Up @@ -82,6 +88,9 @@ def convert_pt_to_ms(torch_ckpt_path, output_path, dtype=ms.float16):
state_dict = torch.load(torch_ckpt_path, map_location="cpu")
ckpt_weights = []
for k, v in state_dict.items():
if "lora_scale" in k:
continue

value = pt2ms(v, dtype)

msname = _name_replace(k)
Expand Down Expand Up @@ -111,4 +120,4 @@ def convert_pt_to_ms(torch_ckpt_path, output_path, dtype=ms.float16):

args = parser.parse_args()

convert_pt_to_ms(args.torch_ckpt_path, args.mindspore_ckpt_path, ms.float16)
convert_pt_to_ms(args.torch_ckpt_path, args.mindspore_ckpt_path, ms.float32)
Loading
Loading