diff --git a/configs/llm/monkey/monkey.yaml b/configs/llm/monkey/monkey.yaml new file mode 100644 index 000000000..385aa5305 --- /dev/null +++ b/configs/llm/monkey/monkey.yaml @@ -0,0 +1,57 @@ +model: + name: MonkeyQwenForCausalLM + batch_size: 1 + seq_length: 2048 + hidden_size: 4096 + num_layers: 32 + num_heads: 32 + vocab_size: 151936 + intermediate_size: 11008 + rms_norm_eps: 1.0e-6 + emb_dropout_prob: 0.0 + eos_token_id: 151643 + pad_token_id: 151643 + compute_dtype: "float16" + layernorm_compute_type: "float32" + softmax_compute_type: "float16" + rotary_dtype: "float16" + param_init_type: "float16" + ln_param_init_type: "float32" + use_past: True + use_flash_attention: False + use_past_shard: False + offset: 0 + checkpoint_name_or_path: "/path/to/monkey.ckpt" + repetition_penalty: 1.5 + max_decode_length: 2048 + top_k: 0 + top_p: 0.8 + do_sample: False + max_new_tokens: 250 + temperature: 0.7 + num_beams: 1 + length_penalty: 1 + num_patches: 1280 + + # configuration items copied from Qwen + rotary_pct: 1.0 + rotary_emb_base: 10000 + kv_channels: 128 + + visual: + heads: 16 + image_size: 896 + image_start_id: 151857 + layers: 48 + mlp_ratio: 4.9231 + output_dim: 4096 + patch_size: 14 + width: 1664 + lora_repeat_num: 4 + positional_embedding_size: 1024 + model_type: open_clip + +processor: + tokenizer: + vocab_file: "/path/to/qwen.tiktoken" + pad_token: "<|endoftext|>" diff --git a/configs/llm/vary/vary_toy.yaml b/configs/llm/vary/vary_toy.yaml index 1b6b59b9b..999f35930 100644 --- a/configs/llm/vary/vary_toy.yaml +++ b/configs/llm/vary/vary_toy.yaml @@ -30,6 +30,7 @@ model: max_new_tokens: 1024 temperature: 1.0 num_beams: 1 + num_patches: 256 # configuration items copied from Qwen rotary_pct: 1.0 @@ -37,8 +38,6 @@ model: kv_channels: 128 processor: - return_tensors: ms tokenizer: vocab_file: "/path/to/qwen.tiktoken" pad_token: "<|endoftext|>" - type: QwenProcessor diff --git a/mindocr/data/transforms/llm_transform.py b/mindocr/data/transforms/llm_transform.py index 92bfaa35d..d467824f6 100644 --- a/mindocr/data/transforms/llm_transform.py +++ b/mindocr/data/transforms/llm_transform.py @@ -28,6 +28,11 @@ def f(im): return f +def load_image(image_file): + image = Image.open(image_file).convert("RGB") + return image + + image_processor_high = alb_wrapper( alb.Compose( [ @@ -270,7 +275,7 @@ def __init__(self, image_resolution=224): self.batch_totensor = BatchToTensor() self.batch_normalizer = BatchNormalize() - def preprocess(self, images): + def __call__(self, images): if not self._bhwc_check(images): images = self.bchw2bhwc(images) images = self.batch_pilizer(images) @@ -299,4 +304,64 @@ def _bhwc_check(image_batch): return False -image_processor = VaryCLIPImageProcessor().preprocess +class VarySAMImageProcessor: + def __init__(self): + self.image_processor_high = alb_wrapper( + alb.Compose( + [ + alb.Resize(1024, 1024), + alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ] + ) + ) + + def __call__(self, images): + images = self.image_processor_high(images) + return images + + +class VaryImageProcessor: + def __init__(self): + self.sam_processor = VarySAMImageProcessor() + self.clip_processor = VaryCLIPImageProcessor() + + def __call__(self, images): + if isinstance(images, str): + images = load_image(images) + image_clip = self.clip_processor(images) + image_sam = self.sam_processor(images) + return image_clip, image_sam + + +class MonkeyImageProcessor: + def __init__(self): + self.resize1 = vision.c_transforms.Resize((896, 896), Inter.PILCUBIC) + self.resize2 = vision.Resize((448, 448), Inter.BICUBIC) + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + self.normalize = vision.Normalize(mean=mean, std=std, is_hwc=False) + + @staticmethod + def sliding_window(images, window_size=(448, 448), stride=448): + windows = [] + for i in range(2): + for j in range(2): + window = images[ + :, :, i * stride : i * stride + window_size[0], j * stride : j * stride + window_size[1] + ] + windows.append(window) + return windows + + def __call__(self, images): + if isinstance(images, str): + images = load_image(images) + images = self.resize1(images) # hwc -> hwc + images = images / 255.0 # hwc -> hwc + images_trans = images.transpose(2, 0, 1) # hwc -> chw + images_norm = self.normalize(images_trans) # chw -> chw + images_nchw = np.expand_dims(images_norm, 0) # chw -> nchw + windows = self.sliding_window(images_nchw, window_size=(448, 448), stride=448) # nchw -> List[nchw] + images_norm = images_norm.transpose(1, 2, 0) + images_448 = self.resize2(images_norm) # hwc -> hwc + images_448 = np.expand_dims(images_448.transpose(2, 0, 1), 0) # hwc -> nchw + return windows, images_448 diff --git a/mindocr/nlp/llm/__init__.py b/mindocr/nlp/llm/__init__.py index ce614c47a..f0a006d5c 100644 --- a/mindocr/nlp/llm/__init__.py +++ b/mindocr/nlp/llm/__init__.py @@ -1,3 +1,4 @@ from ._registry import register_llm from .builder import build_llm_model +from .monkey_qwen_model import MonkeyQwenForCausalLM from .vary_qwen_model import VaryQwenForCausalLM diff --git a/mindocr/nlp/llm/base_llm_model.py b/mindocr/nlp/llm/base_llm_model.py index d18279cdc..6a3ac658e 100644 --- a/mindocr/nlp/llm/base_llm_model.py +++ b/mindocr/nlp/llm/base_llm_model.py @@ -7,6 +7,7 @@ from mindocr.nlp.generation import GeneratorMixin from mindocr.nlp.llm.builder import build_llm_model from mindocr.nlp.llm.configs import BaseConfig, LLMConfig +from mindocr.utils.conversation import Conversation class BaseLLMModel(nn.Cell, GeneratorMixin): @@ -24,6 +25,13 @@ class BaseLLMModel(nn.Cell, GeneratorMixin): def __init__(self, config: BaseConfig, **kwargs): super(BaseLLMModel, self).__init__(**kwargs) self.config = config + self.conversation = Conversation() + self.image_path = None + self.IMAGE_START_TAG = "" + self.IMAGE_END_TAG = "" + self.IMAGE_PAD_TAG = "" + self.num_patches = config.num_patches + self.image_prefix = f"{self.IMAGE_START_TAG}{self.IMAGE_PAD_TAG * self.num_patches}{self.IMAGE_END_TAG}" def load_checkpoint(self, config): """ @@ -42,11 +50,7 @@ def load_checkpoint(self, config): if os.path.exists(checkpoint_name_or_path): param = load_checkpoint(checkpoint_name_or_path) else: - raise ValueError( - f"{checkpoint_name_or_path} is not a supported default model" - f" or a valid path to checkpoint," - f" please select from {self._support_list}." - ) + raise ValueError(f"{checkpoint_name_or_path} is not a valid path to checkpoint.") load_param_into_net(self, param) @@ -81,3 +85,48 @@ def from_pretrained(cls, pretrained_model_name_or_dir: str, **kwargs): config_args = cls._get_config_args(pretrained_model_name_or_dir, **kwargs) model = build_llm_model(config_args.model) return model + + def reset(self): + self.conversation.messages = list() + + def add_image_token_pad_in_query(self, query: str): + query = self.image_prefix + query + return query + + def chat( + self, + query: str, + image_path: str = None, + ) -> str: + """ + If `image_path` is provided, the conversation will be reset. + example: + inputs: + query: Provide the ocr results of this image. + image_path: xx/xxx.png. + outputs: + response: the modalities of irradiation could be modified... + history: [ + ("user", "Provide the ocr results of this image."), + ("assistant", "the modalities of irradiation could be modified..."), + ] + """ + if image_path is not None: + query = self.add_image_token_pad_in_query(query=query) + self.image_path = image_path + self.reset() + + self.conversation.add_message(role="user", message=query) + prompt = self.conversation.get_prompt() + + inputs = self.tokenizer([prompt], max_length=self.seq_length) + input_ids = inputs["input_ids"] + outputs = self.generate(input_ids=input_ids, image_path=self.image_path) + outputs = self.tokenizer.decode(outputs, skip_special_tokens=False) + response = outputs[0][len(prompt) :] + + for special_token in self.tokenizer.special_tokens: + response = response.replace(special_token, "") + self.conversation.add_message(role="assistant", message=response) + + return response diff --git a/mindocr/nlp/llm/builder.py b/mindocr/nlp/llm/builder.py index 490311b09..033a8ef78 100644 --- a/mindocr/nlp/llm/builder.py +++ b/mindocr/nlp/llm/builder.py @@ -1,4 +1,5 @@ from ._registry import is_llm, is_llm_class, list_llms, llm_class_entrypoint, llm_entrypoint +from .configs import LLMConfig __all__ = ["build_llm_model"] @@ -11,8 +12,16 @@ def build_llm_model(config): >>> llm_model = build_llm_model(dict(name='VaryQwenForCausalLM')) >>> print(llm_model) """ + if isinstance(config, dict): + config = LLMConfig(**config) + elif isinstance(config, str): + config = LLMConfig(config) + else: + raise TypeError(f"config must be str or dict, but got {type(config)}") + if "model" in config: + config = LLMConfig(**config["model"], **config["processor"]) if "name" not in config: - raise ValueError("name must in `config`.") + raise ValueError("`name` must in `config`.") name = config["name"] if is_llm(name): create_fn = llm_entrypoint(name) diff --git a/mindocr/nlp/llm/configs.py b/mindocr/nlp/llm/configs.py index 76dba33cd..33daf7f4c 100644 --- a/mindocr/nlp/llm/configs.py +++ b/mindocr/nlp/llm/configs.py @@ -222,9 +222,7 @@ def __init__( self.num_heads = num_heads self.max_position_embedding = max_position_embedding if max_position_embedding else seq_length self.intermediate_size = intermediate_size - self.multiple_of = multiple_of self.n_kv_heads = n_kv_heads - self.ffn_dim_multiplier = ffn_dim_multiplier self.rms_norm_eps = rms_norm_eps self.qkv_concat = qkv_concat self.param_init_type = convert_mstype(param_init_type) @@ -266,6 +264,11 @@ def __init__(self, **kwargs): super(VaryConfig, self).__init__(**kwargs) +class MonkeyConfig(QwenConfig): + def __init__(self, **kwargs): + super(MonkeyConfig, self).__init__(**kwargs) + + class SAMConfig(BaseConfig): def __init__( self, diff --git a/mindocr/nlp/llm/convert_weight.py b/mindocr/nlp/llm/convert_weight.py index eabe1d223..5b1c11f13 100644 --- a/mindocr/nlp/llm/convert_weight.py +++ b/mindocr/nlp/llm/convert_weight.py @@ -51,6 +51,12 @@ def _name_replace(name: str): name = name.replace("layer_norm2.bias", "ln_2.beta") name = name.replace("self_attn", "attn") name = name.replace("post_layernorm", "vision_model.post_layernorm") + if "visual" in name: + name = name.replace("attention_norm", "ln_1") + name = name.replace("ffn_norm", "ln_2") + if "ln_" in name: + name = name.replace("weight", "gamma").replace("bias", "beta") + name = name.replace("feed_forward.w2", "mlp.c_proj") # sam name = name.replace("norm1.weight", "norm1.gamma") @@ -82,6 +88,9 @@ def convert_pt_to_ms(torch_ckpt_path, output_path, dtype=ms.float16): state_dict = torch.load(torch_ckpt_path, map_location="cpu") ckpt_weights = [] for k, v in state_dict.items(): + if "lora_scale" in k: + continue + value = pt2ms(v, dtype) msname = _name_replace(k) @@ -111,4 +120,4 @@ def convert_pt_to_ms(torch_ckpt_path, output_path, dtype=ms.float16): args = parser.parse_args() - convert_pt_to_ms(args.torch_ckpt_path, args.mindspore_ckpt_path, ms.float16) + convert_pt_to_ms(args.torch_ckpt_path, args.mindspore_ckpt_path, ms.float32) diff --git a/mindocr/nlp/llm/monkey_qwen_model.py b/mindocr/nlp/llm/monkey_qwen_model.py new file mode 100644 index 000000000..46cf936be --- /dev/null +++ b/mindocr/nlp/llm/monkey_qwen_model.py @@ -0,0 +1,227 @@ +import mindspore as ms +from mindspore import ops + +from mindocr.data.transforms.llm_transform import MonkeyImageProcessor +from mindocr.nlp.llm import register_llm +from mindocr.nlp.llm.configs import MonkeyConfig +from mindocr.nlp.llm.qwen_model import QwenForCausalLM, QwenModel +from mindocr.nlp.llm.vary_clip_model import VisionTransformer +from mindocr.nlp.utils.layers import Linear +from mindocr.utils.conversation import Conversation + + +class MonkeyModel(QwenModel): + def __init__(self, config): + super().__init__(config) + self.image_start_token_pos = 0 + self.num_patches = 1280 + + self.visual = VisionTransformer( + input_resolution=config.visual.get("image_size", 896), # image_size in transformers + patch_size=config.visual.get("patch_size", 14), # patch_size in transformers + width=config.visual.get("width", 1664), # hidden_size + layers=config.visual.get("layers", 48), # num_hidden_layers + heads=config.visual.get("heads", 16), # num_attention_heads + output_dim=config.visual.get("output_dim", 4096), # projection_dim in transformers + vision_select_layer=-2, + param_init_type=config.param_init_type, + ln_param_init_type=config.ln_param_init_type, + positional_embedding_size=config.visual.get("positional_embedding_size", 1024), + mlp_ratio=config.visual.get("mlp_ratio", 4.9231), + model_type=config.visual.get("model_type", "open_clip"), + compute_dtype=config.compute_dtype, + layernorm_compute_type=config.layernorm_compute_type, + ) + + def construct( + self, + input_ids, + init_reset=True, + batch_valid_length=None, + batch_index=None, + zactivate_len=None, + windows=None, + image=None, + ): + """construct""" + if input_ids is not None: + input_shape = input_ids.shape + input_ids = input_ids.view(-1, input_shape[-1]) + + # 1. wte + bs, seq_len = self.shape(input_ids) + hidden_states = self.wte(input_ids) + + # 2. drop + hidden_states = self.drop(hidden_states) + + # image embedding + if seq_len > 1 and image is not None and windows is not None: + patch_list = [] + lora_idx = 0 + for image_patch in windows: + patch = self.visual(image_patch, idx=lora_idx) + patch_list.append(patch) + lora_idx += 1 + global_feat = self.visual(image) + + local_feat = ops.cat(patch_list, axis=1) + image_features = ops.cat([local_feat, global_feat], axis=1) + + if seq_len > 1 and image_features is not None: + new_input_embeds = [] + num_patches = self.num_patches + image_start_token_pos = self.image_start_token_pos + for i in range(bs): + cur_input_embeds = hidden_states[i] + per_cur_image_features = image_features[i] + cur_input_embeds = ops.cat( + ( + cur_input_embeds[: image_start_token_pos + 1], + per_cur_image_features, + cur_input_embeds[image_start_token_pos + num_patches + 1 :], + ), + axis=0, + ) + + new_input_embeds.append(cur_input_embeds) + + hidden_states = ops.stack(new_input_embeds, axis=0) + + # 3. rotary_emb + if not self.use_past: + freqs_cis = self.freqs_mgr() + mask = self.casual_mask(input_ids) # mask: [bs, seq, seq] + mask = self.casual_mask.post_process(mask) + kvcache_inputs = None + else: + if self.is_first_iteration: + freqs_cis = self.freqs_mgr(seq_len) + mask = self.casual_mask(input_ids) # mask: [bs, seq, seq] + else: + freqs_cis = self.freqs_mgr.increment(batch_valid_length, bs) + if self.is_dynamic and self.is_flexible_shape and not self.use_kvcache_op: + mask = self.casual_mask.increment_slice( + self.kvcache_preprocess.range, + self.kvcache_preprocess.max_cache_length // bs, + batch_valid_length, + zactivate_len, + ) + else: + mask = self.casual_mask.increment(self.kvcache_preprocess.range, batch_valid_length, zactivate_len) + mask = self.casual_mask.post_process(mask) + + kvcache_inputs = self.kvcache_preprocess(bs, batch_valid_length, batch_index, zactivate_len) + + # 4. hidden_states + for i in range(self.num_hidden_layers): + hidden_states = self.layers[i](hidden_states, freqs_cis, mask, kvcache_inputs=kvcache_inputs) + + # 5. ln_f + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +@register_llm +class MonkeyQwenForCausalLM(QwenForCausalLM): + def __init__(self, config): + config = MonkeyConfig(**config) + super().__init__(config) + self.transformer = MonkeyModel(config) + self.lm_head = Linear( + config.hidden_size, + config.vocab_size, + has_bias=False, + param_init_type=config.param_init_type, + compute_dtype=config.compute_dtype, + ) + self.conversation = Conversation(generate_mode=True) + self.image_processor = MonkeyImageProcessor() + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + image_path = kwargs.get("image_path") + if image_path is None: + windows, image = None, None + else: + windows, image = self.image_processor(image_path) + windows = ms.Tensor(windows, ms.float16) + image = ms.Tensor(image, ms.float16) + return { + "input_ids": ms.Tensor(input_ids, ms.int32), + "windows": windows, + "image": image, + } + + def construct( + self, + input_ids, + labels=None, + input_position=None, + position_ids=None, + attention_mask=None, + input_embeds=None, + init_reset=True, + batch_valid_length=None, + batch_index=None, + zactivate_len=None, + windows=None, + image=None, + ): + """construct""" + bsz, seqlen = input_ids.shape + if self.use_past: + if not isinstance(batch_valid_length, ms.Tensor): + batch_valid_length = self.ones((bsz,), ms.int32) + if self.training: + tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1)) + else: + tokens = input_ids + + if batch_valid_length is not None: + batch_valid_length = self.reshape(batch_valid_length, (-1,)) + if not self.is_first_iteration: + batch_valid_length = self.sub_batch_valid_len(batch_valid_length, 1) + + output = self.transformer( + tokens, + init_reset=init_reset, + batch_valid_length=batch_valid_length, + batch_index=batch_index, + zactivate_len=zactivate_len, + windows=windows, + image=image, + ) + pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None + if pre_gather: + output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1) + logits = self.lm_head(output) + + input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), ms.float32) + if labels is None: + labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1)) + else: + if labels.ndim > 1: + if self.training: + labels = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1)) + label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), ms.float32) + input_mask = self.mul(input_mask, label_mask) + + if not self.training: + if not pre_gather: + logits = self.reshape(logits, (bsz, seqlen, -1)) + logits = self.cast(logits, ms.float32) + # makes cast effective to avoid allgather issue in Mindspore1.10 + input_mask = self.add(input_mask, 1) + return logits, tokens, input_mask + + if logits.ndim > 2: + logits = self.reshape(logits, (-1, logits.shape[-1])) + logits = self.cast(logits, ms.float32) + labels = self.reshape(labels, (-1,)) + input_mask = self.reshape(input_mask, (-1,)) + loss = self.loss(logits, labels, input_mask) + return loss + + def add_image_token_pad_in_query(self, query): + query = self.image_prefix + " " + query + " " + return query diff --git a/mindocr/nlp/llm/qwen_model.py b/mindocr/nlp/llm/qwen_model.py index 1e1bee7b8..0cb98c4a0 100644 --- a/mindocr/nlp/llm/qwen_model.py +++ b/mindocr/nlp/llm/qwen_model.py @@ -6,18 +6,20 @@ import mindspore as ms import mindspore.common.dtype as mstype from mindspore import Parameter, Tensor, nn, ops -from mindspore._c_expression import MSContext from mindspore.common.initializer import initializer -from mindspore.nn.layer.flash_attention import FlashAttention from mindocr.nlp.llm.base_llm_model import BaseLLMModel from mindocr.nlp.llm.configs import QwenConfig +from mindocr.nlp.llm.qwen_tokenizer import QwenTokenizer +from mindocr.nlp.utils.flash_attention import FlashAttention from mindocr.nlp.utils.kvcache_mgr import KVCacheMgr, KVCachePreprocess from mindocr.nlp.utils.layers import Linear from mindocr.nlp.utils.loss import CrossEntropyLoss def is_910a(): + from mindspore._c_expression import MSContext + device = MSContext.get_instance().get_ascend_soc_version() return device in ["910a", "ascend910"] @@ -143,99 +145,6 @@ def construct(self, x): return ops.silu(x) -class LlamaFeedForward(nn.Cell): - r""" - LLaMA FeedForward. - - .. math:: - (xW_1 * xW_3)W_2 - - Inputs: - - **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`. - Float tensor. - - Outputs: - Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or - [batch * seq_length, hidden_size]`. - - Raises: - ValueError: `hidden_dim` is not a multiple of the model parallel way. - ValueError: `dim` is not a multiple of the model parallel way. - """ - - def __init__( - self, - dim, - intermediate_size=None, - hidden_dim=None, - multiple_of=256, - hidden_act=LlamaSiLU, - ffn_dim_multiplier=None, - compute_dtype=mstype.float16, - param_init_type=mstype.float32, - is_dynamic=False, - ): - super().__init__() - - if hidden_act is None or not (isinstance(hidden_act, str) or issubclass(hidden_act, nn.Cell)): - raise TypeError( - f"For FeedForward cell, the hidden_act should str type or nn.Cell type, but got {hidden_act}." - ) - - if intermediate_size is not None: - hidden_dim = intermediate_size - else: - if ffn_dim_multiplier is not None: - hidden_dim = int((ffn_dim_multiplier + 0.01) * hidden_dim) - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.dtype = compute_dtype - self.hidden_act = hidden_act - self.dim = dim - self.hidden_dim = hidden_dim - - self.mul = ops.Mul() - self.cast = ops.Cast() - self.w1 = Linear( - in_channels=dim, - out_channels=hidden_dim, - activation=hidden_act, - has_bias=False, - compute_dtype=compute_dtype, - param_init_type=param_init_type, - skip_redistribution=is_dynamic, - ) - - self.w2 = Linear( - in_channels=hidden_dim, - out_channels=dim, - has_bias=False, - compute_dtype=compute_dtype, - param_init_type=param_init_type, - skip_redistribution=is_dynamic, - ) - - self.w3 = Linear( - in_channels=dim, - out_channels=hidden_dim, - has_bias=False, - compute_dtype=compute_dtype, - param_init_type=param_init_type, - skip_redistribution=is_dynamic, - ) - - def construct(self, x): - """Forward process of the FeedForward""" - x = self.cast(x, self.dtype) - # [bs, seq, hidden_dim] or [bs * seq, hidden_dim] - gate = self.w1(x) # dp,1 -> dp, mp - hidden = self.w3(x) # dp,1 -> dp, mp - hidden = self.mul(hidden, gate) # dp,mp -> dp, mp - output = self.w2(hidden) # dp,mp -> dp, 1 - return output - - class LlamaRotaryEmbedding(nn.Cell): r""" Rotary Position Embedding. @@ -467,7 +376,16 @@ def __init__( ) if self.use_flash_attention: - self.flash_attention = FlashAttention(self.head_dim, n_heads, next_block_num=0, high_precision=True) + self.flash_attention = FlashAttention( + head_num=self.n_head, + pre_tokens=65536, + next_tokens=0, + input_layout="BNSD", + keep_prob=1.0, + scale_value=1.0 / (self.head_dim**0.5), + sparse_mode=0, + use_attention_mask=True, + ) if self.use_past: self.kvcache_mgr = KVCacheMgr( @@ -682,6 +600,7 @@ def __init__(self, config=None): self.mul = ops.Mul() self.sub_batch_valid_len = ops.Sub() self.gather = ops.Gather(1) + self.tokenizer = QwenTokenizer(**config.tokenizer) def construct( self, @@ -851,7 +770,7 @@ def construct( # 2. drop hidden_states = self.drop(hidden_states) - # 2. rotary_emb + # 3. rotary_emb bs, seq_len = self.shape(input_ids) if not self.use_past: freqs_cis = self.freqs_mgr() @@ -897,8 +816,6 @@ def __init__( n_heads: int = 8, n_kv_heads: Optional[int] = None, intermediate_size: Optional[int] = None, - multiple_of: int = 256, - ffn_dim_multiplier: Optional[int] = None, norm_eps: float = 1e-5, qkv_concat=False, compute_dtype=mstype.float16, @@ -964,16 +881,6 @@ def __init__( use_rope_slice=use_rope_slice, use_flash_attention=use_flash_attention, ) - self.feed_forward = LlamaFeedForward( - dim=self.hidden_size, - intermediate_size=intermediate_size, - hidden_dim=4 * self.hidden_size, - multiple_of=multiple_of, - ffn_dim_multiplier=ffn_dim_multiplier, - compute_dtype=compute_dtype, - param_init_type=param_init_type, - is_dynamic=is_dynamic, - ) self.feed_forward = QwenFeedForward( dim=self.hidden_size, intermediate_size=intermediate_size, diff --git a/mindocr/nlp/llm/qwen_tokenizer.py b/mindocr/nlp/llm/qwen_tokenizer.py index 8ad4224d2..abe66e65d 100644 --- a/mindocr/nlp/llm/qwen_tokenizer.py +++ b/mindocr/nlp/llm/qwen_tokenizer.py @@ -12,6 +12,16 @@ ENDOFTEXT = "<|endoftext|>" IMSTART = "<|im_start|>" IMEND = "<|im_end|>" +REF_START_TAG = "" +REF_END_TAG = "" +BOX_START_TAG = "" +BOX_END_TAG = "" +QUAD_START_TAG = "" +QUAD_END_TAG = "" +IMAGE_START_TAG = "" +IMAGE_END_TAG = "" +IMAGE_PAD_TAG = "" + EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) SPECIAL_TOKENS = ( ( @@ -20,7 +30,17 @@ IMEND, ) + EXTRAS - + ("", "", "", "", "", "", "", "", "") + + ( + REF_START_TAG, + REF_END_TAG, + BOX_START_TAG, + BOX_END_TAG, + QUAD_START_TAG, + QUAD_END_TAG, + IMAGE_START_TAG, + IMAGE_END_TAG, + IMAGE_PAD_TAG, + ) ) diff --git a/mindocr/nlp/llm/vary_clip_model.py b/mindocr/nlp/llm/vary_clip_model.py index 5c3d22f88..5adeb1c3f 100644 --- a/mindocr/nlp/llm/vary_clip_model.py +++ b/mindocr/nlp/llm/vary_clip_model.py @@ -2,15 +2,157 @@ import mindspore as ms from mindspore import Parameter, Tensor, nn, ops +from mindspore.common.initializer import initializer from mindocr.nlp.utils.layers import LayerNorm, Linear +class LoraAdapter(nn.Cell): + def __init__( + self, + d_model, + out_feat, + r=16, + param_init_type=ms.float32, + compute_dtype=ms.float32, + ): + super().__init__() + self.d_model = d_model + self.out_feat = out_feat + + self.lora_a = Linear( + self.d_model, + r, + has_bias=False, + param_init_type=param_init_type, + compute_dtype=compute_dtype, + ) + self.lora_b = Linear( + r, + self.out_feat, + has_bias=False, + param_init_type=param_init_type, + compute_dtype=compute_dtype, + ) + + def construct(self, x): + down = self.lora_a(x) + up = self.lora_b(down) + output = up + return output + + class QuickGELU(nn.Cell): def construct(self, x: Tensor): return x * ops.sigmoid(1.702 * x) +class VisualAttention(nn.Cell): + """self-attention layer class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, embed_dim, num_heads, lora_repeat_num=4, param_init_type=ms.float32, compute_dtype=ms.float32): + super(VisualAttention, self).__init__() + self.embed_dim = embed_dim + + self.num_heads = num_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = embed_dim // num_heads + self.num_attention_heads_per_partition = num_heads + self.hidden_size_per_partition = embed_dim + + # Strided linear layer. + self.in_proj = Linear(embed_dim, 3 * embed_dim, param_init_type=param_init_type, compute_dtype=compute_dtype) + self.in_proj_lora = [] + for _ in range(lora_repeat_num): + self.in_proj_lora.append( + LoraAdapter( + d_model=embed_dim, + out_feat=3 * embed_dim, + param_init_type=param_init_type, + compute_dtype=compute_dtype, + ) + ) + self.in_proj_lora = nn.CellList(self.in_proj_lora) + + self.out_proj = Linear(embed_dim, embed_dim, param_init_type=param_init_type, compute_dtype=compute_dtype) + self.out_proj_lora = [] + for _ in range(lora_repeat_num): + self.out_proj_lora.append( + LoraAdapter( + d_model=embed_dim, + out_feat=embed_dim, + param_init_type=param_init_type, + compute_dtype=compute_dtype, + ) + ) + self.out_proj_lora = nn.CellList(self.out_proj_lora) + self.norm_factor = self.hidden_size_per_attention_head**0.5 + + def construct(self, query, idx=None): + # query/key/value: [sq, b, h] + sq, b, _ = query.shape + + sk = sq + mixed_x_layer = self.in_proj(query) + if idx is not None: + lora_res = self.in_proj_lora[idx](query) + mixed_x_layer += lora_res + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.shape[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + query_layer, key_layer, value_layer = mixed_x_layer.split(self.hidden_size_per_attention_head, axis=3) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + sq, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ).transpose(1, 0, 2) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view( + sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ).transpose(1, 0, 2) + + q_scaled = query_layer / self.norm_factor + attention_probs = ops.BatchMatMul(transpose_b=True)(q_scaled, key_layer) + attention_probs = ops.softmax(attention_probs, axis=-1) + + value_layer = value_layer.view( + sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ).transpose(1, 0, 2) + + # matmul: [b * np, sq, hn] + context_layer = ops.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view( + b, self.num_attention_heads_per_partition, sq, self.hidden_size_per_attention_head + ) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3) + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.shape[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + output = self.out_proj(context_layer) + if idx is not None: + lora_res = self.out_proj_lora[idx](context_layer) + output += lora_res + + return output + + class CLIPAttention(nn.Cell): """Multi-head attention module for CLIP""" @@ -85,31 +227,78 @@ def __init__( attn_mask: Tensor = None, param_init_type=ms.float32, ln_param_init_type=ms.float32, + compute_dtype=ms.float32, + layernorm_compute_type=ms.float32, + mlp_ratio=4.0, + lora_repeat_num=4, + model_type="clip", ): super().__init__() - self.attn = CLIPAttention(d_model, n_head, param_init_type=param_init_type) - self.ln_1 = LayerNorm((d_model,), eps=1e-5, param_init_type=ln_param_init_type) + self.use_clip = model_type == "clip" + if self.use_clip: + self.attn = CLIPAttention(d_model, n_head, param_init_type=param_init_type) + else: + self.attn = VisualAttention(d_model, n_head, param_init_type=param_init_type, compute_dtype=compute_dtype) + self.ln_1 = LayerNorm((d_model,), eps=1e-6, param_init_type=ln_param_init_type) + mlp_width = int(d_model * mlp_ratio) self.mlp = nn.SequentialCell( OrderedDict( [ - ("c_fc", Linear(d_model, d_model * 4, param_init_type=param_init_type)), - ("gelu", QuickGELU()), - ("c_proj", Linear(d_model * 4, d_model, param_init_type=param_init_type)), + ( + "c_fc", + Linear( + d_model, + mlp_width, + param_init_type=param_init_type, + compute_dtype=compute_dtype, + ), + ), + ("gelu", QuickGELU() if self.use_clip else nn.GELU(approximate=False)), + ( + "c_proj", + Linear( + mlp_width, + d_model, + param_init_type=param_init_type, + compute_dtype=compute_dtype, + ), + ), ] ) ) - self.ln_2 = LayerNorm((d_model,), eps=1e-5, param_init_type=ln_param_init_type) + self.ln_2 = LayerNorm((d_model,), eps=1e-6, param_init_type=ln_param_init_type) self.attn_mask = Parameter(attn_mask) if attn_mask is not None else None - def construct(self, x: Tensor): + self.mlp_lora = [] + if not self.use_clip: + for _ in range(lora_repeat_num): + self.mlp_lora.append( + LoraAdapter( + d_model=d_model, + out_feat=d_model, + r=32, + param_init_type=param_init_type, + compute_dtype=compute_dtype, + ) + ) + self.mlp_lora = nn.CellList(self.mlp_lora) + self.layernorm_compute_type = layernorm_compute_type + + def construct(self, x: Tensor, idx=None): residual0 = x x_type = x.dtype - x = self.ln_1(x.to(ms.float32)).to(x_type) - x = residual0 + self.attn(x) + x = self.ln_1(x.to(self.layernorm_compute_type)).to(x_type) + if self.use_clip: + x = residual0 + self.attn(x) + else: + x = residual0 + self.attn(x, idx) residual1 = x - x = self.ln_2(x.to(ms.float32)).to(x_type) + x = self.ln_2(x.to(self.layernorm_compute_type)).to(x_type) x = residual1 + self.mlp(x) + + if not self.use_clip and idx is not None: + x += self.mlp_lora[idx](residual1) return x @@ -124,6 +313,10 @@ def __init__( attn_mask: Tensor = None, param_init_type=ms.float32, ln_param_init_type=ms.float32, + compute_dtype=ms.float32, + layernorm_compute_type=ms.float32, + mlp_ratio=4.0, + model_type="clip", ): super().__init__() self.width = width @@ -131,22 +324,105 @@ def __init__( self.resblocks = nn.CellList( [ ResidualAttentionBlock( - width, heads, attn_mask, param_init_type=param_init_type, ln_param_init_type=ln_param_init_type + width, + heads, + attn_mask, + param_init_type=param_init_type, + ln_param_init_type=ln_param_init_type, + mlp_ratio=mlp_ratio, + model_type=model_type, + compute_dtype=compute_dtype, + layernorm_compute_type=layernorm_compute_type, ) for _ in range(layers) ] ) - def construct(self, x: Tensor): + def construct(self, x: Tensor, idx=None): encoder_states = () hidden_state = x - for block in self.resblocks: + hidden_state_list = list() + for i, block in enumerate(self.resblocks): encoder_states += (hidden_state,) - hidden_state = block(hidden_state) + dtype = hidden_state.dtype + if i > 20: # After the 20th layer, the error of FP16 becomes unacceptable. + hidden_state = hidden_state.to(ms.float32) + hidden_state = block(hidden_state, idx) + if i > 20: + hidden_state = hidden_state.to(dtype) + hidden_state_list.append(hidden_state) encoder_states += (hidden_state,) return encoder_states +class Resampler(nn.Cell): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__( + self, + grid_size, + embed_dim, + num_heads, + positional_embedding_size=1024, + kv_dim=None, + param_init_type=ms.float32, + ln_param_init_type=ms.float32, + compute_dtype=ms.float32, + layernorm_compute_type=ms.float32, + ): + super().__init__() + self.num_queries = grid_size**2 + self.embed_dim = embed_dim + self.num_heads = num_heads + self.layernorm_compute_type = layernorm_compute_type + + self.pos_embed = Parameter(initializer("zeros", (positional_embedding_size, embed_dim), param_init_type)) + self.pos_embed_unsqueeze = Parameter(initializer("zeros", (256, embed_dim), param_init_type)) + + self.query = Parameter(initializer("zeros", (self.num_queries, embed_dim), param_init_type)) + + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = Linear( + kv_dim, + embed_dim, + has_bias=False, + param_init_type=param_init_type, + compute_dtype=compute_dtype, + ) + else: + self.kv_proj = None + + self.attn = nn.MultiheadAttention(embed_dim, num_heads, dtype=param_init_type) + self.ln_q = LayerNorm(embed_dim, eps=1e-6, param_init_type=ln_param_init_type) + self.ln_kv = LayerNorm(embed_dim, eps=1e-6, param_init_type=ln_param_init_type) + + def construct(self, x, attn_mask=None): + pos_embed = self.pos_embed + + if self.kv_proj is not None: + x = self.kv_proj(x) + x = self.ln_kv(x.to(self.layernorm_compute_type)).to(x.dtype).permute(1, 0, 2) + + n = x.shape[1] + q = self.ln_q(self.query.to(self.layernorm_compute_type)).to(self.query.dtype) + out = self.attn( + self._repeat(q, n) + self.pos_embed_unsqueeze.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask, + )[0] + return out.permute(1, 0, 2) + + @staticmethod + def _repeat(query, n: int): + return query.unsqueeze(1).tile((1, n, 1)) + + class VisionTransformer(nn.Cell): """CLIP module for Vary system""" @@ -161,34 +437,101 @@ def __init__( vision_select_layer: int, param_init_type=ms.float32, ln_param_init_type=ms.float32, + compute_dtype=ms.float32, + layernorm_compute_type=ms.float32, + positional_embedding_size=None, + mlp_ratio=4.0, + model_type="clip", ): super().__init__() + assert model_type in ("clip", "open_clip") + self.use_clip = model_type == "clip" self.input_resolution = input_resolution self.output_dim = output_dim self.vision_select_layer = vision_select_layer - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - has_bias=False, - pad_mode="pad", - weight_init="uniform", - bias_init="uniform", - dtype=param_init_type, - ) + self.layernorm_compute_type = layernorm_compute_type scale = width**-0.5 - self.class_embedding = Parameter((scale * ops.randn(width)).astype(param_init_type)) + if positional_embedding_size is None: + positional_embedding_size = (input_resolution // patch_size) ** 2 + 1 self.positional_embedding = Parameter( - (scale * ops.randn(((input_resolution // patch_size) ** 2 + 1, width))).astype(param_init_type) + (scale * ops.randn((positional_embedding_size, width))).astype(param_init_type) ) self.ln_pre = LayerNorm((width,), eps=1e-5, param_init_type=ln_param_init_type) self.transformer = Transformer( - width, layers, heads, param_init_type=param_init_type, ln_param_init_type=ln_param_init_type + width, + layers, + heads, + param_init_type=param_init_type, + ln_param_init_type=ln_param_init_type, + compute_dtype=compute_dtype, + layernorm_compute_type=layernorm_compute_type, + mlp_ratio=mlp_ratio, + model_type=model_type, ) + if self.use_clip: + self.class_embedding = Parameter((scale * ops.randn(width)).astype(param_init_type)) + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + has_bias=False, + pad_mode="pad", + weight_init="uniform", + bias_init="uniform", + dtype=param_init_type, + ) + else: + self.attn_pool = Resampler( + grid_size=16, + embed_dim=output_dim, + num_heads=output_dim // 128, + positional_embedding_size=positional_embedding_size, + kv_dim=width, + param_init_type=param_init_type, + ln_param_init_type=ln_param_init_type, + compute_dtype=compute_dtype, + layernorm_compute_type=layernorm_compute_type, + ) + self.ln_post = LayerNorm((output_dim,), param_init_type=ln_param_init_type) + self.proj = Parameter(initializer("normal", (output_dim, output_dim), param_init_type)) + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + has_bias=False, + dtype=param_init_type, + ) - def construct(self, x: Tensor): + def construct(self, x: Tensor, idx=None): + if self.use_clip: + x = self._clip_construct(x) + else: + x = self._open_clip_construct(x, idx) + return x + + def _open_clip_construct(self, x, idx=None): + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + self.positional_embedding + + x = self.ln_pre(x.to(self.layernorm_compute_type)).to(x.dtype) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, idx=idx)[-1] + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.attn_pool(x) + x = self.ln_post(x.to(self.layernorm_compute_type)).to(x.dtype) + x = ops.matmul(x, self.proj) + return x + + def _clip_construct(self, x: Tensor): x_type = x.dtype x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape((x.shape[0], x.shape[1], -1)) # shape = [*, width, grid**2] diff --git a/mindocr/nlp/llm/vary_qwen_model.py b/mindocr/nlp/llm/vary_qwen_model.py index 8001bd036..96f4424d3 100644 --- a/mindocr/nlp/llm/vary_qwen_model.py +++ b/mindocr/nlp/llm/vary_qwen_model.py @@ -1,20 +1,20 @@ import mindspore as ms from mindspore import ops +from mindocr.data.transforms.llm_transform import VaryImageProcessor from mindocr.nlp.llm import register_llm from mindocr.nlp.llm.configs import SAMConfig, VaryConfig from mindocr.nlp.llm.qwen_model import QwenForCausalLM, QwenModel from mindocr.nlp.llm.vary_clip_model import build_model from mindocr.nlp.llm.vary_sam_model import SAMEncoder from mindocr.nlp.utils.layers import Linear -from mindocr.utils.conversation import Conversation class VaryQwenModel(QwenModel): def __init__(self, config): super(VaryQwenModel, self).__init__(config) - config = SAMConfig(ln_param_init_type=config.ln_param_init_type) - self.vision_tower_high = SAMEncoder(config) + sam_config = SAMConfig(ln_param_init_type=config.ln_param_init_type) + self.vision_tower_high = SAMEncoder(sam_config) self.vision_tower_high.to_float(ms.float16) self.vision_tower = build_model( @@ -26,7 +26,7 @@ def __init__(self, config): self.mm_projector_vary = Linear(1024, 1024, param_init_type=config.param_init_type) self.image_start_token_pos = 22 - self.num_patches = 256 + self.num_patches = config.num_patches def construct( self, @@ -35,18 +35,18 @@ def construct( batch_valid_length=None, batch_index=None, zactivate_len=None, - image=None, - image_high=None, + image_clip=None, + image_sam=None, ): # 1. wte bs, seq_len = self.shape(input_ids) inputs_embeds = self.wte(input_ids) - if seq_len > 1 and image is not None and image_high is not None: - sam_out = self.vision_tower_high(image_high) + if seq_len > 1 and image_clip is not None and image_sam is not None: + sam_out = self.vision_tower_high(image_sam) sam_out = self.mm_projector_vary(sam_out) - clip_out = self.vision_tower(image) + clip_out = self.vision_tower(image_clip) clip_out = self.mm_projector(clip_out) image_features = ops.concat((clip_out, sam_out), -1) @@ -116,18 +116,20 @@ def __init__(self, config): config = VaryConfig(**config) super(VaryQwenForCausalLM, self).__init__(config) self.transformer = VaryQwenModel(config=config) - self.conversation = None - - self.image_past = None - self.image_high_past = None + self.image_processor = VaryImageProcessor() def prepare_inputs_for_generation(self, input_ids, **kwargs): - image = kwargs.get("image") - image_high = kwargs.get("image_high") + image_path = kwargs.get("image_path") + if image_path is None: + image_clip, image_sam = None, None + else: + image_clip, image_sam = self.image_processor(image_path) + image_clip = ms.Tensor(image_clip, ms.float16) + image_sam = ms.Tensor(image_sam, ms.float16) return { "input_ids": ms.Tensor(input_ids, ms.int32), - "image": ms.Tensor(image, ms.float16) if image is not None else None, - "image_high": ms.Tensor(image_high, ms.float16) if image_high is not None else None, + "image_clip": image_clip, + "image_sam": image_sam, } def construct( @@ -142,8 +144,8 @@ def construct( batch_valid_length=None, batch_index=None, zactivate_len=None, - image=None, - image_high=None, + image_clip=None, + image_sam=None, ): """construct""" bsz, seqlen = input_ids.shape @@ -166,8 +168,8 @@ def construct( batch_valid_length=batch_valid_length, batch_index=batch_index, zactivate_len=zactivate_len, - image=image, - image_high=image_high, + image_clip=image_clip, + image_sam=image_sam, ) pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None if pre_gather: @@ -199,55 +201,3 @@ def construct( input_mask = self.reshape(input_mask, (-1,)) loss = self.loss(logits, labels, input_mask) return loss - - def chat( - self, - tokenizer, - query: str, - image=None, - image_high=None, - ) -> str: - """ - example: - inputs: - query: Provide the ocr results of this image. - image: np.array. - image_high: np.array. - outputs: - response: the modalities of irradiation could be modified... - history: [ - ("user", "Provide the ocr results of this image."), - ("assistant", "the modalities of irradiation could be modified..."), - ] - - """ - if self.conversation is None: - self.conversation = Conversation() - - if image is not None and image_high is not None: - num_patch = 256 - im_start_token = "" - im_end_token = "" - im_patch_token = "" - query = im_start_token + im_patch_token * num_patch + im_end_token + query - self.image_past = image - self.image_high_past = image_high - - self.conversation.add_message(role="user", message=query) - prompt = self.conversation.get_prompt() - - inputs = tokenizer([prompt], max_length=self.seq_length) - input_ids = inputs["input_ids"] - outputs = self.generate(input_ids=input_ids, image=self.image_past, image_high=self.image_high_past) - outputs = tokenizer.decode(outputs, skip_special_tokens=False) - response = outputs[0][len(prompt) :] - - for special_token in tokenizer.special_tokens: - response = response.replace(special_token, "") - self.conversation.add_message(role="assistant", message=response) - - return response - - def reset(self): - if self.conversation is not None: - self.conversation.messages = list() diff --git a/mindocr/nlp/utils/flash_attention.py b/mindocr/nlp/utils/flash_attention.py new file mode 100644 index 000000000..c708fa2ee --- /dev/null +++ b/mindocr/nlp/utils/flash_attention.py @@ -0,0 +1,205 @@ +"""Flash Attention Layer""" +import mindspore.common.dtype as mstype +from mindspore import ops +from mindspore.common.tensor import Tensor +from mindspore.nn.cell import Cell +from mindspore.ops import functional as F +from mindspore.ops.operations.nn_ops import FlashAttentionScore + + +class FlashAttention(Cell): + """Flash Attention Layer. + + This function contains the flash attention primitives used in FlashAttention (see paper) + `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness ` + + Specifically, it includes the following: + + 1. An interface for calling flashattention operation. + 2. Two configuration parameters for enabling local block sparse of flashattention. + + B -- Batch size + S1 -- Sequence length of query. The value ranges from 1 to 32768 and is a multiple of 16. + S2 -- Sequence length of key and value. The value ranges from 1 to 32768 and is a multiple of 16. + N1 -- Num heads of query + N2 -- Num heads of key and value, and N2 must be a factor of N1 + D -- Head size. Support value: 64, 80, 96, 120, 128 and 256. + H1 -- Hidden size of query, which equals to N1 * D + H2 -- Hidden size of key and value, which equals to N2 * D + Args: + head_num (int): The head num of query. + keep_prob (float): The keep probability of dropout. Default: 1.0. + scale_value (float): The scale factor of score. Default: 1.0. + pre_tokens (int): Parameter for sparse computation, represents how many tokens are counted forward. + When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647. + next_tokens (int): Parameter for sparse computation, represents how many tokens are counted backward. + When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647. + input_layout (str): Specifies the layout of input `query`, key and value. The value can be "BSH" or "BNSD". + Default: "BSH". + sparse_mode (int): Indicates sparse mode. Default 0. + - 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed, + and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full attn_mask + matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and nextTokens needs to + be calculated. + - 1: Represents allMask, that is, passing in the complete attn_mask matrix. + - 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left + vertex, and the optimized attn_mask matrix (2048*2048) is required. + - 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower + right vertex, and the optimized attn_mask matrix (2048*2048) is required. + - 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the + optimized attn_mask matrix (2048*2048) is required.. + - 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and + width N is added to the left side. The value of N is obtained by the new input prefix, and the N value of + each Batch axis is different. Not implemented yet. + - 6: Represents the global scenario, not implemented yet. + - 7: Represents the dilated scenario, not implemented yet. + - 8: Represents the block_local scenario, not implemented yet. + use_attention_mask (bool): The value is True if attention_mask is passed. Default: False. + use_alibi_mask (bool): The value is True if alibi_mask is passed. Default: False. + use_mqa (bool): Specifies whether using MQA. Default: False. + dp (int): Data parallel num. + mp (int): Model parallel num. + sp (int): Sequence parallel num. + + Inputs: + - **query** (Tensor[float16, bfloat16]) - The query tensor. + Input tensor of shape :math:`(B, S1, H1)` or `(B, N1, S1, D)`. + - **key** (Tensor[float16, bfloat16]) - The key tensor. + Input tensor of shape :math:`(B, S2, H2)` or `(B, N2, S2, D)`. + - **value** (Tensor[float16, bfloat16]) - The value tensor. + Input tensor of shape :math:`(B, S2, H2)` or `(B, N2, S2, D)`. + - **attn_mask** (Union[Tensor[uint8], None]) - The attention mask tensor. For each element, 0 indicates + retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`, `(B, 1, S1, S2)`, `(S1, S2)` + or (2048, 2048). + - **alibi_mask** (Union[Tensor[float16, bfloat16], None]) - The position embedding code. If S is greater than + 1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of the lower triangle for + memory optimization. + Input tensor of shape :math: `(B, N1, S1, S2)`, `(1, N1, S1, S2)`, `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)` + or (1024, 1024). + - **padding_mask** (None) - Reserved parameter. Not implemented yet. + - **prefix** (Union[Tensor[int64], None]) - N value of each Batch in the prefix sparse calculation scenario. + Not implemented yet. Input tensor of shape :math:`(B,)`. + + Outputs: + - **attention_out** (Tensor[float16, bfloat16]) - The output of attention, its shape, and data type + are the same as the query. + + Supported Platforms: + ``Atlas 800T A2`` + + Examples: + >>> import numpy as np + >>> import math + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> bsz, head_num, seq_len, head_size = 1, 16, 4096, 128 + >>> hidden_size = head_num * head_size + >>> query = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16) + >>> key = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16) + >>> value = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16) + >>> attn_mask = Tensor(np.ones((bsz, 1, seq_len, seq_len)), mstype.uint8) + >>> model = FlashAttention(head_num, + keep_prob=1.0, + scale_value=1.0 / math.sqrt(head_dim), + pre_tokens=2147483647, + next_tokens=2147483647, + input_layout="BSH", + sparse_mode=0, + use_attention_mask=True, + use_alibi_mask=False, + use_mqa=False, + dp=1, + mp=1, + sp=1 + ... ) + >>> output = model(query, key, value, attn_mask) + >>> print(output.shape) + (1, 16, 2048) + """ + + def __init__( + self, + head_num, + keep_prob=1.0, + scale_value=1.0, + pre_tokens=2147483647, + next_tokens=2147483647, + input_layout="BSH", + sparse_mode=0, + use_attention_mask=True, + use_alibi_mask=False, + use_mqa=False, + dp=1, + mp=1, + sp=1, + ): + super(FlashAttention, self).__init__() + self.head_num = head_num + self.enable_dropout = keep_prob < 1.0 + self.input_layout = input_layout + self.sparse_mode = sparse_mode + self.use_alibi_mask = use_alibi_mask + self.use_attention_mask = use_attention_mask + self.use_mqa = use_mqa + self.dp = dp + self.mp = mp + self.sp = sp + + fa_strategies = self._generate_flash_attention_strategy(dp, mp, sp) + self.flash_attention = FlashAttentionScore( + head_num=head_num, + keep_prob=keep_prob, + scale_value=scale_value, + pre_tokens=pre_tokens, + next_tokens=next_tokens, + inner_precise=0, + input_layout=self.input_layout, + sparse_mode=self.sparse_mode, + ).shard(fa_strategies) + if self.use_alibi_mask: + self.alibi_rescale_factor = Tensor([1.0 / scale_value], dtype=mstype.float16) + self.alibi_rescale_mul = ops.Mul().shard(((dp, mp, sp, 1), (1,))) + if self.enable_dropout: + self.keep_prob_tensor = Tensor(keep_prob, dtype=mstype.float16) + self.drop_gen_mask = ops.DropoutGenMask() + + def _generate_flash_attention_strategy(self, dp, mp, sp): + """get FA generate strategies""" + kv_head_split_num = 1 if self.use_mqa else mp + if self.input_layout == "BSH": + fa_strategies = ((dp, sp, mp), (dp, 1, kv_head_split_num), (dp, 1, kv_head_split_num)) + else: + fa_strategies = ((dp, mp, sp, 1), (dp, kv_head_split_num, 1, 1), (dp, kv_head_split_num, 1, 1)) + if self.use_alibi_mask: + fa_strategies += ((dp, mp, sp, 1),) + if self.enable_dropout: + fa_strategies += ((dp, mp, sp, 1),) + if self.use_attention_mask: + if self.sparse_mode in [0, 1]: + fa_strategies += ((dp, 1, sp, 1),) + else: + raise RuntimeError(f"sparse_mode: {self.sparse_mode} is not support currently") + + return fa_strategies + + def construct(self, query, key, value, attn_mask=None, alibi_mask=None, prefix=None, padding_mask=None): + """Forward process of the AttentionMaskMF""" + if self.input_layout == "BSH": + bsz, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + else: + bsz, _, q_seq_len, _ = query.shape + _, _, kv_seq_len, _ = key.shape + if self.enable_dropout: + drop_mask_bits = F.reshape( + self.drop_gen_mask((bsz, self.head_num, q_seq_len, kv_seq_len), self.keep_prob_tensor), + (bsz, self.head_num, q_seq_len, kv_seq_len // 8), + ) + else: + drop_mask_bits = None + if self.use_alibi_mask: + alibi_mask = self.alibi_rescale_mul(alibi_mask, F.cast(self.alibi_rescale_factor, alibi_mask.dtype)) + _, _, _, output = self.flash_attention( + query, key, value, alibi_mask, drop_mask_bits, padding_mask, attn_mask, prefix + ) + return output diff --git a/mindocr/utils/conversation.py b/mindocr/utils/conversation.py index b0002db36..91d9859f8 100644 --- a/mindocr/utils/conversation.py +++ b/mindocr/utils/conversation.py @@ -12,6 +12,7 @@ def __init__( roles: Tuple[str, str] = ("user", "assistant"), messages: List[Tuple[str, str]] = None, sep: str = "<|im_end|>", + generate_mode: bool = False, ): self.system = ( "<|im_start|>{system}\n{message}{sep}\n".format( @@ -25,11 +26,14 @@ def __init__( self.roles = roles self.messages = list() if messages is None else messages self.sep = sep + self.generate_mode = generate_mode def get_messages(self): return self.messages def get_prompt(self): + if self.generate_mode: + return self.messages[-1][1] ret = self.system if self.system else "" for role, message in self.messages: ret += "<|im_start|>{role}\n{message}{sep}\n".format(role=role, message=message, sep=self.sep) diff --git a/tools/infer/text/predict_llm.py b/tools/infer/text/predict_llm.py index 9eb02fb20..d4b309625 100644 --- a/tools/infer/text/predict_llm.py +++ b/tools/infer/text/predict_llm.py @@ -2,22 +2,12 @@ import logging import os -from PIL import Image - import mindspore as ms -from mindocr.data.transforms.llm_transform import image_processor, image_processor_high -from mindocr.nlp.llm.configs import LLMConfig -from mindocr.nlp.llm.qwen_tokenizer import QwenTokenizer -from mindocr.nlp.llm.vary_qwen_model import VaryQwenForCausalLM +from mindocr.nlp.llm import build_llm_model from mindocr.utils.logger import set_logger -def load_image(image_file): - image = Image.open(image_file).convert("RGB") - return image - - def str2bool(v): if isinstance(v, bool): return v @@ -31,10 +21,13 @@ def str2bool(v): def parse_args(): parser = argparse.ArgumentParser(description="Inference Config Args") - parser.add_argument("--image_dir", type=str, required=True, help="image path") + parser.add_argument("--mode", type=int, default=0, help="0 for graph mode, 1 for pynative mode") + parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"]) + parser.add_argument("--image_path", type=str, required=False, help="image path") parser.add_argument("--query", type=str, required=False, default="Provide the ocr results of this image.") parser.add_argument("--config_path", type=str, required=False, default="../../../configs/llm/vary/vary_toy.yaml") parser.add_argument("--chat_mode", type=str2bool, required=False, default=False) + parser.add_argument("--precision_mode", type=str, required=False, default="allow_fp32_to_fp16") args = parser.parse_args() return args @@ -42,49 +35,44 @@ def parse_args(): class LLMGenerator(object): def __init__(self, args): config_path = args.config_path - config = LLMConfig(config_path) ms.set_context( - mode=ms.GRAPH_MODE, - device_target="Ascend", + mode=args.mode, + device_target=args.device_target, enable_graph_kernel=False, graph_kernel_flags="--disable_expand_ops=Softmax,Dropout --enable_parallel_fusion=true " "--reduce_fuse_depth=8 --enable_auto_tensor_inplace=true", - ascend_config={"precision_mode": "must_keep_origin_dtype"}, + ascend_config={"precision_mode": args.precision_mode}, max_call_depth=10000, max_device_memory="58GB", save_graphs=False, save_graphs_path="./graph", device_id=os.environ.get("DEVICE_ID", 0), ) - self.tokenizer = QwenTokenizer(**config.processor.tokenizer) - self.model = VaryQwenForCausalLM.from_pretrained(config_path) + self.model = build_llm_model(config_path) - self.image_dir = args.image_dir + self.image_path = args.image_path self.query = args.query self.seq_length = self.model.seq_length self.chat_mode = args.chat_mode - def _call_one(self, query=None, image=None, image_high=None): - response = self.model.chat(tokenizer=self.tokenizer, query=query, image=image, image_high=image_high) + def _call_one(self, query=None, image_path=None): + response = self.model.chat(query=query, image_path=image_path) print(">" * 100) print(response) print("<" * 100) return response - def __call__(self, query=None, image_dir=None): + def __call__(self, query=None, image_path=None): self.model.reset() is_first_iteration = True if query is None: query = self.query - if image_dir is None: - image_dir = self.image_dir - image = load_image(image_dir) - image_high = image_processor_high(image) - image = image_processor(image) + if image_path is None: + image_path = self.image_path while True: try: if is_first_iteration: - self._call_one(query=query, image=image, image_high=image_high) + self._call_one(query=query, image_path=image_path) if not self.chat_mode: break is_first_iteration = False @@ -93,7 +81,7 @@ def __call__(self, query=None, image_dir=None): query = input() if query == "exit": break - self._call_one(query=query, image=None, image_high=None) + self._call_one(query=query, image_path=image_path) except ValueError as e: if "check your inputs and set max_length larger than your inputs length." in e.args[0]: logging.warning("The input is too long. The conversation is closed.")