diff --git a/dsp/modules/unify.py b/dsp/modules/unify.py index 6963cd7c4..760efedbc 100644 --- a/dsp/modules/unify.py +++ b/dsp/modules/unify.py @@ -12,8 +12,6 @@ class Unify(LM): def __init__( self, endpoint="router@q:1|c:4.65e-03|t:2.08e-05|i:2.07e-03", - # model: Optional[str] = None, - # provider: Optional[str] = None, model_type: Literal["chat", "text"] = "chat", stream: Optional[bool] = False, base_url="https://api.unify.ai/v0", @@ -22,9 +20,10 @@ def __init__( api_key=None, **kwargs, ): + """ + Initializes the Unify client with the specified parameters. + """ self.api_key = api_key - # self.model = model - # self.provider = provider self.endpoint = endpoint self.stream = stream self.client = UnifyClient(api_key=self.api_key, endpoint=self.endpoint) @@ -46,7 +45,11 @@ def __init__( self.history: list[dict[str, Any]] = [] def basic_request(self, prompt: str, **kwargs) -> Any: - """Basic request to the Unify's API.""" + """ + Sends a basic request to the Unify API. + Returns: + Any: The response from the API. + """ kwargs = {**self.kwargs, **kwargs} settings_dict = { "endpoint": self.endpoint, @@ -80,14 +83,19 @@ def basic_request(self, prompt: str, **kwargs) -> Any: response = {"choices": [{"message": {"content": response}}]} # response with choices if not response: - logging.error("Unexpected response format, not response") + logging.error("Unexpected response format, no response") elif "choices" not in response: logging.error(f"no choices in response: {response}") return response def request(self, prompt: str, **kwargs) -> Any: - """Handles retreival of model completions whilst handling rate limiting and caching.""" + """ + Handles retrieval of model completions while managing rate limiting and caching. + + Returns: + Any: The response from the API. + """ if "model_type" in kwargs: del kwargs["model_type"] return self.basic_request(prompt, **kwargs) @@ -99,7 +107,12 @@ def __call__( return_sorted: bool = False, **kwargs, ) -> list[dict[str, Any]]: - """Request completions from the Unify API.""" + """ + Requests completions from the Unify API. + + Returns: + list[dict[str, Any]]: A list of completions from the API. + """ assert only_completed, "for now" assert return_sorted is False, "for now" @@ -115,3 +128,4 @@ def __call__( raise ValueError("Unexpected response format") return completions +