Skip to content

Commit

Permalink
Cancel workload when OpenAI request client disconnects (both unary an…
Browse files Browse the repository at this point in the history
…d streaming) (#2610)

* Patch tensorflow net_http to allow for installing client disconnection callbacks
* Use new genai, add tests, fix building without mediapipe, disconnect unary as well
* Tests

CVS-148134

Modifications to GenAI:
openvinotoolkit/openvino.genai#732
  • Loading branch information
dkalinowski authored Aug 9, 2024
1 parent 03579fe commit ee7dda2
Show file tree
Hide file tree
Showing 15 changed files with 362 additions and 31 deletions.
2 changes: 1 addition & 1 deletion ci/cppclean.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ echo "Number of warnings in tests about not direct includes:" ${NO_WARNINGS_TEST
echo "Number of warnings in tests about not used: " ${NO_WARNINGS_TEST_NOTUSED}

errors=""
if [ ${NO_WARNINGS_FORWARD} -gt 7 ]; then
if [ ${NO_WARNINGS_FORWARD} -gt 8 ]; then
errors+="Failed due to not using forward declarations where possible: ${NO_WARNINGS_FORWARD}"$'\n'
fi
if [ ${NO_WARNINGS_DIRECT} -gt 15 ]; then
Expand Down
139 changes: 117 additions & 22 deletions external/partial.patch
Original file line number Diff line number Diff line change
@@ -1,32 +1,99 @@
diff --git a/tensorflow_serving/util/net_http/server/internal/evhttp_request.cc b/tensorflow_serving/util/net_http/server/internal/evhttp_request.cc
index c8d0501b..f24a21d5 100644
index c8d0501b..3c0a3df6 100644
--- a/tensorflow_serving/util/net_http/server/internal/evhttp_request.cc
+++ b/tensorflow_serving/util/net_http/server/internal/evhttp_request.cc
@@ -342,8 +342,17 @@ void EvHTTPRequest::PartialReplyWithStatus(HTTPStatusCode status) {
@@ -17,6 +17,8 @@ limitations under the License.

#include "tensorflow_serving/util/net_http/server/internal/evhttp_request.h"

+#include <mutex>
+
#include <zlib.h>

#include <cassert>
@@ -26,6 +28,8 @@ limitations under the License.
#include <string>
#include <vector>

+#include <event.h>
+
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
@@ -118,18 +122,47 @@ bool ParsedEvRequest::decode() {
return true;
}

+static void connection_close_callback(struct evhttp_connection *conn, void *arg) {
+ EvHTTPRequest* req = (EvHTTPRequest*)arg;
+ req->ExecuteDisconnectionCallback();
+}
+
EvHTTPRequest::EvHTTPRequest(std::unique_ptr<ParsedEvRequest> request,
ServerSupport* server)
: server_(server),
parsed_request_(std::move(request)),
- output_buf(nullptr) {}
+ output_buf(nullptr) {
+
+ struct evhttp_connection *conn = evhttp_request_get_connection(parsed_request_->request);
+ evhttp_connection_set_closecb(conn, connection_close_callback, (void*)this);
+
+ struct bufferevent *bev = evhttp_connection_get_bufferevent(conn);
+ bufferevent_enable(bev, EV_READ);
+}

EvHTTPRequest::~EvHTTPRequest() {
+ struct evhttp_connection *conn = evhttp_request_get_connection(parsed_request_->request);
+ if (conn != NULL) {
+ evhttp_connection_set_closecb(conn, NULL, NULL);
+ }
+
if (output_buf != nullptr) {
evbuffer_free(output_buf);
}
}

+void EvHTTPRequest::RegisterDisconnectionCallback(std::function<void()> callback) {
+ std::unique_lock<std::mutex> lk(this->disconnection_mx_);
+ this->disconnected_callback_ = std::move(callback);
+}
+
+void EvHTTPRequest::ExecuteDisconnectionCallback() {
+ std::unique_lock<std::mutex> lk(this->disconnection_mx_);
+ this->is_disconnected_ = true;
+ if (this->disconnected_callback_)
+ this->disconnected_callback_();
+}
+
absl::string_view EvHTTPRequest::uri_path() const {
return parsed_request_->path_and_query;
}
@@ -342,8 +375,14 @@ void EvHTTPRequest::PartialReplyWithStatus(HTTPStatusCode status) {
NET_LOG(FATAL, "PartialReplyWithStatus not implemented.");
}

-void EvHTTPRequest::PartialReply() {
- NET_LOG(FATAL, "PartialReplyWithStatus not implemented.");
+void EvHTTPRequest::PartialReply(std::string data) {
+ // TODO: Possibly avoid copy of data
+ bool result =
+ server_->EventLoopSchedule([this, data]() { EvPartialSendReply(data); });
+
+ server_->EventLoopSchedule([this, data = std::move(data)]() mutable { EvPartialSendReply(std::move(data)); });
+
+ if (!result) {
+ NET_LOG(ERROR, "Failed to EventLoopSchedule PartialReply()");
+ Abort();
+ // TODO(wenboz): should have a forced abort that doesn't write back anything
+ // to the event-loop
+ }
}

ServerRequestInterface::CallbackStatus
@@ -371,6 +380,25 @@ void EvHTTPRequest::EvSendReply(HTTPStatusCode status) {
@@ -371,6 +410,33 @@ void EvHTTPRequest::EvSendReply(HTTPStatusCode status) {
delete this;
}

+void EvHTTPRequest::EvPartialSendReply(std::string data) {
+ std::unique_lock<std::mutex> lk(this->disconnection_mx_);
+ if (this->is_disconnected_)
+ return;
+ if (!this->is_reply_started_) {
+ evhttp_send_reply_start(parsed_request_->request, HTTP_OK, "reply start");
+ this->is_reply_started_ = true;
Expand All @@ -36,19 +103,24 @@ index c8d0501b..f24a21d5 100644
+}
+
+void EvHTTPRequest::EvPartialReplyEnd() {
+ if (!this->is_reply_started_) {
+ // Start before we end can end the reply
+ evhttp_send_reply_start(parsed_request_->request, HTTP_OK, "no messages");
+ std::unique_lock<std::mutex> lk(this->disconnection_mx_);
+ if (!this->is_disconnected_) {
+ if (!this->is_reply_started_) {
+ // Start before we end can end the reply
+ evhttp_send_reply_start(parsed_request_->request, HTTP_OK, "no messages");
+ }
+
+ evhttp_send_reply_end(parsed_request_->request);
+ }
+ evhttp_send_reply_end(parsed_request_->request);
+
+ server_->DecOps();
+ delete this;
+}
+
void EvHTTPRequest::Reply() { ReplyWithStatus(HTTPStatusCode::OK); }

// Treats this as 500 for now and let libevent decide what to do
@@ -381,6 +409,18 @@ void EvHTTPRequest::Abort() {
@@ -381,6 +447,16 @@ void EvHTTPRequest::Abort() {
delete this;
}

Expand All @@ -59,19 +131,25 @@ index c8d0501b..f24a21d5 100644
+ if (!result) {
+ NET_LOG(ERROR, "Failed to EventLoopSchedule PartialReplyEnd()");
+ Abort();
+ // TODO(wenboz): should have a forced abort that doesn't write back anything
+ // to the event-loop
+ }
+}
+
} // namespace net_http
} // namespace serving
} // namespace tensorflow
diff --git a/tensorflow_serving/util/net_http/server/internal/evhttp_request.h b/tensorflow_serving/util/net_http/server/internal/evhttp_request.h
index 2f8e601d..ff51c570 100644
index 2f8e601d..562b38b8 100644
--- a/tensorflow_serving/util/net_http/server/internal/evhttp_request.h
+++ b/tensorflow_serving/util/net_http/server/internal/evhttp_request.h
@@ -94,7 +94,7 @@ class EvHTTPRequest final : public ServerRequestInterface {
@@ -21,6 +21,7 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <string>
+#include <mutex>

#include "tensorflow_serving/util/net_http/server/internal/server_support.h"
#include "tensorflow_serving/util/net_http/server/public/httpserver_interface.h"
@@ -94,7 +95,7 @@ class EvHTTPRequest final : public ServerRequestInterface {
absl::string_view value) override;

void PartialReplyWithStatus(HTTPStatusCode status) override;
Expand All @@ -80,7 +158,7 @@ index 2f8e601d..ff51c570 100644

CallbackStatus PartialReplyWithFlushCallback(
std::function<void()> callback) override;
@@ -104,6 +104,8 @@ class EvHTTPRequest final : public ServerRequestInterface {
@@ -104,16 +105,26 @@ class EvHTTPRequest final : public ServerRequestInterface {

void Abort() override;

Expand All @@ -89,26 +167,39 @@ index 2f8e601d..ff51c570 100644
// Initializes the resource and returns false if any error.
bool Initialize();

@@ -114,6 +116,8 @@ class EvHTTPRequest final : public ServerRequestInterface {
+ bool IsDisconnected() const override { return this->is_disconnected_; }
+ void RegisterDisconnectionCallback(std::function<void()> callback) override;
+ void ExecuteDisconnectionCallback();
+
// Keeps a reference to the registered RequestHandlerOptions
void SetHandlerOptions(const RequestHandlerOptions& handler_options) {
this->handler_options_ = &handler_options;
}

+ int64_t Id() { return (int64_t)((int*)this); }
+
private:
void EvSendReply(HTTPStatusCode status);
+ void EvPartialSendReply(std::string data);
+ void EvPartialReplyEnd();

// Returns true if the data needs be uncompressed
bool NeedUncompressGzipContent();
@@ -133,6 +137,8 @@ class EvHTTPRequest final : public ServerRequestInterface {
@@ -133,6 +144,12 @@ class EvHTTPRequest final : public ServerRequestInterface {
std::unique_ptr<ParsedEvRequest> parsed_request_;

evbuffer* output_buf; // owned by this
+
+ bool is_reply_started_{false};
+ bool is_disconnected_{false};
+
+ std::function<void()> disconnected_callback_;
+ std::mutex disconnection_mx_;
};

} // namespace net_http
diff --git a/tensorflow_serving/util/net_http/server/public/server_request_interface.h b/tensorflow_serving/util/net_http/server/public/server_request_interface.h
index e5f4b05f..7077a6c1 100644
index e5f4b05f..3205f0ab 100644
--- a/tensorflow_serving/util/net_http/server/public/server_request_interface.h
+++ b/tensorflow_serving/util/net_http/server/public/server_request_interface.h
@@ -144,7 +144,7 @@ class ServerRequestInterface {
Expand All @@ -120,11 +211,15 @@ index e5f4b05f..7077a6c1 100644

// Similar to PartialReply() but with an on_flush callback which will be
// invoked when the response data has been completely flushed by the
@@ -182,6 +182,8 @@ class ServerRequestInterface {
@@ -182,6 +182,12 @@ class ServerRequestInterface {
// by the server runtime.
virtual void Abort() = 0;

+ virtual void PartialReplyEnd() = 0;
+
+ // Helpers for handling disconnection states
+ virtual bool IsDisconnected() const = 0;
+ virtual void RegisterDisconnectionCallback(std::function<void()> callback) = 0;
+
protected:
ServerRequestInterface() = default;
Expand Down
14 changes: 14 additions & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ cc_library(
"@mediapipe//mediapipe/calculators/geti/utils:emptylabel_calculators",
"@mediapipe//mediapipe/calculators/geti/serialization:calculators",
"//src/llm:httppayload",
"//src:libhttpclientconnection",
],
"//:disable_mediapipe" : [],
}),
Expand Down Expand Up @@ -445,6 +446,19 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "libhttpclientconnection",
hdrs = ["client_connection.hpp", "http_frontend/http_client_connection.hpp"],
deps = [
"@tensorflow_serving//tensorflow_serving/util/net_http/server/public:http_server",
],
visibility = ["//visibility:public",],
local_defines = COMMON_LOCAL_DEFINES,
copts = COPTS_ADJUSTED,
linkopts = LINKOPTS_ADJUSTED,
alwayslink = 1,
)

cc_library(
name = "libovmsprecision",
hdrs = ["precision.hpp"],
Expand Down
27 changes: 27 additions & 0 deletions src/client_connection.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//*****************************************************************************
// Copyright 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <functional>

namespace ovms {

class ClientConnection {
public:
virtual bool isDisconnected() const = 0;
virtual void registerDisconnectionCallback(std::function<void()> fn) = 0;
};

} // namespace ovms
46 changes: 46 additions & 0 deletions src/http_frontend/http_client_connection.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//*****************************************************************************
// Copyright 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once

#include <utility>

#include "../client_connection.hpp"

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wall"
#pragma GCC diagnostic ignored "-Wunused-but-set-variable"
#include "tensorflow_serving/util/net_http/server/public/server_request_interface.h"
#pragma GCC diagnostic pop

namespace ovms {

class HttpClientConnection : public ClientConnection {
tensorflow::serving::net_http::ServerRequestInterface* serverReaderWriter;

public:
HttpClientConnection(tensorflow::serving::net_http::ServerRequestInterface* serverReaderWriter) :
serverReaderWriter(serverReaderWriter) {}

bool isDisconnected() const override {
return this->serverReaderWriter->IsDisconnected();
}

void registerDisconnectionCallback(std::function<void()> fn) override {
serverReaderWriter->RegisterDisconnectionCallback(std::move(fn));
}
};

} // namespace ovms
2 changes: 2 additions & 0 deletions src/http_rest_api_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#include "timer.hpp"

#if (MEDIAPIPE_DISABLE == 0)
#include "http_frontend/http_client_connection.hpp"
#include "http_frontend/http_graph_executor_impl.hpp"
#include "mediapipe_internal/mediapipegraphexecutor.hpp"
#endif
Expand Down Expand Up @@ -486,6 +487,7 @@ Status HttpRestApiHandler::processV3(const std::string_view uri, const HttpReque
request.body = request_body;
request.parsedJson = &doc;
request.uri = std::string(uri);
request.client = std::make_shared<HttpClientConnection>(serverReaderWriter);
}
if (streamFieldVal == false) {
ServableMetricReporter* smr = nullptr; // Unused
Expand Down
1 change: 1 addition & 0 deletions src/llm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ cc_library(
hdrs = ["http_payload.hpp"],
deps = [
"@com_github_tencent_rapidjson//:rapidjson",
"//src:libhttpclientconnection"
],
visibility = ["//visibility:public"],
copts = COPTS_ADJUSTED,
Expand Down
Loading

0 comments on commit ee7dda2

Please sign in to comment.