From ee7dda2feb31ccadbb8b9ed4dd9c8ce40d695f7d Mon Sep 17 00:00:00 2001 From: Damian Kalinowski Date: Fri, 9 Aug 2024 17:19:41 +0200 Subject: [PATCH] Cancel workload when OpenAI request client disconnects (both unary and streaming) (#2610) * Patch tensorflow net_http to allow for installing client disconnection callbacks * Use new genai, add tests, fix building without mediapipe, disconnect unary as well * Tests CVS-148134 Modifications to GenAI: openvinotoolkit/openvino.genai#732 --- ci/cppclean.sh | 2 +- external/partial.patch | 139 +++++++++++++++--- src/BUILD | 14 ++ src/client_connection.hpp | 27 ++++ src/http_frontend/http_client_connection.hpp | 46 ++++++ src/http_rest_api_handler.cpp | 2 + src/llm/BUILD | 1 + src/llm/http_llm_calculator.cc | 25 +++- src/llm/http_payload.hpp | 4 + src/test/http_openai_handler_test.cpp | 7 + src/test/llmnode_test.cpp | 99 +++++++++++++ src/test/llmtemplate_test.cpp | 17 ++- ...penai_chat_completions_mock_calculator.cpp | 6 + src/test/test_utils.hpp | 2 + third_party/llm_engine/llm_engine.bzl | 2 +- 15 files changed, 362 insertions(+), 31 deletions(-) create mode 100644 src/client_connection.hpp create mode 100644 src/http_frontend/http_client_connection.hpp diff --git a/ci/cppclean.sh b/ci/cppclean.sh index 9193f52893..c11095978e 100755 --- a/ci/cppclean.sh +++ b/ci/cppclean.sh @@ -36,7 +36,7 @@ echo "Number of warnings in tests about not direct includes:" ${NO_WARNINGS_TEST echo "Number of warnings in tests about not used: " ${NO_WARNINGS_TEST_NOTUSED} errors="" -if [ ${NO_WARNINGS_FORWARD} -gt 7 ]; then +if [ ${NO_WARNINGS_FORWARD} -gt 8 ]; then errors+="Failed due to not using forward declarations where possible: ${NO_WARNINGS_FORWARD}"$'\n' fi if [ ${NO_WARNINGS_DIRECT} -gt 15 ]; then diff --git a/external/partial.patch b/external/partial.patch index 17b881c264..f2cebad733 100644 --- a/external/partial.patch +++ b/external/partial.patch @@ -1,32 +1,99 @@ diff --git a/tensorflow_serving/util/net_http/server/internal/evhttp_request.cc b/tensorflow_serving/util/net_http/server/internal/evhttp_request.cc -index c8d0501b..f24a21d5 100644 +index c8d0501b..3c0a3df6 100644 --- a/tensorflow_serving/util/net_http/server/internal/evhttp_request.cc +++ b/tensorflow_serving/util/net_http/server/internal/evhttp_request.cc -@@ -342,8 +342,17 @@ void EvHTTPRequest::PartialReplyWithStatus(HTTPStatusCode status) { +@@ -17,6 +17,8 @@ limitations under the License. + + #include "tensorflow_serving/util/net_http/server/internal/evhttp_request.h" + ++#include ++ + #include + + #include +@@ -26,6 +28,8 @@ limitations under the License. + #include + #include + ++#include ++ + #include "absl/strings/match.h" + #include "absl/strings/string_view.h" + #include "absl/types/span.h" +@@ -118,18 +122,47 @@ bool ParsedEvRequest::decode() { + return true; + } + ++static void connection_close_callback(struct evhttp_connection *conn, void *arg) { ++ EvHTTPRequest* req = (EvHTTPRequest*)arg; ++ req->ExecuteDisconnectionCallback(); ++} ++ + EvHTTPRequest::EvHTTPRequest(std::unique_ptr request, + ServerSupport* server) + : server_(server), + parsed_request_(std::move(request)), +- output_buf(nullptr) {} ++ output_buf(nullptr) { ++ ++ struct evhttp_connection *conn = evhttp_request_get_connection(parsed_request_->request); ++ evhttp_connection_set_closecb(conn, connection_close_callback, (void*)this); ++ ++ struct bufferevent *bev = evhttp_connection_get_bufferevent(conn); ++ bufferevent_enable(bev, EV_READ); ++} + + EvHTTPRequest::~EvHTTPRequest() { ++ struct evhttp_connection *conn = evhttp_request_get_connection(parsed_request_->request); ++ if (conn != NULL) { ++ evhttp_connection_set_closecb(conn, NULL, NULL); ++ } ++ + if (output_buf != nullptr) { + evbuffer_free(output_buf); + } + } + ++void EvHTTPRequest::RegisterDisconnectionCallback(std::function callback) { ++ std::unique_lock lk(this->disconnection_mx_); ++ this->disconnected_callback_ = std::move(callback); ++} ++ ++void EvHTTPRequest::ExecuteDisconnectionCallback() { ++ std::unique_lock lk(this->disconnection_mx_); ++ this->is_disconnected_ = true; ++ if (this->disconnected_callback_) ++ this->disconnected_callback_(); ++} ++ + absl::string_view EvHTTPRequest::uri_path() const { + return parsed_request_->path_and_query; + } +@@ -342,8 +375,14 @@ void EvHTTPRequest::PartialReplyWithStatus(HTTPStatusCode status) { NET_LOG(FATAL, "PartialReplyWithStatus not implemented."); } -void EvHTTPRequest::PartialReply() { - NET_LOG(FATAL, "PartialReplyWithStatus not implemented."); +void EvHTTPRequest::PartialReply(std::string data) { -+ // TODO: Possibly avoid copy of data + bool result = -+ server_->EventLoopSchedule([this, data]() { EvPartialSendReply(data); }); -+ ++ server_->EventLoopSchedule([this, data = std::move(data)]() mutable { EvPartialSendReply(std::move(data)); }); ++ + if (!result) { + NET_LOG(ERROR, "Failed to EventLoopSchedule PartialReply()"); + Abort(); -+ // TODO(wenboz): should have a forced abort that doesn't write back anything -+ // to the event-loop + } } ServerRequestInterface::CallbackStatus -@@ -371,6 +380,25 @@ void EvHTTPRequest::EvSendReply(HTTPStatusCode status) { +@@ -371,6 +410,33 @@ void EvHTTPRequest::EvSendReply(HTTPStatusCode status) { delete this; } +void EvHTTPRequest::EvPartialSendReply(std::string data) { ++ std::unique_lock lk(this->disconnection_mx_); ++ if (this->is_disconnected_) ++ return; + if (!this->is_reply_started_) { + evhttp_send_reply_start(parsed_request_->request, HTTP_OK, "reply start"); + this->is_reply_started_ = true; @@ -36,11 +103,16 @@ index c8d0501b..f24a21d5 100644 +} + +void EvHTTPRequest::EvPartialReplyEnd() { -+ if (!this->is_reply_started_) { -+ // Start before we end can end the reply -+ evhttp_send_reply_start(parsed_request_->request, HTTP_OK, "no messages"); ++ std::unique_lock lk(this->disconnection_mx_); ++ if (!this->is_disconnected_) { ++ if (!this->is_reply_started_) { ++ // Start before we end can end the reply ++ evhttp_send_reply_start(parsed_request_->request, HTTP_OK, "no messages"); ++ } ++ ++ evhttp_send_reply_end(parsed_request_->request); + } -+ evhttp_send_reply_end(parsed_request_->request); ++ + server_->DecOps(); + delete this; +} @@ -48,7 +120,7 @@ index c8d0501b..f24a21d5 100644 void EvHTTPRequest::Reply() { ReplyWithStatus(HTTPStatusCode::OK); } // Treats this as 500 for now and let libevent decide what to do -@@ -381,6 +409,18 @@ void EvHTTPRequest::Abort() { +@@ -381,6 +447,16 @@ void EvHTTPRequest::Abort() { delete this; } @@ -59,8 +131,6 @@ index c8d0501b..f24a21d5 100644 + if (!result) { + NET_LOG(ERROR, "Failed to EventLoopSchedule PartialReplyEnd()"); + Abort(); -+ // TODO(wenboz): should have a forced abort that doesn't write back anything -+ // to the event-loop + } +} + @@ -68,10 +138,18 @@ index c8d0501b..f24a21d5 100644 } // namespace serving } // namespace tensorflow diff --git a/tensorflow_serving/util/net_http/server/internal/evhttp_request.h b/tensorflow_serving/util/net_http/server/internal/evhttp_request.h -index 2f8e601d..ff51c570 100644 +index 2f8e601d..562b38b8 100644 --- a/tensorflow_serving/util/net_http/server/internal/evhttp_request.h +++ b/tensorflow_serving/util/net_http/server/internal/evhttp_request.h -@@ -94,7 +94,7 @@ class EvHTTPRequest final : public ServerRequestInterface { +@@ -21,6 +21,7 @@ limitations under the License. + #include + #include + #include ++#include + + #include "tensorflow_serving/util/net_http/server/internal/server_support.h" + #include "tensorflow_serving/util/net_http/server/public/httpserver_interface.h" +@@ -94,7 +95,7 @@ class EvHTTPRequest final : public ServerRequestInterface { absl::string_view value) override; void PartialReplyWithStatus(HTTPStatusCode status) override; @@ -80,7 +158,7 @@ index 2f8e601d..ff51c570 100644 CallbackStatus PartialReplyWithFlushCallback( std::function callback) override; -@@ -104,6 +104,8 @@ class EvHTTPRequest final : public ServerRequestInterface { +@@ -104,16 +105,26 @@ class EvHTTPRequest final : public ServerRequestInterface { void Abort() override; @@ -89,8 +167,17 @@ index 2f8e601d..ff51c570 100644 // Initializes the resource and returns false if any error. bool Initialize(); -@@ -114,6 +116,8 @@ class EvHTTPRequest final : public ServerRequestInterface { ++ bool IsDisconnected() const override { return this->is_disconnected_; } ++ void RegisterDisconnectionCallback(std::function callback) override; ++ void ExecuteDisconnectionCallback(); ++ + // Keeps a reference to the registered RequestHandlerOptions + void SetHandlerOptions(const RequestHandlerOptions& handler_options) { + this->handler_options_ = &handler_options; + } ++ int64_t Id() { return (int64_t)((int*)this); } ++ private: void EvSendReply(HTTPStatusCode status); + void EvPartialSendReply(std::string data); @@ -98,17 +185,21 @@ index 2f8e601d..ff51c570 100644 // Returns true if the data needs be uncompressed bool NeedUncompressGzipContent(); -@@ -133,6 +137,8 @@ class EvHTTPRequest final : public ServerRequestInterface { +@@ -133,6 +144,12 @@ class EvHTTPRequest final : public ServerRequestInterface { std::unique_ptr parsed_request_; evbuffer* output_buf; // owned by this + + bool is_reply_started_{false}; ++ bool is_disconnected_{false}; ++ ++ std::function disconnected_callback_; ++ std::mutex disconnection_mx_; }; } // namespace net_http diff --git a/tensorflow_serving/util/net_http/server/public/server_request_interface.h b/tensorflow_serving/util/net_http/server/public/server_request_interface.h -index e5f4b05f..7077a6c1 100644 +index e5f4b05f..3205f0ab 100644 --- a/tensorflow_serving/util/net_http/server/public/server_request_interface.h +++ b/tensorflow_serving/util/net_http/server/public/server_request_interface.h @@ -144,7 +144,7 @@ class ServerRequestInterface { @@ -120,11 +211,15 @@ index e5f4b05f..7077a6c1 100644 // Similar to PartialReply() but with an on_flush callback which will be // invoked when the response data has been completely flushed by the -@@ -182,6 +182,8 @@ class ServerRequestInterface { +@@ -182,6 +182,12 @@ class ServerRequestInterface { // by the server runtime. virtual void Abort() = 0; + virtual void PartialReplyEnd() = 0; ++ ++ // Helpers for handling disconnection states ++ virtual bool IsDisconnected() const = 0; ++ virtual void RegisterDisconnectionCallback(std::function callback) = 0; + protected: ServerRequestInterface() = default; diff --git a/src/BUILD b/src/BUILD index 6242d7cd9a..0c97fd38e8 100644 --- a/src/BUILD +++ b/src/BUILD @@ -392,6 +392,7 @@ cc_library( "@mediapipe//mediapipe/calculators/geti/utils:emptylabel_calculators", "@mediapipe//mediapipe/calculators/geti/serialization:calculators", "//src/llm:httppayload", + "//src:libhttpclientconnection", ], "//:disable_mediapipe" : [], }), @@ -445,6 +446,19 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "libhttpclientconnection", + hdrs = ["client_connection.hpp", "http_frontend/http_client_connection.hpp"], + deps = [ + "@tensorflow_serving//tensorflow_serving/util/net_http/server/public:http_server", + ], + visibility = ["//visibility:public",], + local_defines = COMMON_LOCAL_DEFINES, + copts = COPTS_ADJUSTED, + linkopts = LINKOPTS_ADJUSTED, + alwayslink = 1, +) + cc_library( name = "libovmsprecision", hdrs = ["precision.hpp"], diff --git a/src/client_connection.hpp b/src/client_connection.hpp new file mode 100644 index 0000000000..2ac08099e6 --- /dev/null +++ b/src/client_connection.hpp @@ -0,0 +1,27 @@ +//***************************************************************************** +// Copyright 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include + +namespace ovms { + +class ClientConnection { +public: + virtual bool isDisconnected() const = 0; + virtual void registerDisconnectionCallback(std::function fn) = 0; +}; + +} // namespace ovms diff --git a/src/http_frontend/http_client_connection.hpp b/src/http_frontend/http_client_connection.hpp new file mode 100644 index 0000000000..b361f0f09d --- /dev/null +++ b/src/http_frontend/http_client_connection.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include + +#include "../client_connection.hpp" + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wall" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#include "tensorflow_serving/util/net_http/server/public/server_request_interface.h" +#pragma GCC diagnostic pop + +namespace ovms { + +class HttpClientConnection : public ClientConnection { + tensorflow::serving::net_http::ServerRequestInterface* serverReaderWriter; + +public: + HttpClientConnection(tensorflow::serving::net_http::ServerRequestInterface* serverReaderWriter) : + serverReaderWriter(serverReaderWriter) {} + + bool isDisconnected() const override { + return this->serverReaderWriter->IsDisconnected(); + } + + void registerDisconnectionCallback(std::function fn) override { + serverReaderWriter->RegisterDisconnectionCallback(std::move(fn)); + } +}; + +} // namespace ovms diff --git a/src/http_rest_api_handler.cpp b/src/http_rest_api_handler.cpp index 2421522083..1a8220aeaf 100644 --- a/src/http_rest_api_handler.cpp +++ b/src/http_rest_api_handler.cpp @@ -63,6 +63,7 @@ #include "timer.hpp" #if (MEDIAPIPE_DISABLE == 0) +#include "http_frontend/http_client_connection.hpp" #include "http_frontend/http_graph_executor_impl.hpp" #include "mediapipe_internal/mediapipegraphexecutor.hpp" #endif @@ -486,6 +487,7 @@ Status HttpRestApiHandler::processV3(const std::string_view uri, const HttpReque request.body = request_body; request.parsedJson = &doc; request.uri = std::string(uri); + request.client = std::make_shared(serverReaderWriter); } if (streamFieldVal == false) { ServableMetricReporter* smr = nullptr; // Unused diff --git a/src/llm/BUILD b/src/llm/BUILD index 563799e33a..f6f77e30d7 100644 --- a/src/llm/BUILD +++ b/src/llm/BUILD @@ -99,6 +99,7 @@ cc_library( hdrs = ["http_payload.hpp"], deps = [ "@com_github_tencent_rapidjson//:rapidjson", + "//src:libhttpclientconnection" ], visibility = ["//visibility:public"], copts = COPTS_ADJUSTED, diff --git a/src/llm/http_llm_calculator.cc b/src/llm/http_llm_calculator.cc index d8a1fff496..d31d949bba 100644 --- a/src/llm/http_llm_calculator.cc +++ b/src/llm/http_llm_calculator.cc @@ -387,6 +387,7 @@ class HttpLLMCalculator : public CalculatorBase { std::shared_ptr nodeResources; ov::genai::GenerationHandle generationHandle; std::shared_ptr request; + std::shared_ptr client; // TODO: To be moved to CB library std::shared_ptr streamer; @@ -447,6 +448,7 @@ class HttpLLMCalculator : public CalculatorBase { RET_CHECK(this->request == nullptr); RET_CHECK(this->generationHandle == nullptr); RET_CHECK(this->streamer == nullptr); + RET_CHECK(this->client == nullptr); // Register resource creation time this->created = std::chrono::system_clock::now(); @@ -463,6 +465,7 @@ class HttpLLMCalculator : public CalculatorBase { return absl::InvalidArgumentError("Wrong endpoint. Allowed endpoints: /v3/chat/completions, /v3/completions"); } this->request = std::make_shared(*payload.parsedJson, endpoint); + this->client = payload.client; // TODO: Support chat scenario once atobisze adds that to CB library auto status = this->request->parse(nodeResources->maxTokensLimit, nodeResources->bestOfLimit); @@ -493,10 +496,20 @@ class HttpLLMCalculator : public CalculatorBase { { OVMS_PROFILE_SCOPE("pipeline->add_request()"); + + // Check if client disconnected while waiting in HTTP requests queue + if (this->client->isDisconnected()) { + return absl::CancelledError(); + } + this->generationHandle = nodeResources->cbPipe->add_request( currentRequestId++, /*to be removed from API?*/ finalPrompt, this->request->createGenerationConfig()); + + this->client->registerDisconnectionCallback([genHandle = this->generationHandle]() { + genHandle->drop(); + }); } nodeResources->notifyExecutorThread(); this->streamer = std::make_shared( @@ -506,15 +519,21 @@ class HttpLLMCalculator : public CalculatorBase { RET_CHECK(this->generationHandle != nullptr); RET_CHECK(this->request != nullptr); RET_CHECK(this->streamer != nullptr); + RET_CHECK(this->client != nullptr); // Unary scenario if (!this->request->isStream()) { OVMS_PROFILE_SCOPE("Unary generation cycle"); + std::vector generationOutput = this->generationHandle->read_all(); + if (this->generationHandle->get_status() == ov::genai::GenerationStatus::DROPPED_BY_HANDLE) { + return absl::CancelledError(); + } RET_CHECK(generationOutput.size() >= 1); - std::sort(generationOutput.begin(), generationOutput.end(), [=](ov::genai::GenerationOutput& r1, ov::genai::GenerationOutput& r2) { + std::sort(generationOutput.begin(), generationOutput.end(), [](ov::genai::GenerationOutput& r1, ov::genai::GenerationOutput& r2) { return r1.score > r2.score; }); + // legacy if (generationOutput.size() == 1) { std::vector tokens = generationOutput[0].generated_token_ids; @@ -545,6 +564,10 @@ class HttpLLMCalculator : public CalculatorBase { // Streaming scenario // Each iteration is single execution of Process() method + if (this->generationHandle->get_status() == ov::genai::GenerationStatus::DROPPED_BY_HANDLE) { + return absl::CancelledError(); + } + if (this->generationHandle->get_status() == ov::genai::GenerationStatus::RUNNING || this->generationHandle->can_read()) { // Subsequent iteration OVMS_PROFILE_SCOPE("Generation of subsequent streaming response"); diff --git a/src/llm/http_payload.hpp b/src/llm/http_payload.hpp index 92ade71ab6..835083a774 100644 --- a/src/llm/http_payload.hpp +++ b/src/llm/http_payload.hpp @@ -15,12 +15,15 @@ //***************************************************************************** #pragma once +#include #include #include #include #include +#include "../client_connection.hpp" + namespace ovms { struct HttpPayload { @@ -28,6 +31,7 @@ struct HttpPayload { std::vector> headers; std::string body; // always rapidjson::Document* parsedJson; // pre-parsed body = null + std::shared_ptr client; }; } // namespace ovms diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index 77f43b7b9e..1650c21ff8 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -134,6 +134,7 @@ TEST_F(HttpOpenAIHandlerTest, Stream) { EXPECT_CALL(writer, PartialReplyEnd()).Times(1); EXPECT_CALL(writer, PartialReply(::testing::_)).Times(9); EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + EXPECT_CALL(writer, IsDisconnected()).Times(9); ASSERT_EQ( handler->dispatchToProcessor("/v3/test", requestBody, &response, comp, responseComponents, &writer), @@ -148,6 +149,7 @@ TEST_F(HttpOpenAIHandlerTest, BodyNotAJson) { EXPECT_CALL(writer, PartialReplyEnd()).Times(0); EXPECT_CALL(writer, PartialReply(::testing::_)).Times(0); EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + EXPECT_CALL(writer, IsDisconnected()).Times(0); auto status = handler->dispatchToProcessor("/v3/test", requestBody, &response, comp, responseComponents, &writer); ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID); @@ -160,6 +162,7 @@ TEST_F(HttpOpenAIHandlerTest, JsonBodyValidButNotAnObject) { EXPECT_CALL(writer, PartialReplyEnd()).Times(0); EXPECT_CALL(writer, PartialReply(::testing::_)).Times(0); EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + EXPECT_CALL(writer, IsDisconnected()).Times(0); auto status = handler->dispatchToProcessor("/v3/test", requestBody, &response, comp, responseComponents, &writer); ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID); @@ -177,6 +180,7 @@ TEST_F(HttpOpenAIHandlerTest, ModelFieldMissing) { EXPECT_CALL(writer, PartialReplyEnd()).Times(0); EXPECT_CALL(writer, PartialReply(::testing::_)).Times(0); EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + EXPECT_CALL(writer, IsDisconnected()).Times(0); auto status = handler->dispatchToProcessor("/v3/test", requestBody, &response, comp, responseComponents, &writer); ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID); @@ -195,6 +199,7 @@ TEST_F(HttpOpenAIHandlerTest, ModelFieldNotAString) { EXPECT_CALL(writer, PartialReplyEnd()).Times(0); EXPECT_CALL(writer, PartialReply(::testing::_)).Times(0); EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + EXPECT_CALL(writer, IsDisconnected()).Times(0); auto status = handler->dispatchToProcessor("/v3/test", requestBody, &response, comp, responseComponents, &writer); ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID); @@ -213,6 +218,7 @@ TEST_F(HttpOpenAIHandlerTest, StreamFieldNotABoolean) { EXPECT_CALL(writer, PartialReplyEnd()).Times(0); EXPECT_CALL(writer, PartialReply(::testing::_)).Times(0); EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + EXPECT_CALL(writer, IsDisconnected()).Times(0); auto status = handler->dispatchToProcessor("/v3/test", requestBody, &response, comp, responseComponents, &writer); ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID); @@ -231,6 +237,7 @@ TEST_F(HttpOpenAIHandlerTest, GraphWithANameDoesNotExist) { EXPECT_CALL(writer, PartialReplyEnd()).Times(0); EXPECT_CALL(writer, PartialReply(::testing::_)).Times(0); EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + EXPECT_CALL(writer, IsDisconnected()).Times(0); auto status = handler->dispatchToProcessor("/v3/test", requestBody, &response, comp, responseComponents, &writer); ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_DEFINITION_NAME_MISSING); diff --git a/src/test/llmnode_test.cpp b/src/test/llmnode_test.cpp index 4b5f35c41a..a336e56b3f 100644 --- a/src/test/llmnode_test.cpp +++ b/src/test/llmnode_test.cpp @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** +#include #include #include #include @@ -360,6 +361,7 @@ TEST_F(LLMFlowHttpTest, inferChatCompletionsStream) { // TODO: New output EXPECT_CALL(writer, PartialReplyEnd()).Times(1); // TODO: New output EXPECT_CALL(writer, PartialReply(::testing::_)).Times(3); // TODO: New output EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + // TODO: New output EXPECT_CALL(writer, IsDisconnected()).Times(6); // more than partial reply because of text streamer not always returning chunk of ready data ASSERT_EQ( handler->dispatchToProcessor(endpointChatCompletions, requestBody, &response, comp, responseComponents, &writer), ovms::StatusCode::PARTIAL_END); @@ -380,12 +382,109 @@ TEST_F(LLMFlowHttpTest, inferCompletionsStream) { // TODO: New output EXPECT_CALL(writer, PartialReplyEnd()).Times(1); // TODO: New output EXPECT_CALL(writer, PartialReply(::testing::_)).Times(3); // TODO: New output EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + // TODO: New output EXPECT_CALL(writer, IsDisconnected()).Times(6); // more than partial reply because of text streamer not always returning chunk of ready data ASSERT_EQ( handler->dispatchToProcessor(endpointCompletions, requestBody, &response, comp, responseComponents, &writer), ovms::StatusCode::PARTIAL_END); ASSERT_EQ(response, ""); } +// /v3/chat/completions endpoint +// unary, gready search +// Correct payload, however disconnection immediately +TEST_F(LLMFlowHttpTest, inferChatCompletionsUnaryClientDisconnectedImmediately) { + std::string requestBody = R"( + { + "model": "llmDummyKFS", + "stream": false, + "seed" : 1, + "max_tokens": 5, + "messages": [ + { + "role": "user", + "content": "What is OpenVINO?" + } + ] + } + )"; + + EXPECT_CALL(writer, RegisterDisconnectionCallback(::testing::_)).WillOnce([](std::function fn) { + fn(); // disconnect immediately, even before read_all is called + }); + ASSERT_EQ( + handler->dispatchToProcessor(endpointChatCompletions, requestBody, &response, comp, responseComponents, &writer), + ovms::StatusCode::MEDIAPIPE_EXECUTION_ERROR); +} + +// /v3/chat/completions endpoint +// streaming +// Correct payload, however disconnection immediately +TEST_F(LLMFlowHttpTest, inferChatCompletionsStreamClientDisconnectedImmediately) { + std::string requestBody = R"( + { + "model": "llmDummyKFS", + "stream": true, + "seed" : 1, + "max_tokens": 5, + "messages": [ + { + "role": "user", + "content": "What is OpenVINO?" + } + ] + } + )"; + + EXPECT_CALL(writer, IsDisconnected()) + .WillOnce(::testing::Return(true)); + + std::atomic i = 0; + EXPECT_CALL(writer, PartialReplyEnd()).Times(1); + EXPECT_CALL(writer, PartialReply(::testing::_)).WillOnce([this, &i](std::string partialResponse) { + i++; + ASSERT_EQ(partialResponse, "{\"error\": \"Mediapipe execution failed. MP status - CANCELLED: CalculatorGraph::Run() failed in Run: \nCalculator::Process() for node \"llmNode1\" failed: \"}"); + }); // no results + EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + + ASSERT_EQ( + handler->dispatchToProcessor(endpointChatCompletions, requestBody, &response, comp, responseComponents, &writer), + ovms::StatusCode::PARTIAL_END); + ASSERT_EQ(i, 1); + ASSERT_EQ(response, ""); +} + +// /v3/completions endpoint +// streaming +// Correct payload, however disconnection immediately +TEST_F(LLMFlowHttpTest, inferCompletionsStreamClientDisconnectedImmediately) { + std::string requestBody = R"( + { + "model": "llmDummyKFS", + "stream": true, + "seed" : 1, + "max_tokens": 5, + "prompt": "What is OpenVINO?" + } + )"; + + EXPECT_CALL(writer, IsDisconnected()) + .WillOnce(::testing::Return(true)); + + std::atomic i = 0; + EXPECT_CALL(writer, PartialReplyEnd()).Times(1); + EXPECT_CALL(writer, PartialReply(::testing::_)).WillOnce([this, &i](std::string partialResponse) { + i++; + ASSERT_EQ(partialResponse, "{\"error\": \"Mediapipe execution failed. MP status - CANCELLED: CalculatorGraph::Run() failed in Run: \nCalculator::Process() for node \"llmNode1\" failed: \"}"); + }); // no results + EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + + ASSERT_EQ( + handler->dispatchToProcessor(endpointCompletions, requestBody, &response, comp, responseComponents, &writer), + ovms::StatusCode::PARTIAL_END); + ASSERT_EQ(i, 1); + ASSERT_EQ(response, ""); +} + const std::string validRequestBodyWithParameter(const std::string& parameter, const std::string& value) { std::string requestBody = R"( { diff --git a/src/test/llmtemplate_test.cpp b/src/test/llmtemplate_test.cpp index c16d1cfa4c..9842269e6e 100644 --- a/src/test/llmtemplate_test.cpp +++ b/src/test/llmtemplate_test.cpp @@ -626,6 +626,9 @@ std::string fullResponse; // } class LLMJinjaChatTemplateHttpTest : public LLMChatTemplateHttpTest { +protected: + std::unique_ptr cleanupGuard; + void SetUp() { fullResponse = ""; TestWithTempDir::SetUp(); @@ -633,11 +636,12 @@ class LLMJinjaChatTemplateHttpTest : public LLMChatTemplateHttpTest { std::string jinjaTemplate = R"({{"What is OpenVINO" + messages[0]['content']}})"; ASSERT_EQ(CreateConfig(jinjaTemplate, jinjaConfigFilePath), true); LLMChatTemplateHttpTest::SetUp(); + + cleanupGuard = std::make_unique(directoryPath); } }; TEST_F(LLMJinjaChatTemplateHttpTest, inferChatCompletionsUnary) { - std::unique_ptr cleanupGuard = std::make_unique(directoryPath); std::string requestBody = R"( { "model": "llmDummyKFS", @@ -665,7 +669,6 @@ TEST_F(LLMJinjaChatTemplateHttpTest, inferChatCompletionsUnary) { } TEST_F(LLMJinjaChatTemplateHttpTest, inferCompletionsUnary) { - std::unique_ptr cleanupGuard = std::make_unique(directoryPath); std::string requestBody = R"( { "model": "llmDummyKFS", @@ -688,7 +691,6 @@ TEST_F(LLMJinjaChatTemplateHttpTest, inferCompletionsUnary) { } TEST_F(LLMJinjaChatTemplateHttpTest, inferChatCompletionsStream) { - std::unique_ptr cleanupGuard = std::make_unique(directoryPath); std::string requestBody = R"( { "model": "llmDummyKFS", @@ -714,8 +716,10 @@ TEST_F(LLMJinjaChatTemplateHttpTest, inferChatCompletionsStream) { auto modelOutput = choices.GetObject()["text"].GetString(); ConcatenateResponse(modelOutput); } - }); */ + }); + */ // TODO: New output EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + // TODO: New output EXPECT_CALL(writer, IsDisconnected()).Times(7); ASSERT_EQ( handler->dispatchToProcessor(endpointCompletions, requestBody, &response, comp, responseComponents, &writer), @@ -727,7 +731,6 @@ TEST_F(LLMJinjaChatTemplateHttpTest, inferChatCompletionsStream) { } TEST_F(LLMJinjaChatTemplateHttpTest, inferCompletionsStream) { - std::unique_ptr cleanupGuard = std::make_unique(directoryPath); std::string requestBody = R"( { "model": "llmDummyKFS", @@ -753,8 +756,10 @@ TEST_F(LLMJinjaChatTemplateHttpTest, inferCompletionsStream) { auto modelOutput = choices.GetObject()["text"].GetString(); ConcatenateResponse(modelOutput); } - }); */ + }); + */ // TODO: New output EXPECT_CALL(writer, WriteResponseString(::testing::_)).Times(0); + // TODO: New output EXPECT_CALL(writer, IsDisconnected()).Times(7); ASSERT_EQ( handler->dispatchToProcessor(endpointCompletions, requestBody, &response, comp, responseComponents, &writer), diff --git a/src/test/mediapipe/calculators/openai_chat_completions_mock_calculator.cpp b/src/test/mediapipe/calculators/openai_chat_completions_mock_calculator.cpp index 0d4d9a1028..70568d2d91 100644 --- a/src/test/mediapipe/calculators/openai_chat_completions_mock_calculator.cpp +++ b/src/test/mediapipe/calculators/openai_chat_completions_mock_calculator.cpp @@ -42,6 +42,7 @@ class OpenAIChatCompletionsMockCalculator : public CalculatorBase { mediapipe::Timestamp timestamp{0}; std::string body; + std::shared_ptr client; public: static absl::Status GetContract(CalculatorContract* cc) { @@ -74,6 +75,7 @@ class OpenAIChatCompletionsMockCalculator : public CalculatorBase { this->body += header.second; } this->body += data.body; + this->client = data.client; if (data.parsedJson != NULL) { rapidjson::StringBuffer buffer; buffer.Clear(); @@ -83,6 +85,10 @@ class OpenAIChatCompletionsMockCalculator : public CalculatorBase { } } + if (client->isDisconnected()) { + return absl::OkStatus(); + } + this->body += std::to_string(timestamp.Value()); // Fake workload diff --git a/src/test/test_utils.hpp b/src/test/test_utils.hpp index 6c65654fd9..18b6781a15 100644 --- a/src/test/test_utils.hpp +++ b/src/test/test_utils.hpp @@ -775,6 +775,8 @@ class MockedServerRequestInterface final : public tensorflow::serving::net_http: MOCK_METHOD(void, Reply, (), (override)); MOCK_METHOD(void, Abort, (), (override)); MOCK_METHOD(void, PartialReplyEnd, (), (override)); + MOCK_METHOD(bool, IsDisconnected, (), (const override)); + MOCK_METHOD(void, RegisterDisconnectionCallback, (std::function), (override)); }; /** diff --git a/third_party/llm_engine/llm_engine.bzl b/third_party/llm_engine/llm_engine.bzl index 74ce247ee7..3096b0905e 100644 --- a/third_party/llm_engine/llm_engine.bzl +++ b/third_party/llm_engine/llm_engine.bzl @@ -20,7 +20,7 @@ def llm_engine(): new_git_repository( name = "llm_engine", remote = "https://github.com/openvinotoolkit/openvino.genai", - commit = "50182b479c958718f1b9e907df07c4fb7a310462", # master + commit = "e5053526b2482df132a53d1a0e4304fc948ac741", # master build_file = "@_llm_engine//:BUILD", init_submodules = True, recursive_init_submodules = True,