From 05557598d3dba816ff18afd0be4047c479ac65d7 Mon Sep 17 00:00:00 2001 From: Shubham Sureka Date: Fri, 29 Mar 2024 13:21:39 +0530 Subject: [PATCH 1/2] mistral model api support --- spacy_llm/models/rest/mistral/__init__.py | 4 ++ spacy_llm/models/rest/mistral/model.py | 80 +++++++++++++++++++++++ spacy_llm/models/rest/mistral/registry.py | 28 ++++++++ 3 files changed, 112 insertions(+) create mode 100644 spacy_llm/models/rest/mistral/__init__.py create mode 100644 spacy_llm/models/rest/mistral/model.py create mode 100644 spacy_llm/models/rest/mistral/registry.py diff --git a/spacy_llm/models/rest/mistral/__init__.py b/spacy_llm/models/rest/mistral/__init__.py new file mode 100644 index 00000000..bb45f7fb --- /dev/null +++ b/spacy_llm/models/rest/mistral/__init__.py @@ -0,0 +1,4 @@ +from .model import AzureMistral +from .registry import azure_mistral + +__all__ = ["AzureMistral", "azure_mistral"] diff --git a/spacy_llm/models/rest/mistral/model.py b/spacy_llm/models/rest/mistral/model.py new file mode 100644 index 00000000..3bef2bab --- /dev/null +++ b/spacy_llm/models/rest/mistral/model.py @@ -0,0 +1,80 @@ +import warnings +import os +from typing import Iterable, Optional, Any, Dict +from ..base import REST + + +class AzureMistral(REST): + def __init__( + self, + name: str, + endpoint: str, + config: Dict[Any, Any], + strict: bool, + max_tries: int, + interval: float, + max_request_time: float, + context_length: Optional[int], + ): + super().__init__( + name, + endpoint, + config, + strict, + max_tries, + interval, + max_request_time, + context_length, + ) + + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: + all_resps = [] + for prompts_doc in prompts: + doc_resps = [] + for prompt in prompts_doc: + from mistralai.client import MistralClient + from mistralai.models.chat_completion import ChatMessage + + api_key = self._credentials.get("api-key") + client = MistralClient(endpoint=self._endpoint, api_key=api_key) + + chat_response = client.chat( + model=self._name, + messages=[ + ChatMessage( + role="user", + content=prompt, + ) + ], + ) + doc_resps.append(chat_response.choices[0].message.content) + all_resps.append(doc_resps) + return all_resps + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + "azureai": 8192, + } + + def _verify_auth(self) -> None: + try: + self([["test"]]) + except ValueError as err: + raise err + + @property + def credentials(self) -> Dict[str, str]: + # Fetch and check the key + api_key = os.getenv("MISTRAL_API_KEY") + if api_key is None: + warnings.warn( + "Could not find the API key to access the Mistral AI API. Ensure you have an API key " + "set up (see " + ", then make it available as an environment variable 'MISTRAL_API_KEY'." + ) + + # Check the access and get a list of available models to verify the model argument (if not None) + # Even if the model is None, this call is used as a healthcheck to verify access. + assert api_key is not None + return {"api-key": api_key} diff --git a/spacy_llm/models/rest/mistral/registry.py b/spacy_llm/models/rest/mistral/registry.py new file mode 100644 index 00000000..8408a6ed --- /dev/null +++ b/spacy_llm/models/rest/mistral/registry.py @@ -0,0 +1,28 @@ +from ....registry import registry +from typing import Optional, Dict, Any +from .model import AzureMistral +from confection import SimpleFrozenDict + + +@registry.llm_models("AzureMistral.v1") +def azure_mistral( + name, + endpoint, + config: Dict[Any, Any] = SimpleFrozenDict(temperature=0.0), + strict: bool = AzureMistral.DEFAULT_STRICT, + max_tries: int = AzureMistral.DEFAULT_MAX_TRIES, + interval: float = AzureMistral.DEFAULT_INTERVAL, + max_request_time: float = AzureMistral.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +): + + return AzureMistral( + name=name, + endpoint=endpoint, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) From 6ed4ca4032a4a1dfc7e19882bd79c3f7d7b3ed37 Mon Sep 17 00:00:00 2001 From: Shubham Sureka Date: Fri, 29 Mar 2024 13:24:53 +0530 Subject: [PATCH 2/2] code-refactor --- spacy_llm/models/rest/mistral/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spacy_llm/models/rest/mistral/model.py b/spacy_llm/models/rest/mistral/model.py index 3bef2bab..83918171 100644 --- a/spacy_llm/models/rest/mistral/model.py +++ b/spacy_llm/models/rest/mistral/model.py @@ -3,6 +3,9 @@ from typing import Iterable, Optional, Any, Dict from ..base import REST +from mistralai.client import MistralClient +from mistralai.models.chat_completion import ChatMessage + class AzureMistral(REST): def __init__( @@ -29,13 +32,10 @@ def __init__( def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: all_resps = [] + api_key = self._credentials.get("api-key") for prompts_doc in prompts: doc_resps = [] for prompt in prompts_doc: - from mistralai.client import MistralClient - from mistralai.models.chat_completion import ChatMessage - - api_key = self._credentials.get("api-key") client = MistralClient(endpoint=self._endpoint, api_key=api_key) chat_response = client.chat(