Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Phi-2 #410

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions spacy_llm/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .llama2 import llama2_hf
from .mistral import mistral_hf
from .openllama import openllama_hf
from .phi2 import phi2_hf
from .stablelm import stablelm_hf

__all__ = [
Expand All @@ -13,5 +14,6 @@
"llama2_hf",
"mistral_hf",
"openllama_hf",
"phi2_hf",
"stablelm_hf",
]
8 changes: 8 additions & 0 deletions spacy_llm/models/hf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ def __init__(
f"Double-check you specified a valid dtype."
) from ex

# Recognize boolean attributes.
for key, value in self._config_init.items():
if value in ("True", "False"):
self._config_init[key] = False if value == "False" else True
for key, value in self._config_run.items():
if value in ("True", "False"):
self._config_run[key] = False if value == "False" else True

# Init HF model.
HuggingFace.check_installation()
self._check_model()
Expand Down
2 changes: 1 addition & 1 deletion spacy_llm/models/hf/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]:

tokenized_input_ids = [
self._tokenizer(
prompt if not self._is_instruct else f"<s>[INST] {prompt} [/INST]",
prompt if not self._is_instruct else f"[INST] {prompt} [/INST]",
return_tensors="pt",
).input_ids
for prompt in prompts_for_doc
Expand Down
115 changes: 115 additions & 0 deletions spacy_llm/models/hf/phi2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

from confection import SimpleFrozenDict

from ...compat import Literal, transformers
from ...registry.util import registry
from .base import HuggingFace


class Phi2(HuggingFace):
MODEL_NAMES = Literal["phi-2"] # noqa: F722

def __init__(
self,
name: str,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
self._tokenizer: Optional["transformers.AutoTokenizer"] = None
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

def init_model(self) -> "transformers.AutoModelForCausalLM":
"""Sets up HF model and needed utilities.
RETURNS (Any): HF model.
"""
# Initialize tokenizer and model.
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
self._name, trust_remote_code=True
)
init_cfg = self._config_init
device: Optional[str] = None
if "device" in init_cfg:
device = init_cfg.pop("device")

model = transformers.AutoModelForCausalLM.from_pretrained(
self._name, **init_cfg
)
if device:
model.to(device)

return model

def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
assert callable(self._tokenizer)
responses: List[List[str]] = []

for prompts_for_doc in prompts:
tokenized_input_ids = [
self._tokenizer(
prompt, return_tensors="pt", return_attention_mask=False
).input_ids
for prompt in prompts_for_doc
]
tokenized_input_ids = [
tii.to(self._model.device) for tii in tokenized_input_ids
]

assert hasattr(self._model, "generate")
responses.append(
[
self._tokenizer.decode(
self._model.generate(input_ids=tii, **self._config_run)[
:, tii.shape[1] :
][0],
)
for tii in tokenized_input_ids
]
)

return responses

@property
def hf_account(self) -> str:
return "microsoft"

@staticmethod
def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]:
# See https://huggingface.co/microsoft/phi-2#sample-code for recommended setting combinations.
default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs()
return (
{
**default_cfg_init,
"torch_dtype": "auto",
"device_map": "cuda",
"trust_remote_code": True,
},
{
**default_cfg_run,
"max_new_tokens": 200,
},
)


@registry.llm_models("spacy.Phi-2.v1")
def phi2_hf(
name: Phi2.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates OpenLLaMA instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the OpenLLaMA model. Has to be one of OpenLLaMA.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): OpenLLaMA instance that can execute a set of prompts and return
the raw responses.
"""
return Phi2(
name=name, config_init=config_init, config_run=config_run, context_length=2048
)
85 changes: 85 additions & 0 deletions spacy_llm/tests/models/test_phi2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import copy

import pytest
import spacy
from confection import Config # type: ignore[import]
from thinc.compat import has_torch_cuda_gpu

from ...compat import torch

_PIPE_CFG = {
"model": {
"@llm_models": "spacy.Phi-2.v1",
"name": "phi-2",
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
"save_io": True,
}

_NLP_CONFIG = """
[nlp]
lang = "en"
pipeline = ["llm"]
batch_size = 128

[components]

[components.llm]
factory = "llm"
save_io = True

[components.llm.task]
@llm_tasks = "spacy.NoOp.v1"

[components.llm.model]
@llm_models = spacy.Phi-2.v1
name = phi-2
"""


@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_init():
"""Test initialization and simple run."""
nlp = spacy.blank("en")
nlp.add_pipe("llm", config=_PIPE_CFG)
doc = nlp("This is a test.")
torch.cuda.empty_cache()
assert not doc.user_data["llm_io"]["llm"]["response"][0].startswith(
doc.user_data["llm_io"]["llm"]["prompt"][0]
)


@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_init_with_set_config():
"""Test initialization and simple run with changed config."""
nlp = spacy.blank("en")
cfg = copy.deepcopy(_PIPE_CFG)
cfg["model"]["config_run"] = {"max_new_tokens": 32}
nlp.add_pipe("llm", config=cfg)
doc = nlp("This is a test.")
torch.cuda.empty_cache()
assert not doc.user_data["llm_io"]["llm"]["response"][0].startswith(
doc.user_data["llm_io"]["llm"]["prompt"][0]
)


@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_init_from_config():
orig_config = Config().from_str(_NLP_CONFIG)
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True)
assert nlp.pipe_names == ["llm"]
torch.cuda.empty_cache()


@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_invalid_model():
orig_config = Config().from_str(_NLP_CONFIG)
config = copy.deepcopy(orig_config)
config["components"]["llm"]["model"]["name"] = "anything-else"
with pytest.raises(ValueError, match="unexpected value; permitted"):
spacy.util.load_model_from_config(config, auto_fill=True)
torch.cuda.empty_cache()
Loading