diff --git a/configs/llm/monkey/monkey.yaml b/configs/llm/monkey/monkey.yaml
new file mode 100644
index 000000000..385aa5305
--- /dev/null
+++ b/configs/llm/monkey/monkey.yaml
@@ -0,0 +1,57 @@
+model:
+ name: MonkeyQwenForCausalLM
+ batch_size: 1
+ seq_length: 2048
+ hidden_size: 4096
+ num_layers: 32
+ num_heads: 32
+ vocab_size: 151936
+ intermediate_size: 11008
+ rms_norm_eps: 1.0e-6
+ emb_dropout_prob: 0.0
+ eos_token_id: 151643
+ pad_token_id: 151643
+ compute_dtype: "float16"
+ layernorm_compute_type: "float32"
+ softmax_compute_type: "float16"
+ rotary_dtype: "float16"
+ param_init_type: "float16"
+ ln_param_init_type: "float32"
+ use_past: True
+ use_flash_attention: False
+ use_past_shard: False
+ offset: 0
+ checkpoint_name_or_path: "/path/to/monkey.ckpt"
+ repetition_penalty: 1.5
+ max_decode_length: 2048
+ top_k: 0
+ top_p: 0.8
+ do_sample: False
+ max_new_tokens: 250
+ temperature: 0.7
+ num_beams: 1
+ length_penalty: 1
+ num_patches: 1280
+
+ # configuration items copied from Qwen
+ rotary_pct: 1.0
+ rotary_emb_base: 10000
+ kv_channels: 128
+
+ visual:
+ heads: 16
+ image_size: 896
+ image_start_id: 151857
+ layers: 48
+ mlp_ratio: 4.9231
+ output_dim: 4096
+ patch_size: 14
+ width: 1664
+ lora_repeat_num: 4
+ positional_embedding_size: 1024
+ model_type: open_clip
+
+processor:
+ tokenizer:
+ vocab_file: "/path/to/qwen.tiktoken"
+ pad_token: "<|endoftext|>"
diff --git a/configs/llm/vary/vary_toy.yaml b/configs/llm/vary/vary_toy.yaml
index 1b6b59b9b..999f35930 100644
--- a/configs/llm/vary/vary_toy.yaml
+++ b/configs/llm/vary/vary_toy.yaml
@@ -30,6 +30,7 @@ model:
max_new_tokens: 1024
temperature: 1.0
num_beams: 1
+ num_patches: 256
# configuration items copied from Qwen
rotary_pct: 1.0
@@ -37,8 +38,6 @@ model:
kv_channels: 128
processor:
- return_tensors: ms
tokenizer:
vocab_file: "/path/to/qwen.tiktoken"
pad_token: "<|endoftext|>"
- type: QwenProcessor
diff --git a/mindocr/data/transforms/llm_transform.py b/mindocr/data/transforms/llm_transform.py
index 92bfaa35d..d467824f6 100644
--- a/mindocr/data/transforms/llm_transform.py
+++ b/mindocr/data/transforms/llm_transform.py
@@ -28,6 +28,11 @@ def f(im):
return f
+def load_image(image_file):
+ image = Image.open(image_file).convert("RGB")
+ return image
+
+
image_processor_high = alb_wrapper(
alb.Compose(
[
@@ -270,7 +275,7 @@ def __init__(self, image_resolution=224):
self.batch_totensor = BatchToTensor()
self.batch_normalizer = BatchNormalize()
- def preprocess(self, images):
+ def __call__(self, images):
if not self._bhwc_check(images):
images = self.bchw2bhwc(images)
images = self.batch_pilizer(images)
@@ -299,4 +304,64 @@ def _bhwc_check(image_batch):
return False
-image_processor = VaryCLIPImageProcessor().preprocess
+class VarySAMImageProcessor:
+ def __init__(self):
+ self.image_processor_high = alb_wrapper(
+ alb.Compose(
+ [
+ alb.Resize(1024, 1024),
+ alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
+ ]
+ )
+ )
+
+ def __call__(self, images):
+ images = self.image_processor_high(images)
+ return images
+
+
+class VaryImageProcessor:
+ def __init__(self):
+ self.sam_processor = VarySAMImageProcessor()
+ self.clip_processor = VaryCLIPImageProcessor()
+
+ def __call__(self, images):
+ if isinstance(images, str):
+ images = load_image(images)
+ image_clip = self.clip_processor(images)
+ image_sam = self.sam_processor(images)
+ return image_clip, image_sam
+
+
+class MonkeyImageProcessor:
+ def __init__(self):
+ self.resize1 = vision.c_transforms.Resize((896, 896), Inter.PILCUBIC)
+ self.resize2 = vision.Resize((448, 448), Inter.BICUBIC)
+ mean = (0.48145466, 0.4578275, 0.40821073)
+ std = (0.26862954, 0.26130258, 0.27577711)
+ self.normalize = vision.Normalize(mean=mean, std=std, is_hwc=False)
+
+ @staticmethod
+ def sliding_window(images, window_size=(448, 448), stride=448):
+ windows = []
+ for i in range(2):
+ for j in range(2):
+ window = images[
+ :, :, i * stride : i * stride + window_size[0], j * stride : j * stride + window_size[1]
+ ]
+ windows.append(window)
+ return windows
+
+ def __call__(self, images):
+ if isinstance(images, str):
+ images = load_image(images)
+ images = self.resize1(images) # hwc -> hwc
+ images = images / 255.0 # hwc -> hwc
+ images_trans = images.transpose(2, 0, 1) # hwc -> chw
+ images_norm = self.normalize(images_trans) # chw -> chw
+ images_nchw = np.expand_dims(images_norm, 0) # chw -> nchw
+ windows = self.sliding_window(images_nchw, window_size=(448, 448), stride=448) # nchw -> List[nchw]
+ images_norm = images_norm.transpose(1, 2, 0)
+ images_448 = self.resize2(images_norm) # hwc -> hwc
+ images_448 = np.expand_dims(images_448.transpose(2, 0, 1), 0) # hwc -> nchw
+ return windows, images_448
diff --git a/mindocr/nlp/llm/__init__.py b/mindocr/nlp/llm/__init__.py
index ce614c47a..f0a006d5c 100644
--- a/mindocr/nlp/llm/__init__.py
+++ b/mindocr/nlp/llm/__init__.py
@@ -1,3 +1,4 @@
from ._registry import register_llm
from .builder import build_llm_model
+from .monkey_qwen_model import MonkeyQwenForCausalLM
from .vary_qwen_model import VaryQwenForCausalLM
diff --git a/mindocr/nlp/llm/base_llm_model.py b/mindocr/nlp/llm/base_llm_model.py
index d18279cdc..6a3ac658e 100644
--- a/mindocr/nlp/llm/base_llm_model.py
+++ b/mindocr/nlp/llm/base_llm_model.py
@@ -7,6 +7,7 @@
from mindocr.nlp.generation import GeneratorMixin
from mindocr.nlp.llm.builder import build_llm_model
from mindocr.nlp.llm.configs import BaseConfig, LLMConfig
+from mindocr.utils.conversation import Conversation
class BaseLLMModel(nn.Cell, GeneratorMixin):
@@ -24,6 +25,13 @@ class BaseLLMModel(nn.Cell, GeneratorMixin):
def __init__(self, config: BaseConfig, **kwargs):
super(BaseLLMModel, self).__init__(**kwargs)
self.config = config
+ self.conversation = Conversation()
+ self.image_path = None
+ self.IMAGE_START_TAG = ""
+ self.IMAGE_END_TAG = ""
+ self.IMAGE_PAD_TAG = ""
+ self.num_patches = config.num_patches
+ self.image_prefix = f"{self.IMAGE_START_TAG}{self.IMAGE_PAD_TAG * self.num_patches}{self.IMAGE_END_TAG}"
def load_checkpoint(self, config):
"""
@@ -42,11 +50,7 @@ def load_checkpoint(self, config):
if os.path.exists(checkpoint_name_or_path):
param = load_checkpoint(checkpoint_name_or_path)
else:
- raise ValueError(
- f"{checkpoint_name_or_path} is not a supported default model"
- f" or a valid path to checkpoint,"
- f" please select from {self._support_list}."
- )
+ raise ValueError(f"{checkpoint_name_or_path} is not a valid path to checkpoint.")
load_param_into_net(self, param)
@@ -81,3 +85,48 @@ def from_pretrained(cls, pretrained_model_name_or_dir: str, **kwargs):
config_args = cls._get_config_args(pretrained_model_name_or_dir, **kwargs)
model = build_llm_model(config_args.model)
return model
+
+ def reset(self):
+ self.conversation.messages = list()
+
+ def add_image_token_pad_in_query(self, query: str):
+ query = self.image_prefix + query
+ return query
+
+ def chat(
+ self,
+ query: str,
+ image_path: str = None,
+ ) -> str:
+ """
+ If `image_path` is provided, the conversation will be reset.
+ example:
+ inputs:
+ query: Provide the ocr results of this image.
+ image_path: xx/xxx.png.
+ outputs:
+ response: the modalities of irradiation could be modified...
+ history: [
+ ("user", "Provide the ocr results of this image."),
+ ("assistant", "the modalities of irradiation could be modified..."),
+ ]
+ """
+ if image_path is not None:
+ query = self.add_image_token_pad_in_query(query=query)
+ self.image_path = image_path
+ self.reset()
+
+ self.conversation.add_message(role="user", message=query)
+ prompt = self.conversation.get_prompt()
+
+ inputs = self.tokenizer([prompt], max_length=self.seq_length)
+ input_ids = inputs["input_ids"]
+ outputs = self.generate(input_ids=input_ids, image_path=self.image_path)
+ outputs = self.tokenizer.decode(outputs, skip_special_tokens=False)
+ response = outputs[0][len(prompt) :]
+
+ for special_token in self.tokenizer.special_tokens:
+ response = response.replace(special_token, "")
+ self.conversation.add_message(role="assistant", message=response)
+
+ return response
diff --git a/mindocr/nlp/llm/builder.py b/mindocr/nlp/llm/builder.py
index 490311b09..033a8ef78 100644
--- a/mindocr/nlp/llm/builder.py
+++ b/mindocr/nlp/llm/builder.py
@@ -1,4 +1,5 @@
from ._registry import is_llm, is_llm_class, list_llms, llm_class_entrypoint, llm_entrypoint
+from .configs import LLMConfig
__all__ = ["build_llm_model"]
@@ -11,8 +12,16 @@ def build_llm_model(config):
>>> llm_model = build_llm_model(dict(name='VaryQwenForCausalLM'))
>>> print(llm_model)
"""
+ if isinstance(config, dict):
+ config = LLMConfig(**config)
+ elif isinstance(config, str):
+ config = LLMConfig(config)
+ else:
+ raise TypeError(f"config must be str or dict, but got {type(config)}")
+ if "model" in config:
+ config = LLMConfig(**config["model"], **config["processor"])
if "name" not in config:
- raise ValueError("name must in `config`.")
+ raise ValueError("`name` must in `config`.")
name = config["name"]
if is_llm(name):
create_fn = llm_entrypoint(name)
diff --git a/mindocr/nlp/llm/configs.py b/mindocr/nlp/llm/configs.py
index 76dba33cd..33daf7f4c 100644
--- a/mindocr/nlp/llm/configs.py
+++ b/mindocr/nlp/llm/configs.py
@@ -222,9 +222,7 @@ def __init__(
self.num_heads = num_heads
self.max_position_embedding = max_position_embedding if max_position_embedding else seq_length
self.intermediate_size = intermediate_size
- self.multiple_of = multiple_of
self.n_kv_heads = n_kv_heads
- self.ffn_dim_multiplier = ffn_dim_multiplier
self.rms_norm_eps = rms_norm_eps
self.qkv_concat = qkv_concat
self.param_init_type = convert_mstype(param_init_type)
@@ -266,6 +264,11 @@ def __init__(self, **kwargs):
super(VaryConfig, self).__init__(**kwargs)
+class MonkeyConfig(QwenConfig):
+ def __init__(self, **kwargs):
+ super(MonkeyConfig, self).__init__(**kwargs)
+
+
class SAMConfig(BaseConfig):
def __init__(
self,
diff --git a/mindocr/nlp/llm/convert_weight.py b/mindocr/nlp/llm/convert_weight.py
index eabe1d223..5b1c11f13 100644
--- a/mindocr/nlp/llm/convert_weight.py
+++ b/mindocr/nlp/llm/convert_weight.py
@@ -51,6 +51,12 @@ def _name_replace(name: str):
name = name.replace("layer_norm2.bias", "ln_2.beta")
name = name.replace("self_attn", "attn")
name = name.replace("post_layernorm", "vision_model.post_layernorm")
+ if "visual" in name:
+ name = name.replace("attention_norm", "ln_1")
+ name = name.replace("ffn_norm", "ln_2")
+ if "ln_" in name:
+ name = name.replace("weight", "gamma").replace("bias", "beta")
+ name = name.replace("feed_forward.w2", "mlp.c_proj")
# sam
name = name.replace("norm1.weight", "norm1.gamma")
@@ -82,6 +88,9 @@ def convert_pt_to_ms(torch_ckpt_path, output_path, dtype=ms.float16):
state_dict = torch.load(torch_ckpt_path, map_location="cpu")
ckpt_weights = []
for k, v in state_dict.items():
+ if "lora_scale" in k:
+ continue
+
value = pt2ms(v, dtype)
msname = _name_replace(k)
@@ -111,4 +120,4 @@ def convert_pt_to_ms(torch_ckpt_path, output_path, dtype=ms.float16):
args = parser.parse_args()
- convert_pt_to_ms(args.torch_ckpt_path, args.mindspore_ckpt_path, ms.float16)
+ convert_pt_to_ms(args.torch_ckpt_path, args.mindspore_ckpt_path, ms.float32)
diff --git a/mindocr/nlp/llm/monkey_qwen_model.py b/mindocr/nlp/llm/monkey_qwen_model.py
new file mode 100644
index 000000000..46cf936be
--- /dev/null
+++ b/mindocr/nlp/llm/monkey_qwen_model.py
@@ -0,0 +1,227 @@
+import mindspore as ms
+from mindspore import ops
+
+from mindocr.data.transforms.llm_transform import MonkeyImageProcessor
+from mindocr.nlp.llm import register_llm
+from mindocr.nlp.llm.configs import MonkeyConfig
+from mindocr.nlp.llm.qwen_model import QwenForCausalLM, QwenModel
+from mindocr.nlp.llm.vary_clip_model import VisionTransformer
+from mindocr.nlp.utils.layers import Linear
+from mindocr.utils.conversation import Conversation
+
+
+class MonkeyModel(QwenModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.image_start_token_pos = 0
+ self.num_patches = 1280
+
+ self.visual = VisionTransformer(
+ input_resolution=config.visual.get("image_size", 896), # image_size in transformers
+ patch_size=config.visual.get("patch_size", 14), # patch_size in transformers
+ width=config.visual.get("width", 1664), # hidden_size
+ layers=config.visual.get("layers", 48), # num_hidden_layers
+ heads=config.visual.get("heads", 16), # num_attention_heads
+ output_dim=config.visual.get("output_dim", 4096), # projection_dim in transformers
+ vision_select_layer=-2,
+ param_init_type=config.param_init_type,
+ ln_param_init_type=config.ln_param_init_type,
+ positional_embedding_size=config.visual.get("positional_embedding_size", 1024),
+ mlp_ratio=config.visual.get("mlp_ratio", 4.9231),
+ model_type=config.visual.get("model_type", "open_clip"),
+ compute_dtype=config.compute_dtype,
+ layernorm_compute_type=config.layernorm_compute_type,
+ )
+
+ def construct(
+ self,
+ input_ids,
+ init_reset=True,
+ batch_valid_length=None,
+ batch_index=None,
+ zactivate_len=None,
+ windows=None,
+ image=None,
+ ):
+ """construct"""
+ if input_ids is not None:
+ input_shape = input_ids.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ # 1. wte
+ bs, seq_len = self.shape(input_ids)
+ hidden_states = self.wte(input_ids)
+
+ # 2. drop
+ hidden_states = self.drop(hidden_states)
+
+ # image embedding
+ if seq_len > 1 and image is not None and windows is not None:
+ patch_list = []
+ lora_idx = 0
+ for image_patch in windows:
+ patch = self.visual(image_patch, idx=lora_idx)
+ patch_list.append(patch)
+ lora_idx += 1
+ global_feat = self.visual(image)
+
+ local_feat = ops.cat(patch_list, axis=1)
+ image_features = ops.cat([local_feat, global_feat], axis=1)
+
+ if seq_len > 1 and image_features is not None:
+ new_input_embeds = []
+ num_patches = self.num_patches
+ image_start_token_pos = self.image_start_token_pos
+ for i in range(bs):
+ cur_input_embeds = hidden_states[i]
+ per_cur_image_features = image_features[i]
+ cur_input_embeds = ops.cat(
+ (
+ cur_input_embeds[: image_start_token_pos + 1],
+ per_cur_image_features,
+ cur_input_embeds[image_start_token_pos + num_patches + 1 :],
+ ),
+ axis=0,
+ )
+
+ new_input_embeds.append(cur_input_embeds)
+
+ hidden_states = ops.stack(new_input_embeds, axis=0)
+
+ # 3. rotary_emb
+ if not self.use_past:
+ freqs_cis = self.freqs_mgr()
+ mask = self.casual_mask(input_ids) # mask: [bs, seq, seq]
+ mask = self.casual_mask.post_process(mask)
+ kvcache_inputs = None
+ else:
+ if self.is_first_iteration:
+ freqs_cis = self.freqs_mgr(seq_len)
+ mask = self.casual_mask(input_ids) # mask: [bs, seq, seq]
+ else:
+ freqs_cis = self.freqs_mgr.increment(batch_valid_length, bs)
+ if self.is_dynamic and self.is_flexible_shape and not self.use_kvcache_op:
+ mask = self.casual_mask.increment_slice(
+ self.kvcache_preprocess.range,
+ self.kvcache_preprocess.max_cache_length // bs,
+ batch_valid_length,
+ zactivate_len,
+ )
+ else:
+ mask = self.casual_mask.increment(self.kvcache_preprocess.range, batch_valid_length, zactivate_len)
+ mask = self.casual_mask.post_process(mask)
+
+ kvcache_inputs = self.kvcache_preprocess(bs, batch_valid_length, batch_index, zactivate_len)
+
+ # 4. hidden_states
+ for i in range(self.num_hidden_layers):
+ hidden_states = self.layers[i](hidden_states, freqs_cis, mask, kvcache_inputs=kvcache_inputs)
+
+ # 5. ln_f
+ hidden_states = self.ln_f(hidden_states)
+ return hidden_states
+
+
+@register_llm
+class MonkeyQwenForCausalLM(QwenForCausalLM):
+ def __init__(self, config):
+ config = MonkeyConfig(**config)
+ super().__init__(config)
+ self.transformer = MonkeyModel(config)
+ self.lm_head = Linear(
+ config.hidden_size,
+ config.vocab_size,
+ has_bias=False,
+ param_init_type=config.param_init_type,
+ compute_dtype=config.compute_dtype,
+ )
+ self.conversation = Conversation(generate_mode=True)
+ self.image_processor = MonkeyImageProcessor()
+
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
+ image_path = kwargs.get("image_path")
+ if image_path is None:
+ windows, image = None, None
+ else:
+ windows, image = self.image_processor(image_path)
+ windows = ms.Tensor(windows, ms.float16)
+ image = ms.Tensor(image, ms.float16)
+ return {
+ "input_ids": ms.Tensor(input_ids, ms.int32),
+ "windows": windows,
+ "image": image,
+ }
+
+ def construct(
+ self,
+ input_ids,
+ labels=None,
+ input_position=None,
+ position_ids=None,
+ attention_mask=None,
+ input_embeds=None,
+ init_reset=True,
+ batch_valid_length=None,
+ batch_index=None,
+ zactivate_len=None,
+ windows=None,
+ image=None,
+ ):
+ """construct"""
+ bsz, seqlen = input_ids.shape
+ if self.use_past:
+ if not isinstance(batch_valid_length, ms.Tensor):
+ batch_valid_length = self.ones((bsz,), ms.int32)
+ if self.training:
+ tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1))
+ else:
+ tokens = input_ids
+
+ if batch_valid_length is not None:
+ batch_valid_length = self.reshape(batch_valid_length, (-1,))
+ if not self.is_first_iteration:
+ batch_valid_length = self.sub_batch_valid_len(batch_valid_length, 1)
+
+ output = self.transformer(
+ tokens,
+ init_reset=init_reset,
+ batch_valid_length=batch_valid_length,
+ batch_index=batch_index,
+ zactivate_len=zactivate_len,
+ windows=windows,
+ image=image,
+ )
+ pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
+ if pre_gather:
+ output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
+ logits = self.lm_head(output)
+
+ input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), ms.float32)
+ if labels is None:
+ labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1))
+ else:
+ if labels.ndim > 1:
+ if self.training:
+ labels = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1))
+ label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), ms.float32)
+ input_mask = self.mul(input_mask, label_mask)
+
+ if not self.training:
+ if not pre_gather:
+ logits = self.reshape(logits, (bsz, seqlen, -1))
+ logits = self.cast(logits, ms.float32)
+ # makes cast effective to avoid allgather issue in Mindspore1.10
+ input_mask = self.add(input_mask, 1)
+ return logits, tokens, input_mask
+
+ if logits.ndim > 2:
+ logits = self.reshape(logits, (-1, logits.shape[-1]))
+ logits = self.cast(logits, ms.float32)
+ labels = self.reshape(labels, (-1,))
+ input_mask = self.reshape(input_mask, (-1,))
+ loss = self.loss(logits, labels, input_mask)
+ return loss
+
+ def add_image_token_pad_in_query(self, query):
+ query = self.image_prefix + " " + query + " "
+ return query
diff --git a/mindocr/nlp/llm/qwen_model.py b/mindocr/nlp/llm/qwen_model.py
index 1e1bee7b8..0cb98c4a0 100644
--- a/mindocr/nlp/llm/qwen_model.py
+++ b/mindocr/nlp/llm/qwen_model.py
@@ -6,18 +6,20 @@
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore import Parameter, Tensor, nn, ops
-from mindspore._c_expression import MSContext
from mindspore.common.initializer import initializer
-from mindspore.nn.layer.flash_attention import FlashAttention
from mindocr.nlp.llm.base_llm_model import BaseLLMModel
from mindocr.nlp.llm.configs import QwenConfig
+from mindocr.nlp.llm.qwen_tokenizer import QwenTokenizer
+from mindocr.nlp.utils.flash_attention import FlashAttention
from mindocr.nlp.utils.kvcache_mgr import KVCacheMgr, KVCachePreprocess
from mindocr.nlp.utils.layers import Linear
from mindocr.nlp.utils.loss import CrossEntropyLoss
def is_910a():
+ from mindspore._c_expression import MSContext
+
device = MSContext.get_instance().get_ascend_soc_version()
return device in ["910a", "ascend910"]
@@ -143,99 +145,6 @@ def construct(self, x):
return ops.silu(x)
-class LlamaFeedForward(nn.Cell):
- r"""
- LLaMA FeedForward.
-
- .. math::
- (xW_1 * xW_3)W_2
-
- Inputs:
- - **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
- Float tensor.
-
- Outputs:
- Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or
- [batch * seq_length, hidden_size]`.
-
- Raises:
- ValueError: `hidden_dim` is not a multiple of the model parallel way.
- ValueError: `dim` is not a multiple of the model parallel way.
- """
-
- def __init__(
- self,
- dim,
- intermediate_size=None,
- hidden_dim=None,
- multiple_of=256,
- hidden_act=LlamaSiLU,
- ffn_dim_multiplier=None,
- compute_dtype=mstype.float16,
- param_init_type=mstype.float32,
- is_dynamic=False,
- ):
- super().__init__()
-
- if hidden_act is None or not (isinstance(hidden_act, str) or issubclass(hidden_act, nn.Cell)):
- raise TypeError(
- f"For FeedForward cell, the hidden_act should str type or nn.Cell type, but got {hidden_act}."
- )
-
- if intermediate_size is not None:
- hidden_dim = intermediate_size
- else:
- if ffn_dim_multiplier is not None:
- hidden_dim = int((ffn_dim_multiplier + 0.01) * hidden_dim)
- hidden_dim = int(2 * hidden_dim / 3)
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
-
- self.dtype = compute_dtype
- self.hidden_act = hidden_act
- self.dim = dim
- self.hidden_dim = hidden_dim
-
- self.mul = ops.Mul()
- self.cast = ops.Cast()
- self.w1 = Linear(
- in_channels=dim,
- out_channels=hidden_dim,
- activation=hidden_act,
- has_bias=False,
- compute_dtype=compute_dtype,
- param_init_type=param_init_type,
- skip_redistribution=is_dynamic,
- )
-
- self.w2 = Linear(
- in_channels=hidden_dim,
- out_channels=dim,
- has_bias=False,
- compute_dtype=compute_dtype,
- param_init_type=param_init_type,
- skip_redistribution=is_dynamic,
- )
-
- self.w3 = Linear(
- in_channels=dim,
- out_channels=hidden_dim,
- has_bias=False,
- compute_dtype=compute_dtype,
- param_init_type=param_init_type,
- skip_redistribution=is_dynamic,
- )
-
- def construct(self, x):
- """Forward process of the FeedForward"""
- x = self.cast(x, self.dtype)
- # [bs, seq, hidden_dim] or [bs * seq, hidden_dim]
- gate = self.w1(x) # dp,1 -> dp, mp
- hidden = self.w3(x) # dp,1 -> dp, mp
- hidden = self.mul(hidden, gate) # dp,mp -> dp, mp
- output = self.w2(hidden) # dp,mp -> dp, 1
- return output
-
-
class LlamaRotaryEmbedding(nn.Cell):
r"""
Rotary Position Embedding.
@@ -467,7 +376,16 @@ def __init__(
)
if self.use_flash_attention:
- self.flash_attention = FlashAttention(self.head_dim, n_heads, next_block_num=0, high_precision=True)
+ self.flash_attention = FlashAttention(
+ head_num=self.n_head,
+ pre_tokens=65536,
+ next_tokens=0,
+ input_layout="BNSD",
+ keep_prob=1.0,
+ scale_value=1.0 / (self.head_dim**0.5),
+ sparse_mode=0,
+ use_attention_mask=True,
+ )
if self.use_past:
self.kvcache_mgr = KVCacheMgr(
@@ -682,6 +600,7 @@ def __init__(self, config=None):
self.mul = ops.Mul()
self.sub_batch_valid_len = ops.Sub()
self.gather = ops.Gather(1)
+ self.tokenizer = QwenTokenizer(**config.tokenizer)
def construct(
self,
@@ -851,7 +770,7 @@ def construct(
# 2. drop
hidden_states = self.drop(hidden_states)
- # 2. rotary_emb
+ # 3. rotary_emb
bs, seq_len = self.shape(input_ids)
if not self.use_past:
freqs_cis = self.freqs_mgr()
@@ -897,8 +816,6 @@ def __init__(
n_heads: int = 8,
n_kv_heads: Optional[int] = None,
intermediate_size: Optional[int] = None,
- multiple_of: int = 256,
- ffn_dim_multiplier: Optional[int] = None,
norm_eps: float = 1e-5,
qkv_concat=False,
compute_dtype=mstype.float16,
@@ -964,16 +881,6 @@ def __init__(
use_rope_slice=use_rope_slice,
use_flash_attention=use_flash_attention,
)
- self.feed_forward = LlamaFeedForward(
- dim=self.hidden_size,
- intermediate_size=intermediate_size,
- hidden_dim=4 * self.hidden_size,
- multiple_of=multiple_of,
- ffn_dim_multiplier=ffn_dim_multiplier,
- compute_dtype=compute_dtype,
- param_init_type=param_init_type,
- is_dynamic=is_dynamic,
- )
self.feed_forward = QwenFeedForward(
dim=self.hidden_size,
intermediate_size=intermediate_size,
diff --git a/mindocr/nlp/llm/qwen_tokenizer.py b/mindocr/nlp/llm/qwen_tokenizer.py
index 8ad4224d2..abe66e65d 100644
--- a/mindocr/nlp/llm/qwen_tokenizer.py
+++ b/mindocr/nlp/llm/qwen_tokenizer.py
@@ -12,6 +12,16 @@
ENDOFTEXT = "<|endoftext|>"
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
+REF_START_TAG = "["
+REF_END_TAG = "]"
+BOX_START_TAG = ""
+BOX_END_TAG = ""
+QUAD_START_TAG = ""
+QUAD_END_TAG = ""
+IMAGE_START_TAG = ""
+IMAGE_END_TAG = ""
+IMAGE_PAD_TAG = ""
+
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
SPECIAL_TOKENS = (
(
@@ -20,7 +30,17 @@
IMEND,
)
+ EXTRAS
- + ("[", "]", "", "", "", "", "", "", "")
+ + (
+ REF_START_TAG,
+ REF_END_TAG,
+ BOX_START_TAG,
+ BOX_END_TAG,
+ QUAD_START_TAG,
+ QUAD_END_TAG,
+ IMAGE_START_TAG,
+ IMAGE_END_TAG,
+ IMAGE_PAD_TAG,
+ )
)
diff --git a/mindocr/nlp/llm/vary_clip_model.py b/mindocr/nlp/llm/vary_clip_model.py
index 5c3d22f88..5adeb1c3f 100644
--- a/mindocr/nlp/llm/vary_clip_model.py
+++ b/mindocr/nlp/llm/vary_clip_model.py
@@ -2,15 +2,157 @@
import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
+from mindspore.common.initializer import initializer
from mindocr.nlp.utils.layers import LayerNorm, Linear
+class LoraAdapter(nn.Cell):
+ def __init__(
+ self,
+ d_model,
+ out_feat,
+ r=16,
+ param_init_type=ms.float32,
+ compute_dtype=ms.float32,
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.out_feat = out_feat
+
+ self.lora_a = Linear(
+ self.d_model,
+ r,
+ has_bias=False,
+ param_init_type=param_init_type,
+ compute_dtype=compute_dtype,
+ )
+ self.lora_b = Linear(
+ r,
+ self.out_feat,
+ has_bias=False,
+ param_init_type=param_init_type,
+ compute_dtype=compute_dtype,
+ )
+
+ def construct(self, x):
+ down = self.lora_a(x)
+ up = self.lora_b(down)
+ output = up
+ return output
+
+
class QuickGELU(nn.Cell):
def construct(self, x: Tensor):
return x * ops.sigmoid(1.702 * x)
+class VisualAttention(nn.Cell):
+ """self-attention layer class.
+
+ Self-attention layer takes input with size [s, b, h]
+ and returns output of the same size.
+ """
+
+ def __init__(self, embed_dim, num_heads, lora_repeat_num=4, param_init_type=ms.float32, compute_dtype=ms.float32):
+ super(VisualAttention, self).__init__()
+ self.embed_dim = embed_dim
+
+ self.num_heads = num_heads
+
+ # Per attention head and per partition values.
+ self.hidden_size_per_attention_head = embed_dim // num_heads
+ self.num_attention_heads_per_partition = num_heads
+ self.hidden_size_per_partition = embed_dim
+
+ # Strided linear layer.
+ self.in_proj = Linear(embed_dim, 3 * embed_dim, param_init_type=param_init_type, compute_dtype=compute_dtype)
+ self.in_proj_lora = []
+ for _ in range(lora_repeat_num):
+ self.in_proj_lora.append(
+ LoraAdapter(
+ d_model=embed_dim,
+ out_feat=3 * embed_dim,
+ param_init_type=param_init_type,
+ compute_dtype=compute_dtype,
+ )
+ )
+ self.in_proj_lora = nn.CellList(self.in_proj_lora)
+
+ self.out_proj = Linear(embed_dim, embed_dim, param_init_type=param_init_type, compute_dtype=compute_dtype)
+ self.out_proj_lora = []
+ for _ in range(lora_repeat_num):
+ self.out_proj_lora.append(
+ LoraAdapter(
+ d_model=embed_dim,
+ out_feat=embed_dim,
+ param_init_type=param_init_type,
+ compute_dtype=compute_dtype,
+ )
+ )
+ self.out_proj_lora = nn.CellList(self.out_proj_lora)
+ self.norm_factor = self.hidden_size_per_attention_head**0.5
+
+ def construct(self, query, idx=None):
+ # query/key/value: [sq, b, h]
+ sq, b, _ = query.shape
+
+ sk = sq
+ mixed_x_layer = self.in_proj(query)
+ if idx is not None:
+ lora_res = self.in_proj_lora[idx](query)
+ mixed_x_layer += lora_res
+
+ # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
+ new_tensor_shape = mixed_x_layer.shape[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
+ query_layer, key_layer, value_layer = mixed_x_layer.split(self.hidden_size_per_attention_head, axis=3)
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.view(
+ sq, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
+ ).transpose(1, 0, 2)
+ # [sk, b, np, hn] -> [sk, b * np, hn]
+ key_layer = key_layer.view(
+ sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
+ ).transpose(1, 0, 2)
+
+ q_scaled = query_layer / self.norm_factor
+ attention_probs = ops.BatchMatMul(transpose_b=True)(q_scaled, key_layer)
+ attention_probs = ops.softmax(attention_probs, axis=-1)
+
+ value_layer = value_layer.view(
+ sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
+ ).transpose(1, 0, 2)
+
+ # matmul: [b * np, sq, hn]
+ context_layer = ops.bmm(attention_probs, value_layer)
+
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(
+ b, self.num_attention_heads_per_partition, sq, self.hidden_size_per_attention_head
+ )
+
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3)
+
+ # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.shape[:-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ output = self.out_proj(context_layer)
+ if idx is not None:
+ lora_res = self.out_proj_lora[idx](context_layer)
+ output += lora_res
+
+ return output
+
+
class CLIPAttention(nn.Cell):
"""Multi-head attention module for CLIP"""
@@ -85,31 +227,78 @@ def __init__(
attn_mask: Tensor = None,
param_init_type=ms.float32,
ln_param_init_type=ms.float32,
+ compute_dtype=ms.float32,
+ layernorm_compute_type=ms.float32,
+ mlp_ratio=4.0,
+ lora_repeat_num=4,
+ model_type="clip",
):
super().__init__()
- self.attn = CLIPAttention(d_model, n_head, param_init_type=param_init_type)
- self.ln_1 = LayerNorm((d_model,), eps=1e-5, param_init_type=ln_param_init_type)
+ self.use_clip = model_type == "clip"
+ if self.use_clip:
+ self.attn = CLIPAttention(d_model, n_head, param_init_type=param_init_type)
+ else:
+ self.attn = VisualAttention(d_model, n_head, param_init_type=param_init_type, compute_dtype=compute_dtype)
+ self.ln_1 = LayerNorm((d_model,), eps=1e-6, param_init_type=ln_param_init_type)
+ mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.SequentialCell(
OrderedDict(
[
- ("c_fc", Linear(d_model, d_model * 4, param_init_type=param_init_type)),
- ("gelu", QuickGELU()),
- ("c_proj", Linear(d_model * 4, d_model, param_init_type=param_init_type)),
+ (
+ "c_fc",
+ Linear(
+ d_model,
+ mlp_width,
+ param_init_type=param_init_type,
+ compute_dtype=compute_dtype,
+ ),
+ ),
+ ("gelu", QuickGELU() if self.use_clip else nn.GELU(approximate=False)),
+ (
+ "c_proj",
+ Linear(
+ mlp_width,
+ d_model,
+ param_init_type=param_init_type,
+ compute_dtype=compute_dtype,
+ ),
+ ),
]
)
)
- self.ln_2 = LayerNorm((d_model,), eps=1e-5, param_init_type=ln_param_init_type)
+ self.ln_2 = LayerNorm((d_model,), eps=1e-6, param_init_type=ln_param_init_type)
self.attn_mask = Parameter(attn_mask) if attn_mask is not None else None
- def construct(self, x: Tensor):
+ self.mlp_lora = []
+ if not self.use_clip:
+ for _ in range(lora_repeat_num):
+ self.mlp_lora.append(
+ LoraAdapter(
+ d_model=d_model,
+ out_feat=d_model,
+ r=32,
+ param_init_type=param_init_type,
+ compute_dtype=compute_dtype,
+ )
+ )
+ self.mlp_lora = nn.CellList(self.mlp_lora)
+ self.layernorm_compute_type = layernorm_compute_type
+
+ def construct(self, x: Tensor, idx=None):
residual0 = x
x_type = x.dtype
- x = self.ln_1(x.to(ms.float32)).to(x_type)
- x = residual0 + self.attn(x)
+ x = self.ln_1(x.to(self.layernorm_compute_type)).to(x_type)
+ if self.use_clip:
+ x = residual0 + self.attn(x)
+ else:
+ x = residual0 + self.attn(x, idx)
residual1 = x
- x = self.ln_2(x.to(ms.float32)).to(x_type)
+ x = self.ln_2(x.to(self.layernorm_compute_type)).to(x_type)
x = residual1 + self.mlp(x)
+
+ if not self.use_clip and idx is not None:
+ x += self.mlp_lora[idx](residual1)
return x
@@ -124,6 +313,10 @@ def __init__(
attn_mask: Tensor = None,
param_init_type=ms.float32,
ln_param_init_type=ms.float32,
+ compute_dtype=ms.float32,
+ layernorm_compute_type=ms.float32,
+ mlp_ratio=4.0,
+ model_type="clip",
):
super().__init__()
self.width = width
@@ -131,22 +324,105 @@ def __init__(
self.resblocks = nn.CellList(
[
ResidualAttentionBlock(
- width, heads, attn_mask, param_init_type=param_init_type, ln_param_init_type=ln_param_init_type
+ width,
+ heads,
+ attn_mask,
+ param_init_type=param_init_type,
+ ln_param_init_type=ln_param_init_type,
+ mlp_ratio=mlp_ratio,
+ model_type=model_type,
+ compute_dtype=compute_dtype,
+ layernorm_compute_type=layernorm_compute_type,
)
for _ in range(layers)
]
)
- def construct(self, x: Tensor):
+ def construct(self, x: Tensor, idx=None):
encoder_states = ()
hidden_state = x
- for block in self.resblocks:
+ hidden_state_list = list()
+ for i, block in enumerate(self.resblocks):
encoder_states += (hidden_state,)
- hidden_state = block(hidden_state)
+ dtype = hidden_state.dtype
+ if i > 20: # After the 20th layer, the error of FP16 becomes unacceptable.
+ hidden_state = hidden_state.to(ms.float32)
+ hidden_state = block(hidden_state, idx)
+ if i > 20:
+ hidden_state = hidden_state.to(dtype)
+ hidden_state_list.append(hidden_state)
encoder_states += (hidden_state,)
return encoder_states
+class Resampler(nn.Cell):
+ """
+ A 2D perceiver-resampler network with one cross attention layers by
+ (grid_size**2) learnable queries and 2d sincos pos_emb
+ Outputs:
+ A tensor with the shape of (grid_size**2, embed_dim)
+ """
+
+ def __init__(
+ self,
+ grid_size,
+ embed_dim,
+ num_heads,
+ positional_embedding_size=1024,
+ kv_dim=None,
+ param_init_type=ms.float32,
+ ln_param_init_type=ms.float32,
+ compute_dtype=ms.float32,
+ layernorm_compute_type=ms.float32,
+ ):
+ super().__init__()
+ self.num_queries = grid_size**2
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.layernorm_compute_type = layernorm_compute_type
+
+ self.pos_embed = Parameter(initializer("zeros", (positional_embedding_size, embed_dim), param_init_type))
+ self.pos_embed_unsqueeze = Parameter(initializer("zeros", (256, embed_dim), param_init_type))
+
+ self.query = Parameter(initializer("zeros", (self.num_queries, embed_dim), param_init_type))
+
+ if kv_dim is not None and kv_dim != embed_dim:
+ self.kv_proj = Linear(
+ kv_dim,
+ embed_dim,
+ has_bias=False,
+ param_init_type=param_init_type,
+ compute_dtype=compute_dtype,
+ )
+ else:
+ self.kv_proj = None
+
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, dtype=param_init_type)
+ self.ln_q = LayerNorm(embed_dim, eps=1e-6, param_init_type=ln_param_init_type)
+ self.ln_kv = LayerNorm(embed_dim, eps=1e-6, param_init_type=ln_param_init_type)
+
+ def construct(self, x, attn_mask=None):
+ pos_embed = self.pos_embed
+
+ if self.kv_proj is not None:
+ x = self.kv_proj(x)
+ x = self.ln_kv(x.to(self.layernorm_compute_type)).to(x.dtype).permute(1, 0, 2)
+
+ n = x.shape[1]
+ q = self.ln_q(self.query.to(self.layernorm_compute_type)).to(self.query.dtype)
+ out = self.attn(
+ self._repeat(q, n) + self.pos_embed_unsqueeze.unsqueeze(1),
+ x + pos_embed.unsqueeze(1),
+ x,
+ attn_mask=attn_mask,
+ )[0]
+ return out.permute(1, 0, 2)
+
+ @staticmethod
+ def _repeat(query, n: int):
+ return query.unsqueeze(1).tile((1, n, 1))
+
+
class VisionTransformer(nn.Cell):
"""CLIP module for Vary system"""
@@ -161,34 +437,101 @@ def __init__(
vision_select_layer: int,
param_init_type=ms.float32,
ln_param_init_type=ms.float32,
+ compute_dtype=ms.float32,
+ layernorm_compute_type=ms.float32,
+ positional_embedding_size=None,
+ mlp_ratio=4.0,
+ model_type="clip",
):
super().__init__()
+ assert model_type in ("clip", "open_clip")
+ self.use_clip = model_type == "clip"
self.input_resolution = input_resolution
self.output_dim = output_dim
self.vision_select_layer = vision_select_layer
- self.conv1 = nn.Conv2d(
- in_channels=3,
- out_channels=width,
- kernel_size=patch_size,
- stride=patch_size,
- has_bias=False,
- pad_mode="pad",
- weight_init="uniform",
- bias_init="uniform",
- dtype=param_init_type,
- )
+ self.layernorm_compute_type = layernorm_compute_type
scale = width**-0.5
- self.class_embedding = Parameter((scale * ops.randn(width)).astype(param_init_type))
+ if positional_embedding_size is None:
+ positional_embedding_size = (input_resolution // patch_size) ** 2 + 1
self.positional_embedding = Parameter(
- (scale * ops.randn(((input_resolution // patch_size) ** 2 + 1, width))).astype(param_init_type)
+ (scale * ops.randn((positional_embedding_size, width))).astype(param_init_type)
)
self.ln_pre = LayerNorm((width,), eps=1e-5, param_init_type=ln_param_init_type)
self.transformer = Transformer(
- width, layers, heads, param_init_type=param_init_type, ln_param_init_type=ln_param_init_type
+ width,
+ layers,
+ heads,
+ param_init_type=param_init_type,
+ ln_param_init_type=ln_param_init_type,
+ compute_dtype=compute_dtype,
+ layernorm_compute_type=layernorm_compute_type,
+ mlp_ratio=mlp_ratio,
+ model_type=model_type,
)
+ if self.use_clip:
+ self.class_embedding = Parameter((scale * ops.randn(width)).astype(param_init_type))
+ self.conv1 = nn.Conv2d(
+ in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ has_bias=False,
+ pad_mode="pad",
+ weight_init="uniform",
+ bias_init="uniform",
+ dtype=param_init_type,
+ )
+ else:
+ self.attn_pool = Resampler(
+ grid_size=16,
+ embed_dim=output_dim,
+ num_heads=output_dim // 128,
+ positional_embedding_size=positional_embedding_size,
+ kv_dim=width,
+ param_init_type=param_init_type,
+ ln_param_init_type=ln_param_init_type,
+ compute_dtype=compute_dtype,
+ layernorm_compute_type=layernorm_compute_type,
+ )
+ self.ln_post = LayerNorm((output_dim,), param_init_type=ln_param_init_type)
+ self.proj = Parameter(initializer("normal", (output_dim, output_dim), param_init_type))
+ self.conv1 = nn.Conv2d(
+ in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ has_bias=False,
+ dtype=param_init_type,
+ )
- def construct(self, x: Tensor):
+ def construct(self, x: Tensor, idx=None):
+ if self.use_clip:
+ x = self._clip_construct(x)
+ else:
+ x = self._open_clip_construct(x, idx)
+ return x
+
+ def _open_clip_construct(self, x, idx=None):
+ # to patches
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+
+ x = x + self.positional_embedding
+
+ x = self.ln_pre(x.to(self.layernorm_compute_type)).to(x.dtype)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, idx=idx)[-1]
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.attn_pool(x)
+ x = self.ln_post(x.to(self.layernorm_compute_type)).to(x.dtype)
+ x = ops.matmul(x, self.proj)
+ return x
+
+ def _clip_construct(self, x: Tensor):
x_type = x.dtype
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape((x.shape[0], x.shape[1], -1)) # shape = [*, width, grid**2]
diff --git a/mindocr/nlp/llm/vary_qwen_model.py b/mindocr/nlp/llm/vary_qwen_model.py
index 8001bd036..96f4424d3 100644
--- a/mindocr/nlp/llm/vary_qwen_model.py
+++ b/mindocr/nlp/llm/vary_qwen_model.py
@@ -1,20 +1,20 @@
import mindspore as ms
from mindspore import ops
+from mindocr.data.transforms.llm_transform import VaryImageProcessor
from mindocr.nlp.llm import register_llm
from mindocr.nlp.llm.configs import SAMConfig, VaryConfig
from mindocr.nlp.llm.qwen_model import QwenForCausalLM, QwenModel
from mindocr.nlp.llm.vary_clip_model import build_model
from mindocr.nlp.llm.vary_sam_model import SAMEncoder
from mindocr.nlp.utils.layers import Linear
-from mindocr.utils.conversation import Conversation
class VaryQwenModel(QwenModel):
def __init__(self, config):
super(VaryQwenModel, self).__init__(config)
- config = SAMConfig(ln_param_init_type=config.ln_param_init_type)
- self.vision_tower_high = SAMEncoder(config)
+ sam_config = SAMConfig(ln_param_init_type=config.ln_param_init_type)
+ self.vision_tower_high = SAMEncoder(sam_config)
self.vision_tower_high.to_float(ms.float16)
self.vision_tower = build_model(
@@ -26,7 +26,7 @@ def __init__(self, config):
self.mm_projector_vary = Linear(1024, 1024, param_init_type=config.param_init_type)
self.image_start_token_pos = 22
- self.num_patches = 256
+ self.num_patches = config.num_patches
def construct(
self,
@@ -35,18 +35,18 @@ def construct(
batch_valid_length=None,
batch_index=None,
zactivate_len=None,
- image=None,
- image_high=None,
+ image_clip=None,
+ image_sam=None,
):
# 1. wte
bs, seq_len = self.shape(input_ids)
inputs_embeds = self.wte(input_ids)
- if seq_len > 1 and image is not None and image_high is not None:
- sam_out = self.vision_tower_high(image_high)
+ if seq_len > 1 and image_clip is not None and image_sam is not None:
+ sam_out = self.vision_tower_high(image_sam)
sam_out = self.mm_projector_vary(sam_out)
- clip_out = self.vision_tower(image)
+ clip_out = self.vision_tower(image_clip)
clip_out = self.mm_projector(clip_out)
image_features = ops.concat((clip_out, sam_out), -1)
@@ -116,18 +116,20 @@ def __init__(self, config):
config = VaryConfig(**config)
super(VaryQwenForCausalLM, self).__init__(config)
self.transformer = VaryQwenModel(config=config)
- self.conversation = None
-
- self.image_past = None
- self.image_high_past = None
+ self.image_processor = VaryImageProcessor()
def prepare_inputs_for_generation(self, input_ids, **kwargs):
- image = kwargs.get("image")
- image_high = kwargs.get("image_high")
+ image_path = kwargs.get("image_path")
+ if image_path is None:
+ image_clip, image_sam = None, None
+ else:
+ image_clip, image_sam = self.image_processor(image_path)
+ image_clip = ms.Tensor(image_clip, ms.float16)
+ image_sam = ms.Tensor(image_sam, ms.float16)
return {
"input_ids": ms.Tensor(input_ids, ms.int32),
- "image": ms.Tensor(image, ms.float16) if image is not None else None,
- "image_high": ms.Tensor(image_high, ms.float16) if image_high is not None else None,
+ "image_clip": image_clip,
+ "image_sam": image_sam,
}
def construct(
@@ -142,8 +144,8 @@ def construct(
batch_valid_length=None,
batch_index=None,
zactivate_len=None,
- image=None,
- image_high=None,
+ image_clip=None,
+ image_sam=None,
):
"""construct"""
bsz, seqlen = input_ids.shape
@@ -166,8 +168,8 @@ def construct(
batch_valid_length=batch_valid_length,
batch_index=batch_index,
zactivate_len=zactivate_len,
- image=image,
- image_high=image_high,
+ image_clip=image_clip,
+ image_sam=image_sam,
)
pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
if pre_gather:
@@ -199,55 +201,3 @@ def construct(
input_mask = self.reshape(input_mask, (-1,))
loss = self.loss(logits, labels, input_mask)
return loss
-
- def chat(
- self,
- tokenizer,
- query: str,
- image=None,
- image_high=None,
- ) -> str:
- """
- example:
- inputs:
- query: Provide the ocr results of this image.
- image: np.array.
- image_high: np.array.
- outputs:
- response: the modalities of irradiation could be modified...
- history: [
- ("user", "Provide the ocr results of this image."),
- ("assistant", "the modalities of irradiation could be modified..."),
- ]
-
- """
- if self.conversation is None:
- self.conversation = Conversation()
-
- if image is not None and image_high is not None:
- num_patch = 256
- im_start_token = ""
- im_end_token = ""
- im_patch_token = ""
- query = im_start_token + im_patch_token * num_patch + im_end_token + query
- self.image_past = image
- self.image_high_past = image_high
-
- self.conversation.add_message(role="user", message=query)
- prompt = self.conversation.get_prompt()
-
- inputs = tokenizer([prompt], max_length=self.seq_length)
- input_ids = inputs["input_ids"]
- outputs = self.generate(input_ids=input_ids, image=self.image_past, image_high=self.image_high_past)
- outputs = tokenizer.decode(outputs, skip_special_tokens=False)
- response = outputs[0][len(prompt) :]
-
- for special_token in tokenizer.special_tokens:
- response = response.replace(special_token, "")
- self.conversation.add_message(role="assistant", message=response)
-
- return response
-
- def reset(self):
- if self.conversation is not None:
- self.conversation.messages = list()
diff --git a/mindocr/nlp/utils/flash_attention.py b/mindocr/nlp/utils/flash_attention.py
new file mode 100644
index 000000000..c708fa2ee
--- /dev/null
+++ b/mindocr/nlp/utils/flash_attention.py
@@ -0,0 +1,205 @@
+"""Flash Attention Layer"""
+import mindspore.common.dtype as mstype
+from mindspore import ops
+from mindspore.common.tensor import Tensor
+from mindspore.nn.cell import Cell
+from mindspore.ops import functional as F
+from mindspore.ops.operations.nn_ops import FlashAttentionScore
+
+
+class FlashAttention(Cell):
+ """Flash Attention Layer.
+
+ This function contains the flash attention primitives used in FlashAttention (see paper)
+ `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `
+
+ Specifically, it includes the following:
+
+ 1. An interface for calling flashattention operation.
+ 2. Two configuration parameters for enabling local block sparse of flashattention.
+
+ B -- Batch size
+ S1 -- Sequence length of query. The value ranges from 1 to 32768 and is a multiple of 16.
+ S2 -- Sequence length of key and value. The value ranges from 1 to 32768 and is a multiple of 16.
+ N1 -- Num heads of query
+ N2 -- Num heads of key and value, and N2 must be a factor of N1
+ D -- Head size. Support value: 64, 80, 96, 120, 128 and 256.
+ H1 -- Hidden size of query, which equals to N1 * D
+ H2 -- Hidden size of key and value, which equals to N2 * D
+ Args:
+ head_num (int): The head num of query.
+ keep_prob (float): The keep probability of dropout. Default: 1.0.
+ scale_value (float): The scale factor of score. Default: 1.0.
+ pre_tokens (int): Parameter for sparse computation, represents how many tokens are counted forward.
+ When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
+ next_tokens (int): Parameter for sparse computation, represents how many tokens are counted backward.
+ When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
+ input_layout (str): Specifies the layout of input `query`, key and value. The value can be "BSH" or "BNSD".
+ Default: "BSH".
+ sparse_mode (int): Indicates sparse mode. Default 0.
+ - 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed,
+ and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full attn_mask
+ matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and nextTokens needs to
+ be calculated.
+ - 1: Represents allMask, that is, passing in the complete attn_mask matrix.
+ - 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left
+ vertex, and the optimized attn_mask matrix (2048*2048) is required.
+ - 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower
+ right vertex, and the optimized attn_mask matrix (2048*2048) is required.
+ - 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the
+ optimized attn_mask matrix (2048*2048) is required..
+ - 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and
+ width N is added to the left side. The value of N is obtained by the new input prefix, and the N value of
+ each Batch axis is different. Not implemented yet.
+ - 6: Represents the global scenario, not implemented yet.
+ - 7: Represents the dilated scenario, not implemented yet.
+ - 8: Represents the block_local scenario, not implemented yet.
+ use_attention_mask (bool): The value is True if attention_mask is passed. Default: False.
+ use_alibi_mask (bool): The value is True if alibi_mask is passed. Default: False.
+ use_mqa (bool): Specifies whether using MQA. Default: False.
+ dp (int): Data parallel num.
+ mp (int): Model parallel num.
+ sp (int): Sequence parallel num.
+
+ Inputs:
+ - **query** (Tensor[float16, bfloat16]) - The query tensor.
+ Input tensor of shape :math:`(B, S1, H1)` or `(B, N1, S1, D)`.
+ - **key** (Tensor[float16, bfloat16]) - The key tensor.
+ Input tensor of shape :math:`(B, S2, H2)` or `(B, N2, S2, D)`.
+ - **value** (Tensor[float16, bfloat16]) - The value tensor.
+ Input tensor of shape :math:`(B, S2, H2)` or `(B, N2, S2, D)`.
+ - **attn_mask** (Union[Tensor[uint8], None]) - The attention mask tensor. For each element, 0 indicates
+ retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`, `(B, 1, S1, S2)`, `(S1, S2)`
+ or (2048, 2048).
+ - **alibi_mask** (Union[Tensor[float16, bfloat16], None]) - The position embedding code. If S is greater than
+ 1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of the lower triangle for
+ memory optimization.
+ Input tensor of shape :math: `(B, N1, S1, S2)`, `(1, N1, S1, S2)`, `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`
+ or (1024, 1024).
+ - **padding_mask** (None) - Reserved parameter. Not implemented yet.
+ - **prefix** (Union[Tensor[int64], None]) - N value of each Batch in the prefix sparse calculation scenario.
+ Not implemented yet. Input tensor of shape :math:`(B,)`.
+
+ Outputs:
+ - **attention_out** (Tensor[float16, bfloat16]) - The output of attention, its shape, and data type
+ are the same as the query.
+
+ Supported Platforms:
+ ``Atlas 800T A2``
+
+ Examples:
+ >>> import numpy as np
+ >>> import math
+ >>> from mindspore import dtype as mstype
+ >>> from mindspore import Tensor
+ >>> bsz, head_num, seq_len, head_size = 1, 16, 4096, 128
+ >>> hidden_size = head_num * head_size
+ >>> query = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
+ >>> key = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
+ >>> value = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
+ >>> attn_mask = Tensor(np.ones((bsz, 1, seq_len, seq_len)), mstype.uint8)
+ >>> model = FlashAttention(head_num,
+ keep_prob=1.0,
+ scale_value=1.0 / math.sqrt(head_dim),
+ pre_tokens=2147483647,
+ next_tokens=2147483647,
+ input_layout="BSH",
+ sparse_mode=0,
+ use_attention_mask=True,
+ use_alibi_mask=False,
+ use_mqa=False,
+ dp=1,
+ mp=1,
+ sp=1
+ ... )
+ >>> output = model(query, key, value, attn_mask)
+ >>> print(output.shape)
+ (1, 16, 2048)
+ """
+
+ def __init__(
+ self,
+ head_num,
+ keep_prob=1.0,
+ scale_value=1.0,
+ pre_tokens=2147483647,
+ next_tokens=2147483647,
+ input_layout="BSH",
+ sparse_mode=0,
+ use_attention_mask=True,
+ use_alibi_mask=False,
+ use_mqa=False,
+ dp=1,
+ mp=1,
+ sp=1,
+ ):
+ super(FlashAttention, self).__init__()
+ self.head_num = head_num
+ self.enable_dropout = keep_prob < 1.0
+ self.input_layout = input_layout
+ self.sparse_mode = sparse_mode
+ self.use_alibi_mask = use_alibi_mask
+ self.use_attention_mask = use_attention_mask
+ self.use_mqa = use_mqa
+ self.dp = dp
+ self.mp = mp
+ self.sp = sp
+
+ fa_strategies = self._generate_flash_attention_strategy(dp, mp, sp)
+ self.flash_attention = FlashAttentionScore(
+ head_num=head_num,
+ keep_prob=keep_prob,
+ scale_value=scale_value,
+ pre_tokens=pre_tokens,
+ next_tokens=next_tokens,
+ inner_precise=0,
+ input_layout=self.input_layout,
+ sparse_mode=self.sparse_mode,
+ ).shard(fa_strategies)
+ if self.use_alibi_mask:
+ self.alibi_rescale_factor = Tensor([1.0 / scale_value], dtype=mstype.float16)
+ self.alibi_rescale_mul = ops.Mul().shard(((dp, mp, sp, 1), (1,)))
+ if self.enable_dropout:
+ self.keep_prob_tensor = Tensor(keep_prob, dtype=mstype.float16)
+ self.drop_gen_mask = ops.DropoutGenMask()
+
+ def _generate_flash_attention_strategy(self, dp, mp, sp):
+ """get FA generate strategies"""
+ kv_head_split_num = 1 if self.use_mqa else mp
+ if self.input_layout == "BSH":
+ fa_strategies = ((dp, sp, mp), (dp, 1, kv_head_split_num), (dp, 1, kv_head_split_num))
+ else:
+ fa_strategies = ((dp, mp, sp, 1), (dp, kv_head_split_num, 1, 1), (dp, kv_head_split_num, 1, 1))
+ if self.use_alibi_mask:
+ fa_strategies += ((dp, mp, sp, 1),)
+ if self.enable_dropout:
+ fa_strategies += ((dp, mp, sp, 1),)
+ if self.use_attention_mask:
+ if self.sparse_mode in [0, 1]:
+ fa_strategies += ((dp, 1, sp, 1),)
+ else:
+ raise RuntimeError(f"sparse_mode: {self.sparse_mode} is not support currently")
+
+ return fa_strategies
+
+ def construct(self, query, key, value, attn_mask=None, alibi_mask=None, prefix=None, padding_mask=None):
+ """Forward process of the AttentionMaskMF"""
+ if self.input_layout == "BSH":
+ bsz, q_seq_len, _ = query.shape
+ _, kv_seq_len, _ = key.shape
+ else:
+ bsz, _, q_seq_len, _ = query.shape
+ _, _, kv_seq_len, _ = key.shape
+ if self.enable_dropout:
+ drop_mask_bits = F.reshape(
+ self.drop_gen_mask((bsz, self.head_num, q_seq_len, kv_seq_len), self.keep_prob_tensor),
+ (bsz, self.head_num, q_seq_len, kv_seq_len // 8),
+ )
+ else:
+ drop_mask_bits = None
+ if self.use_alibi_mask:
+ alibi_mask = self.alibi_rescale_mul(alibi_mask, F.cast(self.alibi_rescale_factor, alibi_mask.dtype))
+ _, _, _, output = self.flash_attention(
+ query, key, value, alibi_mask, drop_mask_bits, padding_mask, attn_mask, prefix
+ )
+ return output
diff --git a/mindocr/utils/conversation.py b/mindocr/utils/conversation.py
index b0002db36..91d9859f8 100644
--- a/mindocr/utils/conversation.py
+++ b/mindocr/utils/conversation.py
@@ -12,6 +12,7 @@ def __init__(
roles: Tuple[str, str] = ("user", "assistant"),
messages: List[Tuple[str, str]] = None,
sep: str = "<|im_end|>",
+ generate_mode: bool = False,
):
self.system = (
"<|im_start|>{system}\n{message}{sep}\n".format(
@@ -25,11 +26,14 @@ def __init__(
self.roles = roles
self.messages = list() if messages is None else messages
self.sep = sep
+ self.generate_mode = generate_mode
def get_messages(self):
return self.messages
def get_prompt(self):
+ if self.generate_mode:
+ return self.messages[-1][1]
ret = self.system if self.system else ""
for role, message in self.messages:
ret += "<|im_start|>{role}\n{message}{sep}\n".format(role=role, message=message, sep=self.sep)
diff --git a/tools/infer/text/predict_llm.py b/tools/infer/text/predict_llm.py
index 9eb02fb20..d4b309625 100644
--- a/tools/infer/text/predict_llm.py
+++ b/tools/infer/text/predict_llm.py
@@ -2,22 +2,12 @@
import logging
import os
-from PIL import Image
-
import mindspore as ms
-from mindocr.data.transforms.llm_transform import image_processor, image_processor_high
-from mindocr.nlp.llm.configs import LLMConfig
-from mindocr.nlp.llm.qwen_tokenizer import QwenTokenizer
-from mindocr.nlp.llm.vary_qwen_model import VaryQwenForCausalLM
+from mindocr.nlp.llm import build_llm_model
from mindocr.utils.logger import set_logger
-def load_image(image_file):
- image = Image.open(image_file).convert("RGB")
- return image
-
-
def str2bool(v):
if isinstance(v, bool):
return v
@@ -31,10 +21,13 @@ def str2bool(v):
def parse_args():
parser = argparse.ArgumentParser(description="Inference Config Args")
- parser.add_argument("--image_dir", type=str, required=True, help="image path")
+ parser.add_argument("--mode", type=int, default=0, help="0 for graph mode, 1 for pynative mode")
+ parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"])
+ parser.add_argument("--image_path", type=str, required=False, help="image path")
parser.add_argument("--query", type=str, required=False, default="Provide the ocr results of this image.")
parser.add_argument("--config_path", type=str, required=False, default="../../../configs/llm/vary/vary_toy.yaml")
parser.add_argument("--chat_mode", type=str2bool, required=False, default=False)
+ parser.add_argument("--precision_mode", type=str, required=False, default="allow_fp32_to_fp16")
args = parser.parse_args()
return args
@@ -42,49 +35,44 @@ def parse_args():
class LLMGenerator(object):
def __init__(self, args):
config_path = args.config_path
- config = LLMConfig(config_path)
ms.set_context(
- mode=ms.GRAPH_MODE,
- device_target="Ascend",
+ mode=args.mode,
+ device_target=args.device_target,
enable_graph_kernel=False,
graph_kernel_flags="--disable_expand_ops=Softmax,Dropout --enable_parallel_fusion=true "
"--reduce_fuse_depth=8 --enable_auto_tensor_inplace=true",
- ascend_config={"precision_mode": "must_keep_origin_dtype"},
+ ascend_config={"precision_mode": args.precision_mode},
max_call_depth=10000,
max_device_memory="58GB",
save_graphs=False,
save_graphs_path="./graph",
device_id=os.environ.get("DEVICE_ID", 0),
)
- self.tokenizer = QwenTokenizer(**config.processor.tokenizer)
- self.model = VaryQwenForCausalLM.from_pretrained(config_path)
+ self.model = build_llm_model(config_path)
- self.image_dir = args.image_dir
+ self.image_path = args.image_path
self.query = args.query
self.seq_length = self.model.seq_length
self.chat_mode = args.chat_mode
- def _call_one(self, query=None, image=None, image_high=None):
- response = self.model.chat(tokenizer=self.tokenizer, query=query, image=image, image_high=image_high)
+ def _call_one(self, query=None, image_path=None):
+ response = self.model.chat(query=query, image_path=image_path)
print(">" * 100)
print(response)
print("<" * 100)
return response
- def __call__(self, query=None, image_dir=None):
+ def __call__(self, query=None, image_path=None):
self.model.reset()
is_first_iteration = True
if query is None:
query = self.query
- if image_dir is None:
- image_dir = self.image_dir
- image = load_image(image_dir)
- image_high = image_processor_high(image)
- image = image_processor(image)
+ if image_path is None:
+ image_path = self.image_path
while True:
try:
if is_first_iteration:
- self._call_one(query=query, image=image, image_high=image_high)
+ self._call_one(query=query, image_path=image_path)
if not self.chat_mode:
break
is_first_iteration = False
@@ -93,7 +81,7 @@ def __call__(self, query=None, image_dir=None):
query = input()
if query == "exit":
break
- self._call_one(query=query, image=None, image_high=None)
+ self._call_one(query=query, image_path=image_path)
except ValueError as e:
if "check your inputs and set max_length larger than your inputs length." in e.args[0]:
logging.warning("The input is too long. The conversation is closed.")