diff --git a/spacy_llm/models/rest/mistral/__init__.py b/spacy_llm/models/rest/mistral/__init__.py new file mode 100644 index 00000000..bb45f7fb --- /dev/null +++ b/spacy_llm/models/rest/mistral/__init__.py @@ -0,0 +1,4 @@ +from .model import AzureMistral +from .registry import azure_mistral + +__all__ = ["AzureMistral", "azure_mistral"] diff --git a/spacy_llm/models/rest/mistral/model.py b/spacy_llm/models/rest/mistral/model.py new file mode 100644 index 00000000..83918171 --- /dev/null +++ b/spacy_llm/models/rest/mistral/model.py @@ -0,0 +1,80 @@ +import warnings +import os +from typing import Iterable, Optional, Any, Dict +from ..base import REST + +from mistralai.client import MistralClient +from mistralai.models.chat_completion import ChatMessage + + +class AzureMistral(REST): + def __init__( + self, + name: str, + endpoint: str, + config: Dict[Any, Any], + strict: bool, + max_tries: int, + interval: float, + max_request_time: float, + context_length: Optional[int], + ): + super().__init__( + name, + endpoint, + config, + strict, + max_tries, + interval, + max_request_time, + context_length, + ) + + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: + all_resps = [] + api_key = self._credentials.get("api-key") + for prompts_doc in prompts: + doc_resps = [] + for prompt in prompts_doc: + client = MistralClient(endpoint=self._endpoint, api_key=api_key) + + chat_response = client.chat( + model=self._name, + messages=[ + ChatMessage( + role="user", + content=prompt, + ) + ], + ) + doc_resps.append(chat_response.choices[0].message.content) + all_resps.append(doc_resps) + return all_resps + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + "azureai": 8192, + } + + def _verify_auth(self) -> None: + try: + self([["test"]]) + except ValueError as err: + raise err + + @property + def credentials(self) -> Dict[str, str]: + # Fetch and check the key + api_key = os.getenv("MISTRAL_API_KEY") + if api_key is None: + warnings.warn( + "Could not find the API key to access the Mistral AI API. Ensure you have an API key " + "set up (see " + ", then make it available as an environment variable 'MISTRAL_API_KEY'." + ) + + # Check the access and get a list of available models to verify the model argument (if not None) + # Even if the model is None, this call is used as a healthcheck to verify access. + assert api_key is not None + return {"api-key": api_key} diff --git a/spacy_llm/models/rest/mistral/registry.py b/spacy_llm/models/rest/mistral/registry.py new file mode 100644 index 00000000..8408a6ed --- /dev/null +++ b/spacy_llm/models/rest/mistral/registry.py @@ -0,0 +1,28 @@ +from ....registry import registry +from typing import Optional, Dict, Any +from .model import AzureMistral +from confection import SimpleFrozenDict + + +@registry.llm_models("AzureMistral.v1") +def azure_mistral( + name, + endpoint, + config: Dict[Any, Any] = SimpleFrozenDict(temperature=0.0), + strict: bool = AzureMistral.DEFAULT_STRICT, + max_tries: int = AzureMistral.DEFAULT_MAX_TRIES, + interval: float = AzureMistral.DEFAULT_INTERVAL, + max_request_time: float = AzureMistral.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +): + + return AzureMistral( + name=name, + endpoint=endpoint, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + )