Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mistral model api support #456

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions spacy_llm/models/rest/mistral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .model import AzureMistral
from .registry import azure_mistral

__all__ = ["AzureMistral", "azure_mistral"]
80 changes: 80 additions & 0 deletions spacy_llm/models/rest/mistral/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import warnings
import os
from typing import Iterable, Optional, Any, Dict
from ..base import REST

from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage


class AzureMistral(REST):
def __init__(
self,
name: str,
endpoint: str,
config: Dict[Any, Any],
strict: bool,
max_tries: int,
interval: float,
max_request_time: float,
context_length: Optional[int],
):
super().__init__(
name,
endpoint,
config,
strict,
max_tries,
interval,
max_request_time,
context_length,
)

def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]:
all_resps = []
api_key = self._credentials.get("api-key")
for prompts_doc in prompts:
doc_resps = []
for prompt in prompts_doc:
client = MistralClient(endpoint=self._endpoint, api_key=api_key)

chat_response = client.chat(
model=self._name,
messages=[
ChatMessage(
role="user",
content=prompt,
)
],
)
doc_resps.append(chat_response.choices[0].message.content)
all_resps.append(doc_resps)
return all_resps

@staticmethod
def _get_context_lengths() -> Dict[str, int]:
return {
"azureai": 8192,
}

def _verify_auth(self) -> None:
try:
self([["test"]])
except ValueError as err:
raise err

@property
def credentials(self) -> Dict[str, str]:
# Fetch and check the key
api_key = os.getenv("MISTRAL_API_KEY")
if api_key is None:
warnings.warn(
"Could not find the API key to access the Mistral AI API. Ensure you have an API key "
"set up (see "
", then make it available as an environment variable 'MISTRAL_API_KEY'."
)

# Check the access and get a list of available models to verify the model argument (if not None)
# Even if the model is None, this call is used as a healthcheck to verify access.
assert api_key is not None
return {"api-key": api_key}
28 changes: 28 additions & 0 deletions spacy_llm/models/rest/mistral/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from ....registry import registry
from typing import Optional, Dict, Any
from .model import AzureMistral
from confection import SimpleFrozenDict


@registry.llm_models("AzureMistral.v1")
def azure_mistral(
name,
endpoint,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0.0),
strict: bool = AzureMistral.DEFAULT_STRICT,
max_tries: int = AzureMistral.DEFAULT_MAX_TRIES,
interval: float = AzureMistral.DEFAULT_INTERVAL,
max_request_time: float = AzureMistral.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
):

return AzureMistral(
name=name,
endpoint=endpoint,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)