From 2ebc62b7bd6d477ab371a12800032cbc10ec82f4 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 5 Jan 2024 11:36:28 +0800 Subject: [PATCH] fix(model): Fix device MPS load llama error --- dbgpt/model/llm/monkey_patch.py | 1 + dbgpt/model/loader.py | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/dbgpt/model/llm/monkey_patch.py b/dbgpt/model/llm/monkey_patch.py index f3656a159..f9c2b3119 100644 --- a/dbgpt/model/llm/monkey_patch.py +++ b/dbgpt/model/llm/monkey_patch.py @@ -34,6 +34,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() diff --git a/dbgpt/model/loader.py b/dbgpt/model/loader.py index 4ed42d630..cfd6e0c3c 100644 --- a/dbgpt/model/loader.py +++ b/dbgpt/model/loader.py @@ -166,11 +166,22 @@ def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParamete elif device == "mps": kwargs = {"torch_dtype": torch.float16} - from dbgpt.model.llm.monkey_patch import ( - replace_llama_attn_with_non_inplace_operations, - ) - replace_llama_attn_with_non_inplace_operations() + import transformers + + version = tuple(int(v) for v in transformers.__version__.split(".")) + if version < (4, 35, 0): + from dbgpt.model.llm.monkey_patch import ( + replace_llama_attn_with_non_inplace_operations, + ) + + # NOTE: Recent transformers library seems to fix the mps issue, also + # it has made some changes causing compatibility issues with our + # original patch. So we only apply the patch for older versions. + + # Avoid bugs in mps backend by not using in-place operations. + replace_llama_attn_with_non_inplace_operations() + else: raise ValueError(f"Invalid device: {device}")