Skip to content

Commit

Permalink
Add provider-specific registry functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Jan 24, 2024
1 parent ad30add commit cd082ed
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 5 deletions.
3 changes: 1 addition & 2 deletions spacy_llm/models/hf/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def mistral_hf(
name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Mistral instance that can execute a set of prompts and return
the raw responses.
RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses.
"""
return Mistral(
name=name, config_init=config_init, config_run=config_run, context_length=8000
Expand Down
47 changes: 47 additions & 0 deletions spacy_llm/models/hf/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Any, Callable, Dict, Iterable, Optional

from confection import SimpleFrozenDict

from ...registry import registry
from .dolly import Dolly
from .falcon import Falcon
from .llama2 import Llama2
from .mistral import Mistral
from .openllama import OpenLLaMA
from .stablelm import StableLM


@registry.llm_models("spacy.HuggingFace.v1")
def huggingface_v1(
name: str,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns HuggingFace model instance.
name (str): Name of model to use.
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Model instance that can execute a set of prompts and return
the raw responses.
"""
model_context_lengths = {
Dolly: 2048,
Falcon: 2048,
Llama2: 4096,
Mistral: 8000,
OpenLLaMA: 2048,
StableLM: 4096,
}

for model_cls, context_length in model_context_lengths.items():
if name in getattr(model_cls, "MODEL_NAMES", {}):
return model_cls(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

raise ValueError(
f"Name {name} could not be associated with any of the supported models. Please check https://spacy.io/api/large-language-models#models-hf to ensure the specified model name is correct."
)
37 changes: 37 additions & 0 deletions spacy_llm/models/rest/anthropic/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,43 @@
from .model import Anthropic, Endpoints


@registry.llm_models("spacy.Anthropic.v1")
def anthropic_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(),
strict: bool = Anthropic.DEFAULT_STRICT,
max_tries: int = Anthropic.DEFAULT_MAX_TRIES,
interval: float = Anthropic.DEFAULT_INTERVAL,
max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
) -> Anthropic:
"""Returns Anthropic model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (str): Name of model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Anthropic): Instance of Anthropic model.
"""
return Anthropic(
name=name,
endpoint=Endpoints.COMPLETIONS.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


@registry.llm_models("spacy.Claude-2.v2")
def anthropic_claude_2_v2(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down
39 changes: 38 additions & 1 deletion spacy_llm/models/rest/cohere/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,43 @@
from .model import Cohere, Endpoints


@registry.llm_models("spacy.Cohere.v1")
def cohere_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(),
strict: bool = Cohere.DEFAULT_STRICT,
max_tries: int = Cohere.DEFAULT_MAX_TRIES,
interval: float = Cohere.DEFAULT_INTERVAL,
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
) -> Cohere:
"""Returns Cohere model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (str): Name of model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Cohere): Instance of Cohere model.
"""
return Cohere(
name=name,
endpoint=Endpoints.COMPLETION.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


@registry.llm_models("spacy.Command.v2")
def cohere_command_v2(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down Expand Up @@ -56,7 +93,7 @@ def cohere_command(
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Cohere instance for 'command' model using REST to prompt API.
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use.
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down
43 changes: 43 additions & 0 deletions spacy_llm/models/rest/openai/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,49 @@

_DEFAULT_TEMPERATURE = 0.0


@registry.llm_models("spacy.OpenAI.v")
def openai_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
strict: bool = OpenAI.DEFAULT_STRICT,
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
interval: float = OpenAI.DEFAULT_INTERVAL,
max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME,
endpoint: Optional[str] = None,
context_length: Optional[int] = None,
) -> OpenAI:
"""Returns OpenAI model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Model name to use. Can be any model name supported by the OpenAI API.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (OpenAI): OpenAI model instance.
DOCS: https://spacy.io/api/large-language-models#models
"""
return OpenAI(
name=name,
endpoint=endpoint or Endpoints.CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


"""
Parameter explanations:
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
Expand Down
46 changes: 44 additions & 2 deletions spacy_llm/models/rest/palm/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,48 @@
from .model import Endpoints, PaLM


@registry.llm_models("spacy.Google.v1")
def google_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
strict: bool = PaLM.DEFAULT_STRICT,
max_tries: int = PaLM.DEFAULT_MAX_TRIES,
interval: float = PaLM.DEFAULT_INTERVAL,
max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
endpoint: Optional[str] = None,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Google model instance using REST to prompt API.
name (str): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
endpoint (Optional[str]): Endpoint to use. Defaults to standard endpoint.
RETURNS (PaLM): PaLM model instance.
"""
default_endpoint = (
Endpoints.TEXT.value if name in {"text-bison-001"} else Endpoints.MSG.value
)
return PaLM(
name=name,
endpoint=endpoint or default_endpoint,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=None,
)


@registry.llm_models("spacy.PaLM.v2")
def palm_bison_v2(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
Expand All @@ -18,7 +60,7 @@ def palm_bison_v2(
context_length: Optional[int] = None,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down Expand Up @@ -57,7 +99,7 @@ def palm_bison(
endpoint: Optional[str] = None,
) -> PaLM:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down

0 comments on commit cd082ed

Please sign in to comment.