Skip to content

Commit

Permalink
fix: fixed outstanding issue with the exception tests; refactored cal…
Browse files Browse the repository at this point in the history
…l_upstreams
  • Loading branch information
adubovik committed Nov 26, 2024
1 parent 6e579b9 commit cb64f48
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 139 deletions.
112 changes: 58 additions & 54 deletions aidial_interceptors_sdk/chat_completion/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
from aidial_sdk.chat_completion import Request as DialRequest
from aidial_sdk.chat_completion import Response as DialResponse
from aidial_sdk.chat_completion.chunks import DefaultChunk
from aidial_sdk.exceptions import HTTPException as DialException
from openai import AsyncStream
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedException,
)
from aidial_interceptors_sdk.chat_completion.base import (
ChatCompletionInterceptor,
)
Expand All @@ -21,8 +25,8 @@
from aidial_interceptors_sdk.utils._reflection import call_with_extra_body
from aidial_interceptors_sdk.utils.streaming import (
block_response_to_streaming_chunk,
handle_streaming_errors,
map_stream,
materialize_streaming_errors,
singleton_stream,
)

Expand Down Expand Up @@ -63,68 +67,68 @@ async def chat_completion(
interceptor.traverse_request
)(request_body)

async def call_upstream(
request: dict, call_context: Any | None
) -> AsyncIterator[dict]:
upstream_response = cast(
AsyncStream[ChatCompletionChunk] | ChatCompletion,
await call_with_extra_body(
dial_client.client.chat.completions.create, request
),
)

if isinstance(upstream_response, ChatCompletion):
resp = upstream_response.to_dict()
if _debug():
_log.debug(
f"upstream response[{call_context}]: {json.dumps(resp)}"
)

# Block mode:
# Removing the default fields which are generated by
# DIAL SDK automatically.
# It also means that these fields aren't proxied from the upstream.
# They are recreated on each interceptor call.
# If the fields aren't removed, then they will be merged
# recursively with the one generated by SDK and we will end up with
# "object": "chat.completionchat.completionchat.completion"
for key in DefaultChunk.__annotations__.keys():
resp.pop(key, None)

chunk = block_response_to_streaming_chunk(resp)
stream = singleton_stream(chunk)
else:
# Streaming mode:
# No need to remove default fields, because
# they will be automatically overridden by the default fields
# generated by DIAL SDK, when each chunk is merged naively with
# a default chunk.

def on_upstream_chunk(chunk: ChatCompletionChunk) -> dict:
d = chunk.to_dict()
if _debug():
_log.debug(
f"upstream chunk[{call_context}]: {json.dumps(d)}"
)
return d

stream = map_stream(on_upstream_chunk, upstream_response)

return handle_streaming_errors(stream)

try:
await interceptor.on_stream_start()

async for chunk in await interceptor.call_upstreams(
def call_upstream(context: Any | None, request: dict):
return call_single_upstream(dial_client, context, request)

async for value in await interceptor.call_upstreams(
request_body, call_upstream
):
if "error" in chunk.chunk:
await interceptor.on_stream_error(chunk)
if isinstance(value, AnnotatedException):
await interceptor.on_stream_error(value)
else:
await interceptor.traverse_response_chunk(chunk)
await interceptor.traverse_response_chunk(value)

await interceptor.on_stream_end()
except EarlyStreamExit:
pass

return Impl()


async def call_single_upstream(
dial_client: DialClient, context: Any | None, request: dict
) -> AsyncIterator[dict | DialException]:
response = cast(
AsyncStream[ChatCompletionChunk] | ChatCompletion,
await call_with_extra_body(
dial_client.client.chat.completions.create, request
),
)

if isinstance(response, ChatCompletion):
resp = response.to_dict()
if _debug():
_log.debug(f"upstream response[{context}]: {json.dumps(resp)}")

# Block mode:
# Removing the default fields which are generated by
# DIAL SDK automatically.
# It also means that these fields aren't proxied from the upstream.
# They are recreated on each interceptor call.
# If the fields aren't removed, then they will be merged
# recursively with the one generated by SDK and we will end up with
# "object": "chat.completionchat.completionchat.completion"
for key in DefaultChunk.__annotations__.keys():
resp.pop(key, None)

chunk = block_response_to_streaming_chunk(resp)
stream = singleton_stream(chunk)
else:
# Streaming mode:
# No need to remove default fields, because
# they will be automatically overridden by the default fields
# generated by DIAL SDK, when each chunk is merged naively with
# a default chunk.

def on_upstream_chunk(chunk: ChatCompletionChunk) -> dict:
d = chunk.to_dict()
if _debug():
_log.debug(f"upstream chunk[{context}]: {json.dumps(d)}")
return d

stream = map_stream(on_upstream_chunk, response)

return materialize_streaming_errors(stream)
8 changes: 0 additions & 8 deletions aidial_interceptors_sdk/chat_completion/annotated_chunk.py

This file was deleted.

23 changes: 23 additions & 0 deletions aidial_interceptors_sdk/chat_completion/annotated_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import ABC
from typing import Any

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.pydantic_v1 import BaseModel


class AnnotatedValueBase(BaseModel, ABC):
class Config:
arbitrary_types_allowed = True

annotation: Any | None = None


class AnnotatedChunk(AnnotatedValueBase):
chunk: dict


class AnnotatedException(AnnotatedValueBase):
error: DialException


AnnotatedValue = AnnotatedChunk | AnnotatedException
27 changes: 15 additions & 12 deletions aidial_interceptors_sdk/chat_completion/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any, AsyncIterator, Callable, Coroutine

from aidial_interceptors_sdk.chat_completion.annotated_chunk import (
AnnotatedChunk,
from aidial_sdk.exceptions import HTTPException as DialException

from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedException,
AnnotatedValue,
)
from aidial_interceptors_sdk.chat_completion.request_handler import (
RequestHandler,
Expand All @@ -10,6 +13,7 @@
ResponseHandler,
)
from aidial_interceptors_sdk.dial_client import DialClient
from aidial_interceptors_sdk.utils.streaming import annotate_stream


class ChatCompletionInterceptor(RequestHandler, ResponseHandler):
Expand All @@ -19,15 +23,14 @@ async def call_upstreams(
self,
request: dict,
call_upstream: Callable[
[dict, Any | None], Coroutine[Any, Any, AsyncIterator[dict]]
[Any | None, dict],
Coroutine[Any, Any, AsyncIterator[dict | DialException]],
],
) -> AsyncIterator[AnnotatedChunk]:
async def iterator():
call_context = None
async for chunk in await call_upstream(request, call_context):
yield AnnotatedChunk(chunk=chunk, annotation=call_context)

return iterator()
) -> AsyncIterator[AnnotatedValue]:
annotation = None
return annotate_stream(
annotation, await call_upstream(annotation, request)
)

async def on_stream_start(self) -> None:
# TODO: it's probably worth to put all the chunks
Expand All @@ -37,8 +40,8 @@ async def on_stream_start(self) -> None:
# its "assistant" role is reported.
pass

async def on_stream_error(self, error: AnnotatedChunk) -> None:
self.send_chunk(error.chunk)
async def on_stream_error(self, error: AnnotatedException) -> None:
raise error.error

async def on_stream_end(self) -> None:
# TODO: it's probably worth to withhold the last chunk generated by
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aidial_sdk.chat_completion.chunks import BaseChunk
from aidial_sdk.pydantic_v1 import PrivateAttr

from aidial_interceptors_sdk.chat_completion.annotated_chunk import (
from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedChunk,
)
from aidial_interceptors_sdk.chat_completion.element_path import (
Expand Down
23 changes: 12 additions & 11 deletions aidial_interceptors_sdk/examples/chat_completion/replicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@
FinishReason,
UsageChunk,
)
from aidial_sdk.exceptions import HTTPException as DialException
from typing_extensions import override

from aidial_interceptors_sdk.chat_completion.annotated_chunk import (
AnnotatedChunk,
from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedValue,
)
from aidial_interceptors_sdk.chat_completion.base import (
ChatCompletionInterceptor,
)
from aidial_interceptors_sdk.chat_completion.element_path import ElementPath
from aidial_interceptors_sdk.chat_completion.index_mapper import IndexMapper
from aidial_interceptors_sdk.utils.not_given import NotGiven
from aidial_interceptors_sdk.utils.streaming import annotate_stream


class ReplicatorInterceptor(ChatCompletionInterceptor):
Expand Down Expand Up @@ -74,19 +76,18 @@ async def call_upstreams(
self,
request: dict,
call_upstream: Callable[
[dict, Any | None], Coroutine[Any, Any, AsyncIterator[dict]]
[Any | None, dict],
Coroutine[Any, Any, AsyncIterator[dict | DialException]],
],
) -> AsyncIterator[AnnotatedChunk]:
) -> AsyncIterator[AnnotatedValue]:
request["n"] = 1

async def get_iterator(idx: int) -> AsyncIterator[AnnotatedChunk]:
call_context = idx
async for chunk in await call_upstream(request, call_context):
yield AnnotatedChunk(chunk=chunk, annotation=call_context)

iterators = [get_iterator(idx) for idx in range(self.n)]
streams = [
annotate_stream(idx, await call_upstream(idx, request))
for idx in range(self.n)
]
# TODO: create tasks
return _join_iterators(iterators)
return _join_iterators(streams)

@override
async def on_response_stage(
Expand Down
26 changes: 22 additions & 4 deletions aidial_interceptors_sdk/utils/streaming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import logging
from typing import Any, AsyncIterator, Callable, Optional, TypeVar

from aidial_sdk.exceptions import HTTPException as DialException

from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedChunk,
AnnotatedException,
AnnotatedValue,
)
from aidial_interceptors_sdk.utils._exceptions import to_dial_exception

_log = logging.getLogger(__name__)
Expand All @@ -9,9 +16,9 @@
_V = TypeVar("_V")


async def handle_streaming_errors(
async def materialize_streaming_errors(
stream: AsyncIterator[dict],
) -> AsyncIterator[dict]:
) -> AsyncIterator[dict | DialException]:

try:
async for chunk in stream:
Expand All @@ -21,8 +28,19 @@ async def handle_streaming_errors(
f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}"
)

dial_exception = to_dial_exception(e)
yield dial_exception.json_error()
yield to_dial_exception(e)


def annotate_stream(
annotation: Any | None, stream: AsyncIterator[dict | DialException]
) -> AsyncIterator[AnnotatedValue]:
def _annotate(value: dict | DialException) -> AnnotatedValue:
if isinstance(value, dict):
return AnnotatedChunk(chunk=value, annotation=annotation)
else:
return AnnotatedException(error=value, annotation=annotation)

return map_stream(_annotate, stream)


# TODO: add to SDK as a inverse of cleanup_indices
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ httpx = ">=0.25.0,<1.0"
openai = ">=1.32.0,<2.0"
# FIXME: revert to a release version
# aidial-sdk = { version = "^0.15.0", extras = ["telemetry"] }
aidial-sdk = { git = "https://github.com/epam/ai-dial-sdk.git", branch = "feat/support-headers-in-dial-exception", extras = ["telemetry"] }
aidial-sdk = { git = "https://github.com/epam/ai-dial-sdk.git", branch = "development", extras = ["telemetry"] }

# Extras for examples
aiostream = { version = "^0.6.2", optional = true }
Expand Down
Loading

0 comments on commit cb64f48

Please sign in to comment.