Skip to content

Commit

Permalink
fix(model): Fix device MPS load llama error
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Jan 5, 2024
1 parent 757b61d commit 2ebc62b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
1 change: 1 addition & 0 deletions dbgpt/model/llm/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand Down
19 changes: 15 additions & 4 deletions dbgpt/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,22 @@ def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParamete

elif device == "mps":
kwargs = {"torch_dtype": torch.float16}
from dbgpt.model.llm.monkey_patch import (
replace_llama_attn_with_non_inplace_operations,
)

replace_llama_attn_with_non_inplace_operations()
import transformers

version = tuple(int(v) for v in transformers.__version__.split("."))
if version < (4, 35, 0):
from dbgpt.model.llm.monkey_patch import (
replace_llama_attn_with_non_inplace_operations,
)

# NOTE: Recent transformers library seems to fix the mps issue, also
# it has made some changes causing compatibility issues with our
# original patch. So we only apply the patch for older versions.

# Avoid bugs in mps backend by not using in-place operations.
replace_llama_attn_with_non_inplace_operations()

else:
raise ValueError(f"Invalid device: {device}")

Expand Down

0 comments on commit 2ebc62b

Please sign in to comment.