Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dsp): reformatted the integration #17

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dsp/modules/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def inspect_history(self, n: int = 1, skip: int = 0):
"premai",
"you.com",
"tensorrt_llm",
"unify",
):
text = choices
elif provider == "openai" or provider == "ollama" or provider == "unify":
elif provider == "openai" or provider == "ollama":
text = " " + self._get_choice_text(choices[0]).strip()
elif provider == "groq":
text = " " + choices
Expand Down
107 changes: 48 additions & 59 deletions dsp/modules/unify.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,32 @@
import logging
from typing import Any, Literal, Optional
from typing import Any, Optional

from unify.clients import Unify as UnifyClient

from dsp.modules.lm import LM


class Unify(LM):
class Unify(LM, UnifyClient):
"""A class to interact with the Unify AI API."""

def __init__(
self,
endpoint="router@q:1|c:4.65e-03|t:2.08e-05|i:2.07e-03",
# model: Optional[str] = None,
# provider: Optional[str] = None,
model_type: Literal["chat", "text"] = "chat",
endpoint: str = "router@q:1|c:4.65e-03|t:2.08e-05|i:2.07e-03",
model: Optional[str] = None,
provider: Optional[str] = None,
api_key=None,
stream: Optional[bool] = False,
base_url="https://api.unify.ai/v0",
system_prompt: Optional[str] = None,
base_url: str = "https://api.unify.ai/v0",
n: int = 1,
api_key=None,
**kwargs,
):
self.api_key = api_key
# self.model = model
# self.provider = provider
self.endpoint = endpoint
self.base_url = base_url
self.stream = stream
self.client = UnifyClient(api_key=self.api_key, endpoint=self.endpoint)

super().__init__(model=self.endpoint)

LM.__init__(self, model)
UnifyClient.__init__(self, endpoint=endpoint, model=model, provider=provider, api_key=api_key)
# super().__init__(model)
self.system_prompt = system_prompt
self.model_type = model_type
self.kwargs = {
"temperature": 0.0,
"max_tokens": 150,
Expand All @@ -44,6 +38,24 @@ def __init__(
}
self.kwargs["endpoint"] = endpoint
self.history: list[dict[str, Any]] = []
self._dspy_integration_provider = "unify"

@property
def provider(self) -> Optional[str]:
return self._dspy_integration_provider

@provider.setter
def provider(self, value: str) -> None:
self._dspy_integration_provider = value

@property
def model_provider(self) -> Optional[str]:
return UnifyClient.provider(self)

@model_provider.setter
def model_provider(self, value: str) -> None:
if value != "default":
self.set_provider(value)

def basic_request(self, prompt: str, **kwargs) -> Any:
"""Basic request to the Unify's API."""
Expand All @@ -52,66 +64,43 @@ def basic_request(self, prompt: str, **kwargs) -> Any:
"endpoint": self.endpoint,
"stream": self.stream,
}
if self.model_type == "chat":
messages = [{"role": "user", "content": prompt}]
settings_dict["messages"] = messages
if self.system_prompt:
settings_dict["messages"].insert(0, {"role": "system", "content": self.system_prompt})
else:
settings_dict["prompt"] = prompt
messages = [{"role": "user", "content": prompt}]
settings_dict["messages"] = messages
if self.system_prompt:
settings_dict["messages"].insert(0, {"role": "system", "content": self.system_prompt})

logging.debug(f"Settings Dict: {settings_dict}")

if "messages" in settings_dict:
response = self.client.generate(
messages=settings_dict["messages"],
stream=settings_dict["stream"],
temperature=kwargs["temperature"],
max_tokens=kwargs["max_tokens"],
)
else:
response = self.client.generate(
user_prompt=settings_dict["prompt"],
stream=settings_dict["stream"],
temperature=kwargs["temperature"],
max_tokens=kwargs["max_tokens"],
)
response = self.generate(
messages=settings_dict["messages"],
stream=settings_dict["stream"],
temperature=kwargs["temperature"],
max_tokens=kwargs["max_tokens"],
)

response = {"choices": [{"message": {"content": response}}]} # response with choices

if not response:
logging.error("Unexpected response format, not response")
elif "choices" not in response:
logging.error(f"no choices in response: {response}")

if isinstance(response, dict) and "choices" in response:
self.history.append({"prompt": prompt, "response": response})
else:
raise ValueError("Unexpected response format")
return response

def request(self, prompt: str, **kwargs) -> Any:
"""Handles retreival of model completions whilst handling rate limiting and caching."""
if "model_type" in kwargs:
del kwargs["model_type"]
return self.basic_request(prompt, **kwargs)

def __call__(
self,
prompt: str,
prompt: Optional[str],
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[dict[str, Any]]:
"""Request completions from the Unify API."""
assert only_completed, "for now"
assert return_sorted is False, "for now"

n = kwargs.pop("n", 1)
completions = []

for _ in range(n):
response = self.request(prompt, **kwargs)

if isinstance(response, dict) and "choices" in response:
completions.append(response["choices"][0]["message"]["content"])
else:
raise ValueError("Unexpected response format")

return completions
n: int = kwargs.get("n") or 1
skip: int = kwargs.get("skip") or 0
self.request(prompt, **kwargs)
return self.inspect_history(n=n, skip=skip)
14 changes: 11 additions & 3 deletions mwe_unify.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
test_add_unify_client
import logging
import os
from dotenv import load_dotenv

import dsp
import dspy
from dspy.datasets.gsm8k import GSM8K, gsm8k_metric
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShot

endpoint = dsp.Unify(

load_dotenv()
unify_api_key = os.getenv("UNIFY_KEY")

lm = dsp.Unify(
endpoint="gpt-3.5-turbo@openai",
max_tokens=150,
api_key="QOZDhc54GhdcuUGXkPrDrjxoaySXOOPvq38rUUa+Mpk=",
model_type="text",
api_key=unify_api_key,
)

dspy.settings.configure(lm=endpoint)
Expand Down
Loading