Skip to content

Commit

Permalink
feat: introduced Annotation/RequestDict type aliases; migrated from C…
Browse files Browse the repository at this point in the history
…oroutine to Awaitable type
  • Loading branch information
adubovik committed Nov 26, 2024
1 parent cb64f48 commit cb3b7b6
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 38 deletions.
8 changes: 5 additions & 3 deletions aidial_interceptors_sdk/chat_completion/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@

from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedException,
Annotation,
)
from aidial_interceptors_sdk.chat_completion.base import (
ChatCompletionInterceptor,
RequestDict,
)
from aidial_interceptors_sdk.dial_client import DialClient
from aidial_interceptors_sdk.error import EarlyStreamExit
Expand Down Expand Up @@ -70,7 +72,7 @@ async def chat_completion(
try:
await interceptor.on_stream_start()

def call_upstream(context: Any | None, request: dict):
def call_upstream(context: Annotation, request: dict):
return call_single_upstream(dial_client, context, request)

async for value in await interceptor.call_upstreams(
Expand All @@ -89,7 +91,7 @@ def call_upstream(context: Any | None, request: dict):


async def call_single_upstream(
dial_client: DialClient, context: Any | None, request: dict
dial_client: DialClient, context: Annotation, request: RequestDict
) -> AsyncIterator[dict | DialException]:
response = cast(
AsyncStream[ChatCompletionChunk] | ChatCompletion,
Expand All @@ -103,7 +105,7 @@ async def call_single_upstream(
if _debug():
_log.debug(f"upstream response[{context}]: {json.dumps(resp)}")

# Block mode:
# Non-streaming mode:
# Removing the default fields which are generated by
# DIAL SDK automatically.
# It also means that these fields aren't proxied from the upstream.
Expand Down
4 changes: 3 additions & 1 deletion aidial_interceptors_sdk/chat_completion/annotated_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.pydantic_v1 import BaseModel

Annotation = Any | None


class AnnotatedValueBase(BaseModel, ABC):
class Config:
arbitrary_types_allowed = True

annotation: Any | None = None
annotation: Annotation = None


class AnnotatedChunk(AnnotatedValueBase):
Expand Down
11 changes: 7 additions & 4 deletions aidial_interceptors_sdk/chat_completion/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, AsyncIterator, Callable, Coroutine
from typing import AsyncIterator, Awaitable, Callable

from aidial_sdk.exceptions import HTTPException as DialException

from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedException,
AnnotatedValue,
Annotation,
)
from aidial_interceptors_sdk.chat_completion.request_handler import (
RequestHandler,
Expand All @@ -15,16 +16,18 @@
from aidial_interceptors_sdk.dial_client import DialClient
from aidial_interceptors_sdk.utils.streaming import annotate_stream

RequestDict = dict


class ChatCompletionInterceptor(RequestHandler, ResponseHandler):
dial_client: DialClient

async def call_upstreams(
self,
request: dict,
request: RequestDict,
call_upstream: Callable[
[Any | None, dict],
Coroutine[Any, Any, AsyncIterator[dict | DialException]],
[Annotation, RequestDict],
Awaitable[AsyncIterator[dict | DialException]],
],
) -> AsyncIterator[AnnotatedValue]:
annotation = None
Expand Down
26 changes: 13 additions & 13 deletions aidial_interceptors_sdk/chat_completion/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Coroutine, List, TypeVar, overload
from typing import Awaitable, Callable, List, TypeVar, overload

from aidial_interceptors_sdk.utils.not_given import NOT_GIVEN, NotGiven

Expand All @@ -13,7 +13,7 @@ async def traverse_dict_value(
key: str,
on_value: Callable[
[P, T | NotGiven | None],
Coroutine[Any, Any, T | NotGiven | None],
Awaitable[T | NotGiven | None],
],
) -> dict: ...

Expand All @@ -25,7 +25,7 @@ async def traverse_dict_value(
key: str,
on_value: Callable[
[P, T | NotGiven | None],
Coroutine[Any, Any, T | NotGiven | None],
Awaitable[T | NotGiven | None],
],
) -> NotGiven: ...

Expand All @@ -37,7 +37,7 @@ async def traverse_dict_value(
key: str,
on_value: Callable[
[P, T | NotGiven | None],
Coroutine[Any, Any, T | NotGiven | None],
Awaitable[T | NotGiven | None],
],
) -> None: ...

Expand All @@ -48,7 +48,7 @@ async def traverse_dict_value(
key: str,
on_value: Callable[
[P, T | NotGiven | None],
Coroutine[Any, Any, T | NotGiven | None],
Awaitable[T | NotGiven | None],
],
) -> dict | NotGiven | None:
if d is None or isinstance(d, NotGiven):
Expand All @@ -71,7 +71,7 @@ async def traverse_required_dict_value(
path: P,
d: None,
key: str,
on_value: Callable[[P, T], Coroutine[Any, Any, T]],
on_value: Callable[[P, T], Awaitable[T]],
) -> None: ...


Expand All @@ -80,7 +80,7 @@ async def traverse_required_dict_value(
path: P,
d: NotGiven,
key: str,
on_value: Callable[[P, T], Coroutine[Any, Any, T]],
on_value: Callable[[P, T], Awaitable[T]],
) -> NotGiven: ...


Expand All @@ -89,15 +89,15 @@ async def traverse_required_dict_value(
path: P,
d: dict,
key: str,
on_value: Callable[[P, T], Coroutine[Any, Any, T]],
on_value: Callable[[P, T], Awaitable[T]],
) -> dict: ...


async def traverse_required_dict_value(
path: P,
d: dict | NotGiven | None,
key: str,
on_value: Callable[[P, T], Coroutine[Any, Any, T]],
on_value: Callable[[P, T], Awaitable[T]],
) -> dict | NotGiven | None:
if d is None or isinstance(d, NotGiven):
return d
Expand All @@ -115,30 +115,30 @@ async def traverse_required_dict_value(
async def traverse_list(
create_elem_path: Callable[[int], P],
lst: NotGiven,
on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]],
on_elem: Callable[[P, T], Awaitable[List[T] | T]],
) -> NotGiven: ...


@overload
async def traverse_list(
create_elem_path: Callable[[int], P],
lst: None,
on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]],
on_elem: Callable[[P, T], Awaitable[List[T] | T]],
) -> None: ...


@overload
async def traverse_list(
create_elem_path: Callable[[int], P],
lst: List[T],
on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]],
on_elem: Callable[[P, T], Awaitable[List[T] | T]],
) -> List[T]: ...


async def traverse_list(
create_elem_path: Callable[[int], P],
lst: List[T] | NotGiven | None,
on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]],
on_elem: Callable[[P, T], Awaitable[List[T] | T]],
) -> List[T] | NotGiven | None:
if lst is None or isinstance(lst, NotGiven):
return lst
Expand Down
11 changes: 6 additions & 5 deletions aidial_interceptors_sdk/examples/chat_completion/replicator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
List,
Tuple,
Expand All @@ -22,9 +21,11 @@

from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedValue,
Annotation,
)
from aidial_interceptors_sdk.chat_completion.base import (
ChatCompletionInterceptor,
RequestDict,
)
from aidial_interceptors_sdk.chat_completion.element_path import ElementPath
from aidial_interceptors_sdk.chat_completion.index_mapper import IndexMapper
Expand Down Expand Up @@ -74,10 +75,10 @@ def _get_content_stage(self, path: ElementPath) -> Stage:
@override
async def call_upstreams(
self,
request: dict,
request: RequestDict,
call_upstream: Callable[
[Any | None, dict],
Coroutine[Any, Any, AsyncIterator[dict | DialException]],
[Annotation, RequestDict],
Awaitable[AsyncIterator[dict | DialException]],
],
) -> AsyncIterator[AnnotatedValue]:
request["n"] = 1
Expand Down
16 changes: 8 additions & 8 deletions aidial_interceptors_sdk/utils/_debug.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import Callable, Coroutine, TypeVar
from typing import Awaitable, Callable, TypeVar

_log = logging.getLogger(__name__)

Expand All @@ -9,23 +9,23 @@ def _debug():
return _log.isEnabledFor(logging.DEBUG)


A = TypeVar("A")
B = TypeVar("B")
_A = TypeVar("_A")
_B = TypeVar("_B")


def debug_logging(
title: str,
) -> Callable[
[Callable[[A], Coroutine[None, None, B]]],
Callable[[A], Coroutine[None, None, B]],
[Callable[[_A], Awaitable[_B]]],
Callable[[_A], Awaitable[_B]],
]:
def decorator(
fn: Callable[[A], Coroutine[None, None, B]]
) -> Callable[[A], Coroutine[None, None, B]]:
fn: Callable[[_A], Awaitable[_B]]
) -> Callable[[_A], Awaitable[_B]]:
if not _debug():
return fn

async def _fn(a: A) -> B:
async def _fn(a: _A) -> _B:
_log.debug(f"{title} old: {json.dumps(a)}")
b = await fn(a)
_log.debug(f"{title} new: {json.dumps(b)}")
Expand Down
6 changes: 3 additions & 3 deletions aidial_interceptors_sdk/utils/_reflection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import inspect
from typing import Any, Callable, Coroutine, TypeVar
from typing import Any, Awaitable, Callable, TypeVar

from aidial_sdk.exceptions import InvalidRequestError

T = TypeVar("T")


async def call_with_extra_body(
func: Callable[..., Coroutine[Any, Any, T]], arg: dict
func: Callable[..., Awaitable[T]], arg: dict
) -> T:
if _has_kwargs_argument(func):
return await func(**arg)
Expand All @@ -32,7 +32,7 @@ async def call_with_extra_body(
return await func(**arg)


def _has_kwargs_argument(func: Callable[..., Coroutine[Any, Any, Any]]) -> bool:
def _has_kwargs_argument(func: Callable[..., Awaitable[Any]]) -> bool:
"""
Determines if the given function accepts a variable keyword argument (**kwargs).
"""
Expand Down
3 changes: 2 additions & 1 deletion aidial_interceptors_sdk/utils/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AnnotatedChunk,
AnnotatedException,
AnnotatedValue,
Annotation,
)
from aidial_interceptors_sdk.utils._exceptions import to_dial_exception

Expand All @@ -32,7 +33,7 @@ async def materialize_streaming_errors(


def annotate_stream(
annotation: Any | None, stream: AsyncIterator[dict | DialException]
annotation: Annotation, stream: AsyncIterator[dict | DialException]
) -> AsyncIterator[AnnotatedValue]:
def _annotate(value: dict | DialException) -> AnnotatedValue:
if isinstance(value, dict):
Expand Down

0 comments on commit cb3b7b6

Please sign in to comment.