From cd082edca8dd1cfc04de889fe4673660870592bc Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 24 Jan 2024 17:08:19 +0100 Subject: [PATCH] Add provider-specific registry functions. --- spacy_llm/models/hf/mistral.py | 3 +- spacy_llm/models/hf/registry.py | 47 +++++++++++++++++++++ spacy_llm/models/rest/anthropic/registry.py | 37 ++++++++++++++++ spacy_llm/models/rest/cohere/registry.py | 39 ++++++++++++++++- spacy_llm/models/rest/openai/registry.py | 43 +++++++++++++++++++ spacy_llm/models/rest/palm/registry.py | 46 +++++++++++++++++++- 6 files changed, 210 insertions(+), 5 deletions(-) create mode 100644 spacy_llm/models/hf/registry.py diff --git a/spacy_llm/models/hf/mistral.py b/spacy_llm/models/hf/mistral.py index c80d636e..9e7b06c5 100644 --- a/spacy_llm/models/hf/mistral.py +++ b/spacy_llm/models/hf/mistral.py @@ -99,8 +99,7 @@ def mistral_hf( name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. config_run (Optional[Dict[str, Any]]): HF config for running the model. - RETURNS (Callable[[Iterable[str]], Iterable[str]]): Mistral instance that can execute a set of prompts and return - the raw responses. + RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses. """ return Mistral( name=name, config_init=config_init, config_run=config_run, context_length=8000 diff --git a/spacy_llm/models/hf/registry.py b/spacy_llm/models/hf/registry.py new file mode 100644 index 00000000..247ae1f7 --- /dev/null +++ b/spacy_llm/models/hf/registry.py @@ -0,0 +1,47 @@ +from typing import Any, Callable, Dict, Iterable, Optional + +from confection import SimpleFrozenDict + +from ...registry import registry +from .dolly import Dolly +from .falcon import Falcon +from .llama2 import Llama2 +from .mistral import Mistral +from .openllama import OpenLLaMA +from .stablelm import StableLM + + +@registry.llm_models("spacy.HuggingFace.v1") +def huggingface_v1( + name: str, + config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), + config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns HuggingFace model instance. + name (str): Name of model to use. + config_init (Optional[Dict[str, Any]]): HF config for initializing the model. + config_run (Optional[Dict[str, Any]]): HF config for running the model. + RETURNS (Callable[[Iterable[str]], Iterable[str]]): Model instance that can execute a set of prompts and return + the raw responses. + """ + model_context_lengths = { + Dolly: 2048, + Falcon: 2048, + Llama2: 4096, + Mistral: 8000, + OpenLLaMA: 2048, + StableLM: 4096, + } + + for model_cls, context_length in model_context_lengths.items(): + if name in getattr(model_cls, "MODEL_NAMES", {}): + return model_cls( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) + + raise ValueError( + f"Name {name} could not be associated with any of the supported models. Please check https://spacy.io/api/large-language-models#models-hf to ensure the specified model name is correct." + ) diff --git a/spacy_llm/models/rest/anthropic/registry.py b/spacy_llm/models/rest/anthropic/registry.py index dc44eb7e..9719af18 100644 --- a/spacy_llm/models/rest/anthropic/registry.py +++ b/spacy_llm/models/rest/anthropic/registry.py @@ -7,6 +7,43 @@ from .model import Anthropic, Endpoints +@registry.llm_models("spacy.Anthropic.v1") +def anthropic_v1( + name: str, + config: Dict[Any, Any] = SimpleFrozenDict(), + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Anthropic: + """Returns Anthropic model instance using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use. + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Instance of Anthropic model. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Claude-2.v2") def anthropic_claude_2_v2( config: Dict[Any, Any] = SimpleFrozenDict(), diff --git a/spacy_llm/models/rest/cohere/registry.py b/spacy_llm/models/rest/cohere/registry.py index 79c711e1..8deb979d 100644 --- a/spacy_llm/models/rest/cohere/registry.py +++ b/spacy_llm/models/rest/cohere/registry.py @@ -7,6 +7,43 @@ from .model import Cohere, Endpoints +@registry.llm_models("spacy.Cohere.v1") +def cohere_v1( + name: str, + config: Dict[Any, Any] = SimpleFrozenDict(), + strict: bool = Cohere.DEFAULT_STRICT, + max_tries: int = Cohere.DEFAULT_MAX_TRIES, + interval: float = Cohere.DEFAULT_INTERVAL, + max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Cohere: + """Returns Cohere model instance using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use. + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Cohere): Instance of Cohere model. + """ + return Cohere( + name=name, + endpoint=Endpoints.COMPLETION.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Command.v2") def cohere_command_v2( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -56,7 +93,7 @@ def cohere_command( max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME, ) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Cohere instance for 'command' model using REST to prompt API. - name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use. + name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Name of model to use. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON or other response object that does not conform to the expectation of how a well-formed response object from diff --git a/spacy_llm/models/rest/openai/registry.py b/spacy_llm/models/rest/openai/registry.py index 3c3793ff..0e7a675d 100644 --- a/spacy_llm/models/rest/openai/registry.py +++ b/spacy_llm/models/rest/openai/registry.py @@ -8,6 +8,49 @@ _DEFAULT_TEMPERATURE = 0.0 + +@registry.llm_models("spacy.OpenAI.v") +def openai_v1( + name: str, + config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), + strict: bool = OpenAI.DEFAULT_STRICT, + max_tries: int = OpenAI.DEFAULT_MAX_TRIES, + interval: float = OpenAI.DEFAULT_INTERVAL, + max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, + endpoint: Optional[str] = None, + context_length: Optional[int] = None, +) -> OpenAI: + """Returns OpenAI model instance using REST to prompt API. + + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + name (str): Model name to use. Can be any model name supported by the OpenAI API. + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (OpenAI): OpenAI model instance. + + DOCS: https://spacy.io/api/large-language-models#models + """ + return OpenAI( + name=name, + endpoint=endpoint or Endpoints.CHAT.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + """ Parameter explanations: strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON diff --git a/spacy_llm/models/rest/palm/registry.py b/spacy_llm/models/rest/palm/registry.py index d7bae629..506e6d4b 100644 --- a/spacy_llm/models/rest/palm/registry.py +++ b/spacy_llm/models/rest/palm/registry.py @@ -7,6 +7,48 @@ from .model import Endpoints, PaLM +@registry.llm_models("spacy.Google.v1") +def google_v1( + name: str, + config: Dict[Any, Any] = SimpleFrozenDict(temperature=0), + strict: bool = PaLM.DEFAULT_STRICT, + max_tries: int = PaLM.DEFAULT_MAX_TRIES, + interval: float = PaLM.DEFAULT_INTERVAL, + max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, + endpoint: Optional[str] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Google model instance using REST to prompt API. + name (str): Name of model to use. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + endpoint (Optional[str]): Endpoint to use. Defaults to standard endpoint. + RETURNS (PaLM): PaLM model instance. + """ + default_endpoint = ( + Endpoints.TEXT.value if name in {"text-bison-001"} else Endpoints.MSG.value + ) + return PaLM( + name=name, + endpoint=endpoint or default_endpoint, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=None, + ) + + @registry.llm_models("spacy.PaLM.v2") def palm_bison_v2( config: Dict[Any, Any] = SimpleFrozenDict(temperature=0), @@ -18,7 +60,7 @@ def palm_bison_v2( context_length: Optional[int] = None, ) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Google instance for PaLM Bison model using REST to prompt API. - name (Literal["chat-bison-001", "text-bison-001"]): Model to use. + name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON or other response object that does not conform to the expectation of how a well-formed response object from @@ -57,7 +99,7 @@ def palm_bison( endpoint: Optional[str] = None, ) -> PaLM: """Returns Google instance for PaLM Bison model using REST to prompt API. - name (Literal["chat-bison-001", "text-bison-001"]): Model to use. + name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON or other response object that does not conform to the expectation of how a well-formed response object from