Skip to content

Commit

Permalink
Revert "Add registry functions to instantiate models by provider (#428)"
Browse files Browse the repository at this point in the history
This reverts commit ff07682.
  • Loading branch information
svlandeg committed May 17, 2024
1 parent 5b49105 commit 15b761e
Show file tree
Hide file tree
Showing 46 changed files with 97 additions and 363 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ factory = "llm"
labels = ["COMPLIMENT", "INSULT"]

[components.llm.model]
@llm_models = "spacy.OpenAI.v1"
name = "gpt-4"
@llm_models = "spacy.GPT-4.v2"
```

Now run:
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ filterwarnings = [
"ignore:^.*The `construct` method is deprecated.*",
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*",
"ignore:^.*was deprecated in langchain-community.*",
"ignore:^.*was deprecated in LangChain 0.0.1.*",
"ignore:^.*the load_module() method is deprecated and slated for removal in Python 3.12.*"
"ignore:^.*was deprecated in langchain-community.*"
]
markers = [
"external: interacts with a (potentially cost-incurring) third-party API",
Expand Down
3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ langchain>=0.1,<0.2; python_version>="3.9"
openai>=0.27,<=0.28.1; python_version>="3.9"

# Necessary for running all local models on GPU.
# TODO: transformers > 4.38 causes bug in model handling due to unknown factors. To be investigated.
transformers[sentencepiece]>=4.0.0,<=4.38
transformers[sentencepiece]>=4.0.0
torch
einops>=0.4

Expand Down
2 changes: 0 additions & 2 deletions spacy_llm/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from .llama2 import llama2_hf
from .mistral import mistral_hf
from .openllama import openllama_hf
from .registry import huggingface_v1
from .stablelm import stablelm_hf

__all__ = [
"HuggingFace",
"dolly_hf",
"falcon_hf",
"huggingface_v1",
"llama2_hf",
"mistral_hf",
"openllama_hf",
Expand Down
3 changes: 2 additions & 1 deletion spacy_llm/models/hf/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def mistral_hf(
name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Mistral instance that can execute a set of prompts and return
the raw responses.
"""
return Mistral(
name=name, config_init=config_init, config_run=config_run, context_length=8000
Expand Down
51 changes: 0 additions & 51 deletions spacy_llm/models/hf/registry.py

This file was deleted.

1 change: 0 additions & 1 deletion spacy_llm/models/langchain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def query_langchain(
prompts (Iterable[Iterable[Any]]): Prompts to execute.
RETURNS (Iterable[Iterable[Any]]): LLM responses.
"""
assert callable(model)
return [
[model.invoke(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts
]
Expand Down
37 changes: 0 additions & 37 deletions spacy_llm/models/rest/anthropic/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,6 @@
from .model import Anthropic, Endpoints


@registry.llm_models("spacy.Anthropic.v1")
def anthropic_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(),
strict: bool = Anthropic.DEFAULT_STRICT,
max_tries: int = Anthropic.DEFAULT_MAX_TRIES,
interval: float = Anthropic.DEFAULT_INTERVAL,
max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
) -> Anthropic:
"""Returns Anthropic model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (str): Name of model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Anthropic): Instance of Anthropic model.
"""
return Anthropic(
name=name,
endpoint=Endpoints.COMPLETIONS.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


@registry.llm_models("spacy.Claude-2.v2")
def anthropic_claude_2_v2(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down
39 changes: 1 addition & 38 deletions spacy_llm/models/rest/cohere/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,6 @@
from .model import Cohere, Endpoints


@registry.llm_models("spacy.Cohere.v1")
def cohere_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(),
strict: bool = Cohere.DEFAULT_STRICT,
max_tries: int = Cohere.DEFAULT_MAX_TRIES,
interval: float = Cohere.DEFAULT_INTERVAL,
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
) -> Cohere:
"""Returns Cohere model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (str): Name of model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Cohere): Instance of Cohere model.
"""
return Cohere(
name=name,
endpoint=Endpoints.COMPLETION.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


@registry.llm_models("spacy.Command.v2")
def cohere_command_v2(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down Expand Up @@ -93,7 +56,7 @@ def cohere_command(
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Cohere instance for 'command' model using REST to prompt API.
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Name of model to use.
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down
41 changes: 0 additions & 41 deletions spacy_llm/models/rest/openai/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,6 @@

_DEFAULT_TEMPERATURE = 0.0


@registry.llm_models("spacy.OpenAI.v1")
def openai_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
strict: bool = OpenAI.DEFAULT_STRICT,
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
interval: float = OpenAI.DEFAULT_INTERVAL,
max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME,
endpoint: Optional[str] = None,
context_length: Optional[int] = None,
) -> OpenAI:
"""Returns OpenAI model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Model name to use. Can be any model name supported by the OpenAI API.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (OpenAI): OpenAI model instance.
"""
return OpenAI(
name=name,
endpoint=endpoint or Endpoints.CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


"""
Parameter explanations:
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
Expand Down
46 changes: 2 additions & 44 deletions spacy_llm/models/rest/palm/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,6 @@
from .model import Endpoints, PaLM


@registry.llm_models("spacy.Google.v1")
def google_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
strict: bool = PaLM.DEFAULT_STRICT,
max_tries: int = PaLM.DEFAULT_MAX_TRIES,
interval: float = PaLM.DEFAULT_INTERVAL,
max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
endpoint: Optional[str] = None,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Google model instance using REST to prompt API.
name (str): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
endpoint (Optional[str]): Endpoint to use. Defaults to standard endpoint.
RETURNS (PaLM): PaLM model instance.
"""
default_endpoint = (
Endpoints.TEXT.value if name in {"text-bison-001"} else Endpoints.MSG.value
)
return PaLM(
name=name,
endpoint=endpoint or default_endpoint,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=None,
)


@registry.llm_models("spacy.PaLM.v2")
def palm_bison_v2(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
Expand All @@ -60,7 +18,7 @@ def palm_bison_v2(
context_length: Optional[int] = None,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down Expand Up @@ -99,7 +57,7 @@ def palm_bison(
endpoint: Optional[str] = None,
) -> PaLM:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down
7 changes: 3 additions & 4 deletions spacy_llm/pipeline/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
logger.addHandler(logging.NullHandler())

DEFAULT_MODEL_CONFIG = {
"@llm_models": "spacy.GPT-3-5.v3",
"@llm_models": "spacy.GPT-3-5.v2",
"strict": True,
}
DEFAULT_CACHE_CONFIG = {
Expand Down Expand Up @@ -238,7 +238,6 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
else self._task.generate_prompts(noncached_doc_batch),
n_iters + 1,
)

responses_iters = tee(
self._model(
# Ensure that model receives Iterable[Iterable[Any]]. If task doesn't shard, its prompt is wrapped
Expand All @@ -252,7 +251,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
)

for prompt_data, response, doc in zip(
prompts_iters[1], list(responses_iters[0]), noncached_doc_batch
prompts_iters[1], responses_iters[0], noncached_doc_batch
):
logger.debug(
"Generated prompt for doc: %s\n%s",
Expand All @@ -267,7 +266,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
elem[1] if support_sharding else noncached_doc_batch[i]
for i, elem in enumerate(prompts_iters[2])
),
list(responses_iters[1]),
responses_iters[1],
)
)

Expand Down
2 changes: 1 addition & 1 deletion spacy_llm/tests/models/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_cohere_api_response_when_error():
def test_cohere_error_unsupported_model():
"""Ensure graceful handling of error when model is not supported"""
incorrect_model = "x-gpt-3.5-turbo"
with pytest.raises(ValueError, match="Request to Cohere API failed"):
with pytest.raises(ValueError, match="model not found"):
Cohere(
name=incorrect_model,
config={},
Expand Down
Loading

0 comments on commit 15b761e

Please sign in to comment.