Skip to content

Commit

Permalink
Add an option to drop the request (#732)
Browse files Browse the repository at this point in the history
This enables to drop user request in case the client is disconnected
(when used in OVMS).

OVMS commit using this:
openvinotoolkit/model_server#2610
  • Loading branch information
dkalinowski authored Aug 6, 2024
1 parent eb248db commit 0be2620
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 4 deletions.
6 changes: 5 additions & 1 deletion src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class GenerationStream;
class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
std::shared_ptr<GenerationStream> m_generation_stream;
ov::genai::GenerationConfig m_sampling_params;

bool is_dropped();

public:
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const ov::genai::GenerationConfig& sampling_params) :
Expand All @@ -81,12 +83,14 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {

bool can_read();

void drop();

GenerationOutputs back();
// Reads result of a generation for single iteration
GenerationOutputs read();
// Reads all generated tokens for all sequences
std::vector<GenerationOutput> read_all();
};

using GenerationHandle = std::unique_ptr<GenerationHandleImpl>;
using GenerationHandle = std::shared_ptr<GenerationHandleImpl>;
}
20 changes: 19 additions & 1 deletion src/cpp/src/continuous_batching_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ class ContinuousBatchingPipeline::Impl {
ChatHistory m_history;


void _notify_requests_dropped_by_handle() {
// Notify the last time by pushing empty output
// This causes read() to unblock by adding anything to the queue
for (SequenceGroup::Ptr& request : m_requests) {
if (request->handle_dropped())
request->push_empty_outputs();
}
}

void _free_non_running_requests() {
std::vector<SequenceGroup::Ptr>::iterator requests_iterator = m_requests.begin();
while (requests_iterator != m_requests.end()) {
Expand Down Expand Up @@ -136,7 +145,7 @@ class ContinuousBatchingPipeline::Impl {
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
m_awaiting_requests.push_back(sequence_group);
}
return std::make_unique<GenerationHandleImpl>(sequence_group->get_generation_stream(), sampling_params);
return std::make_shared<GenerationHandleImpl>(sequence_group->get_generation_stream(), sampling_params);
}

GenerationHandle add_request(uint64_t request_id, const std::string& prompt, ov::genai::GenerationConfig sampling_params) {
Expand Down Expand Up @@ -227,6 +236,15 @@ class ContinuousBatchingPipeline::Impl {
timer.end();
}

// notify requests dropped by handle

{
static ManualTimer timer("notify requests dropped by handle");
timer.start();
_notify_requests_dropped_by_handle();
timer.end();
}

// free non running requests for current step

{
Expand Down
15 changes: 13 additions & 2 deletions src/cpp/src/generation_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,32 @@
using namespace ov::genai;

GenerationHandleImpl::~GenerationHandleImpl() {
m_generation_stream->drop();
drop();
}

GenerationStatus GenerationHandleImpl::get_status() {
return m_generation_stream->get_status();
}

bool GenerationHandleImpl::can_read() {
return m_generation_stream->can_read();
return !is_dropped() && m_generation_stream->can_read();
}

bool GenerationHandleImpl::is_dropped() {
return get_status() == GenerationStatus::DROPPED_BY_HANDLE;
}

void GenerationHandleImpl::drop() {
m_generation_stream->drop();
}

std::unordered_map<uint64_t, GenerationOutput> GenerationHandleImpl::back() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
return m_generation_stream->back();
}

std::unordered_map<uint64_t, GenerationOutput> GenerationHandleImpl::read() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
return m_generation_stream->read();
}

Expand All @@ -42,6 +52,7 @@ void add_partial_result(std::unordered_map<uint64_t, GenerationOutput>& partial_
}

std::vector<GenerationOutput> GenerationHandleImpl::read_all() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
std::vector<GenerationOutput> results;
std::unordered_map<uint64_t, GenerationOutput> partial_results;
// We iterate until generation is running or there are tokens we haven't read yet
Expand Down
4 changes: 4 additions & 0 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,10 @@ class SequenceGroup {
return m_generation_stream->get_status() == GenerationStatus::DROPPED_BY_HANDLE;
}

void push_empty_outputs() {
m_generation_stream->push({});
}

void push_outputs() {
GenerationOutputs outputs;
for (auto& sequence: m_sequences) {
Expand Down

0 comments on commit 0be2620

Please sign in to comment.