From cb3b7b6c7594a4b75924732146f59630bbfe91c3 Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Tue, 26 Nov 2024 16:10:52 +0000 Subject: [PATCH] feat: introduced Annotation/RequestDict type aliases; migrated from Coroutine to Awaitable type --- .../chat_completion/adapter.py | 8 +++--- .../chat_completion/annotated_value.py | 4 ++- .../chat_completion/base.py | 11 +++++--- .../chat_completion/helpers.py | 26 +++++++++---------- .../examples/chat_completion/replicator.py | 11 ++++---- aidial_interceptors_sdk/utils/_debug.py | 16 ++++++------ aidial_interceptors_sdk/utils/_reflection.py | 6 ++--- aidial_interceptors_sdk/utils/streaming.py | 3 ++- 8 files changed, 47 insertions(+), 38 deletions(-) diff --git a/aidial_interceptors_sdk/chat_completion/adapter.py b/aidial_interceptors_sdk/chat_completion/adapter.py index 6d8c451..e95b6e7 100644 --- a/aidial_interceptors_sdk/chat_completion/adapter.py +++ b/aidial_interceptors_sdk/chat_completion/adapter.py @@ -13,9 +13,11 @@ from aidial_interceptors_sdk.chat_completion.annotated_value import ( AnnotatedException, + Annotation, ) from aidial_interceptors_sdk.chat_completion.base import ( ChatCompletionInterceptor, + RequestDict, ) from aidial_interceptors_sdk.dial_client import DialClient from aidial_interceptors_sdk.error import EarlyStreamExit @@ -70,7 +72,7 @@ async def chat_completion( try: await interceptor.on_stream_start() - def call_upstream(context: Any | None, request: dict): + def call_upstream(context: Annotation, request: dict): return call_single_upstream(dial_client, context, request) async for value in await interceptor.call_upstreams( @@ -89,7 +91,7 @@ def call_upstream(context: Any | None, request: dict): async def call_single_upstream( - dial_client: DialClient, context: Any | None, request: dict + dial_client: DialClient, context: Annotation, request: RequestDict ) -> AsyncIterator[dict | DialException]: response = cast( AsyncStream[ChatCompletionChunk] | ChatCompletion, @@ -103,7 +105,7 @@ async def call_single_upstream( if _debug(): _log.debug(f"upstream response[{context}]: {json.dumps(resp)}") - # Block mode: + # Non-streaming mode: # Removing the default fields which are generated by # DIAL SDK automatically. # It also means that these fields aren't proxied from the upstream. diff --git a/aidial_interceptors_sdk/chat_completion/annotated_value.py b/aidial_interceptors_sdk/chat_completion/annotated_value.py index 91c0280..dc2ff82 100644 --- a/aidial_interceptors_sdk/chat_completion/annotated_value.py +++ b/aidial_interceptors_sdk/chat_completion/annotated_value.py @@ -4,12 +4,14 @@ from aidial_sdk.exceptions import HTTPException as DialException from aidial_sdk.pydantic_v1 import BaseModel +Annotation = Any | None + class AnnotatedValueBase(BaseModel, ABC): class Config: arbitrary_types_allowed = True - annotation: Any | None = None + annotation: Annotation = None class AnnotatedChunk(AnnotatedValueBase): diff --git a/aidial_interceptors_sdk/chat_completion/base.py b/aidial_interceptors_sdk/chat_completion/base.py index 464be54..d1c0a5c 100644 --- a/aidial_interceptors_sdk/chat_completion/base.py +++ b/aidial_interceptors_sdk/chat_completion/base.py @@ -1,10 +1,11 @@ -from typing import Any, AsyncIterator, Callable, Coroutine +from typing import AsyncIterator, Awaitable, Callable from aidial_sdk.exceptions import HTTPException as DialException from aidial_interceptors_sdk.chat_completion.annotated_value import ( AnnotatedException, AnnotatedValue, + Annotation, ) from aidial_interceptors_sdk.chat_completion.request_handler import ( RequestHandler, @@ -15,16 +16,18 @@ from aidial_interceptors_sdk.dial_client import DialClient from aidial_interceptors_sdk.utils.streaming import annotate_stream +RequestDict = dict + class ChatCompletionInterceptor(RequestHandler, ResponseHandler): dial_client: DialClient async def call_upstreams( self, - request: dict, + request: RequestDict, call_upstream: Callable[ - [Any | None, dict], - Coroutine[Any, Any, AsyncIterator[dict | DialException]], + [Annotation, RequestDict], + Awaitable[AsyncIterator[dict | DialException]], ], ) -> AsyncIterator[AnnotatedValue]: annotation = None diff --git a/aidial_interceptors_sdk/chat_completion/helpers.py b/aidial_interceptors_sdk/chat_completion/helpers.py index ecb85eb..54cd99a 100644 --- a/aidial_interceptors_sdk/chat_completion/helpers.py +++ b/aidial_interceptors_sdk/chat_completion/helpers.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Coroutine, List, TypeVar, overload +from typing import Awaitable, Callable, List, TypeVar, overload from aidial_interceptors_sdk.utils.not_given import NOT_GIVEN, NotGiven @@ -13,7 +13,7 @@ async def traverse_dict_value( key: str, on_value: Callable[ [P, T | NotGiven | None], - Coroutine[Any, Any, T | NotGiven | None], + Awaitable[T | NotGiven | None], ], ) -> dict: ... @@ -25,7 +25,7 @@ async def traverse_dict_value( key: str, on_value: Callable[ [P, T | NotGiven | None], - Coroutine[Any, Any, T | NotGiven | None], + Awaitable[T | NotGiven | None], ], ) -> NotGiven: ... @@ -37,7 +37,7 @@ async def traverse_dict_value( key: str, on_value: Callable[ [P, T | NotGiven | None], - Coroutine[Any, Any, T | NotGiven | None], + Awaitable[T | NotGiven | None], ], ) -> None: ... @@ -48,7 +48,7 @@ async def traverse_dict_value( key: str, on_value: Callable[ [P, T | NotGiven | None], - Coroutine[Any, Any, T | NotGiven | None], + Awaitable[T | NotGiven | None], ], ) -> dict | NotGiven | None: if d is None or isinstance(d, NotGiven): @@ -71,7 +71,7 @@ async def traverse_required_dict_value( path: P, d: None, key: str, - on_value: Callable[[P, T], Coroutine[Any, Any, T]], + on_value: Callable[[P, T], Awaitable[T]], ) -> None: ... @@ -80,7 +80,7 @@ async def traverse_required_dict_value( path: P, d: NotGiven, key: str, - on_value: Callable[[P, T], Coroutine[Any, Any, T]], + on_value: Callable[[P, T], Awaitable[T]], ) -> NotGiven: ... @@ -89,7 +89,7 @@ async def traverse_required_dict_value( path: P, d: dict, key: str, - on_value: Callable[[P, T], Coroutine[Any, Any, T]], + on_value: Callable[[P, T], Awaitable[T]], ) -> dict: ... @@ -97,7 +97,7 @@ async def traverse_required_dict_value( path: P, d: dict | NotGiven | None, key: str, - on_value: Callable[[P, T], Coroutine[Any, Any, T]], + on_value: Callable[[P, T], Awaitable[T]], ) -> dict | NotGiven | None: if d is None or isinstance(d, NotGiven): return d @@ -115,7 +115,7 @@ async def traverse_required_dict_value( async def traverse_list( create_elem_path: Callable[[int], P], lst: NotGiven, - on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]], + on_elem: Callable[[P, T], Awaitable[List[T] | T]], ) -> NotGiven: ... @@ -123,7 +123,7 @@ async def traverse_list( async def traverse_list( create_elem_path: Callable[[int], P], lst: None, - on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]], + on_elem: Callable[[P, T], Awaitable[List[T] | T]], ) -> None: ... @@ -131,14 +131,14 @@ async def traverse_list( async def traverse_list( create_elem_path: Callable[[int], P], lst: List[T], - on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]], + on_elem: Callable[[P, T], Awaitable[List[T] | T]], ) -> List[T]: ... async def traverse_list( create_elem_path: Callable[[int], P], lst: List[T] | NotGiven | None, - on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]], + on_elem: Callable[[P, T], Awaitable[List[T] | T]], ) -> List[T] | NotGiven | None: if lst is None or isinstance(lst, NotGiven): return lst diff --git a/aidial_interceptors_sdk/examples/chat_completion/replicator.py b/aidial_interceptors_sdk/examples/chat_completion/replicator.py index ef57c46..84a48a4 100644 --- a/aidial_interceptors_sdk/examples/chat_completion/replicator.py +++ b/aidial_interceptors_sdk/examples/chat_completion/replicator.py @@ -1,8 +1,7 @@ from typing import ( - Any, AsyncIterator, + Awaitable, Callable, - Coroutine, Dict, List, Tuple, @@ -22,9 +21,11 @@ from aidial_interceptors_sdk.chat_completion.annotated_value import ( AnnotatedValue, + Annotation, ) from aidial_interceptors_sdk.chat_completion.base import ( ChatCompletionInterceptor, + RequestDict, ) from aidial_interceptors_sdk.chat_completion.element_path import ElementPath from aidial_interceptors_sdk.chat_completion.index_mapper import IndexMapper @@ -74,10 +75,10 @@ def _get_content_stage(self, path: ElementPath) -> Stage: @override async def call_upstreams( self, - request: dict, + request: RequestDict, call_upstream: Callable[ - [Any | None, dict], - Coroutine[Any, Any, AsyncIterator[dict | DialException]], + [Annotation, RequestDict], + Awaitable[AsyncIterator[dict | DialException]], ], ) -> AsyncIterator[AnnotatedValue]: request["n"] = 1 diff --git a/aidial_interceptors_sdk/utils/_debug.py b/aidial_interceptors_sdk/utils/_debug.py index a36f4a2..997e18e 100644 --- a/aidial_interceptors_sdk/utils/_debug.py +++ b/aidial_interceptors_sdk/utils/_debug.py @@ -1,6 +1,6 @@ import json import logging -from typing import Callable, Coroutine, TypeVar +from typing import Awaitable, Callable, TypeVar _log = logging.getLogger(__name__) @@ -9,23 +9,23 @@ def _debug(): return _log.isEnabledFor(logging.DEBUG) -A = TypeVar("A") -B = TypeVar("B") +_A = TypeVar("_A") +_B = TypeVar("_B") def debug_logging( title: str, ) -> Callable[ - [Callable[[A], Coroutine[None, None, B]]], - Callable[[A], Coroutine[None, None, B]], + [Callable[[_A], Awaitable[_B]]], + Callable[[_A], Awaitable[_B]], ]: def decorator( - fn: Callable[[A], Coroutine[None, None, B]] - ) -> Callable[[A], Coroutine[None, None, B]]: + fn: Callable[[_A], Awaitable[_B]] + ) -> Callable[[_A], Awaitable[_B]]: if not _debug(): return fn - async def _fn(a: A) -> B: + async def _fn(a: _A) -> _B: _log.debug(f"{title} old: {json.dumps(a)}") b = await fn(a) _log.debug(f"{title} new: {json.dumps(b)}") diff --git a/aidial_interceptors_sdk/utils/_reflection.py b/aidial_interceptors_sdk/utils/_reflection.py index 841cb40..cdf63c3 100644 --- a/aidial_interceptors_sdk/utils/_reflection.py +++ b/aidial_interceptors_sdk/utils/_reflection.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, Callable, Coroutine, TypeVar +from typing import Any, Awaitable, Callable, TypeVar from aidial_sdk.exceptions import InvalidRequestError @@ -7,7 +7,7 @@ async def call_with_extra_body( - func: Callable[..., Coroutine[Any, Any, T]], arg: dict + func: Callable[..., Awaitable[T]], arg: dict ) -> T: if _has_kwargs_argument(func): return await func(**arg) @@ -32,7 +32,7 @@ async def call_with_extra_body( return await func(**arg) -def _has_kwargs_argument(func: Callable[..., Coroutine[Any, Any, Any]]) -> bool: +def _has_kwargs_argument(func: Callable[..., Awaitable[Any]]) -> bool: """ Determines if the given function accepts a variable keyword argument (**kwargs). """ diff --git a/aidial_interceptors_sdk/utils/streaming.py b/aidial_interceptors_sdk/utils/streaming.py index da1e8c3..b1d2f16 100644 --- a/aidial_interceptors_sdk/utils/streaming.py +++ b/aidial_interceptors_sdk/utils/streaming.py @@ -7,6 +7,7 @@ AnnotatedChunk, AnnotatedException, AnnotatedValue, + Annotation, ) from aidial_interceptors_sdk.utils._exceptions import to_dial_exception @@ -32,7 +33,7 @@ async def materialize_streaming_errors( def annotate_stream( - annotation: Any | None, stream: AsyncIterator[dict | DialException] + annotation: Annotation, stream: AsyncIterator[dict | DialException] ) -> AsyncIterator[AnnotatedValue]: def _annotate(value: dict | DialException) -> AnnotatedValue: if isinstance(value, dict):