Skip to content

Commit

Permalink
Add wait and fixes types (#410)
Browse files Browse the repository at this point in the history
* Add wait and fixes types

* Fix types

* Fix imports

* Fix type
  • Loading branch information
billytrend-cohere authored Mar 19, 2024
1 parent 77e0087 commit d476799
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 58 deletions.
11 changes: 9 additions & 2 deletions src/cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from .base_client import BaseCohere, AsyncBaseCohere
from .environment import ClientEnvironment
from .utils import wait, async_wait


# Use NoReturn as Never type for compatibility
Never = typing.NoReturn
Expand All @@ -25,6 +27,7 @@ def throw_if_stream_is_true(*args, **kwargs) -> None:
"Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)"
)


def moved_function(fn_name: str, new_fn_name: str) -> typing.Any:
"""
This method is moved. Please update usage.
Expand Down Expand Up @@ -56,7 +59,7 @@ def fn(*args, **kwargs):
class Client(BaseCohere):
def __init__(
self,
api_key: typing.Union[str, typing.Callable[[], str]],
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
*,
base_url: typing.Optional[str] = None,
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
Expand All @@ -76,6 +79,8 @@ def __init__(

validate_args(self, "chat", throw_if_stream_is_true)

wait = wait

"""
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
Expand Down Expand Up @@ -125,7 +130,7 @@ def __init__(
class AsyncClient(AsyncBaseCohere):
def __init__(
self,
api_key: typing.Union[str, typing.Callable[[], str]],
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
*,
base_url: typing.Optional[str] = None,
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
Expand All @@ -145,6 +150,8 @@ def __init__(

validate_args(self, "chat", throw_if_stream_is_true)

wait = async_wait

"""
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
Expand Down
152 changes: 152 additions & 0 deletions src/cohere/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import asyncio
import time
import typing
from typing import Optional

from .types import EmbedJob, CreateEmbedJobResponse
from .datasets import DatasetsCreateResponse, DatasetsGetResponse


def get_terminal_states():
return get_success_states() | get_failed_states()


def get_success_states():
return {"complete", "validated"}


def get_failed_states():
return {"unknown", "failed", "skipped", "cancelled", "failed"}


def get_id(
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]):
return getattr(awaitable, "job_id", None) or getattr(awaitable, "id", None) or getattr(
getattr(awaitable, "dataset", None), "id", None)


def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]):
return getattr(awaitable, "status", None) or getattr(getattr(awaitable, "dataset", None), "validation_status", None)


def get_job(cohere: typing.Any,
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \
typing.Union[
EmbedJob, DatasetsGetResponse]:
if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
return cohere.embed_jobs.get(id=get_id(awaitable))
elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
return cohere.datasets.get(id=get_id(awaitable))
else:
raise ValueError(f"Unexpected awaitable type {awaitable}")


async def async_get_job(cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse]) -> \
typing.Union[
EmbedJob, DatasetsGetResponse]:
if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
return await cohere.embed_jobs.get(id=get_id(awaitable))
elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
return await cohere.datasets.get(id=get_id(awaitable))
else:
raise ValueError(f"Unexpected awaitable type {awaitable}")


def get_failure_reason(job: typing.Union[EmbedJob, DatasetsGetResponse]) -> Optional[str]:
if isinstance(job, EmbedJob):
return f"Embed job {job.job_id} failed with status {job.status}"
elif isinstance(job, DatasetsGetResponse):
return f"Dataset creation {job.dataset.validation_status} failed with status {job.dataset.validation_status}"
return None


@typing.overload
def wait(
cohere: typing.Any,
awaitable: CreateEmbedJobResponse,
timeout: Optional[float] = None,
interval: float = 10,
) -> EmbedJob:
...


@typing.overload
def wait(
cohere: typing.Any,
awaitable: DatasetsCreateResponse,
timeout: Optional[float] = None,
interval: float = 10,
) -> DatasetsGetResponse:
...


def wait(
cohere: typing.Any,
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse],
timeout: Optional[float] = None,
interval: float = 2,
) -> typing.Union[EmbedJob, DatasetsGetResponse]:
start_time = time.time()
terminal_states = get_terminal_states()
failed_states = get_failed_states()

job = get_job(cohere, awaitable)
while get_validation_status(job) not in terminal_states:
if timeout is not None and time.time() - start_time > timeout:
raise TimeoutError(f"wait timed out after {timeout} seconds")

time.sleep(interval)
print("...")

job = get_job(cohere, awaitable)

if get_validation_status(job) in failed_states:
raise Exception(get_failure_reason(job))

return job


@typing.overload
async def async_wait(
cohere: typing.Any,
awaitable: CreateEmbedJobResponse,
timeout: Optional[float] = None,
interval: float = 10,
) -> EmbedJob:
...


@typing.overload
async def async_wait(
cohere: typing.Any,
awaitable: DatasetsCreateResponse,
timeout: Optional[float] = None,
interval: float = 10,
) -> DatasetsGetResponse:
...


async def async_wait(
cohere: typing.Any,
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse],
timeout: Optional[float] = None,
interval: float = 10,
) -> typing.Union[EmbedJob, DatasetsGetResponse]:
start_time = time.time()
terminal_states = get_terminal_states()
failed_states = get_failed_states()

job = await async_get_job(cohere, awaitable)
while get_validation_status(job) not in terminal_states:
if timeout is not None and time.time() - start_time > timeout:
raise TimeoutError(f"wait timed out after {timeout} seconds")

await asyncio.sleep(interval)
print("...")

job = await async_get_job(cohere, awaitable)

if get_validation_status(job) in failed_states:
raise Exception(get_failure_reason(job))

return job
Loading

0 comments on commit d476799

Please sign in to comment.