diff --git a/src/cpp/include/openvino/genai/generation_handle.hpp b/src/cpp/include/openvino/genai/generation_handle.hpp index 46fc97c746..94be82fe9a 100644 --- a/src/cpp/include/openvino/genai/generation_handle.hpp +++ b/src/cpp/include/openvino/genai/generation_handle.hpp @@ -58,6 +58,8 @@ class GenerationStream; class OPENVINO_GENAI_EXPORTS GenerationHandleImpl { std::shared_ptr m_generation_stream; ov::genai::GenerationConfig m_sampling_params; + + bool is_dropped(); public: GenerationHandleImpl(std::shared_ptr generation_stream, const ov::genai::GenerationConfig& sampling_params) : @@ -74,6 +76,8 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl { bool can_read(); + void drop(); + GenerationOutputs back(); // Reads result of a generation for single iteration GenerationOutputs read(); @@ -81,5 +85,5 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl { std::vector read_all(); }; -using GenerationHandle = std::unique_ptr; +using GenerationHandle = std::shared_ptr; } diff --git a/src/cpp/src/continuous_batching_pipeline.cpp b/src/cpp/src/continuous_batching_pipeline.cpp index a2f639bf8c..13f84d9bee 100644 --- a/src/cpp/src/continuous_batching_pipeline.cpp +++ b/src/cpp/src/continuous_batching_pipeline.cpp @@ -60,6 +60,15 @@ class ContinuousBatchingPipeline::Impl { ChatHistory m_history; + void _notify_requests_dropped_by_handle() { + // Notify the last time by pushing empty output + // This causes read() to unblock by adding anything to the queue + for (SequenceGroup::Ptr& request : m_requests) { + if (request->handle_dropped()) + request->push_empty_outputs(); + } + } + void _free_non_running_requests() { std::vector::iterator requests_iterator = m_requests.begin(); while (requests_iterator != m_requests.end()) { @@ -136,7 +145,7 @@ class ContinuousBatchingPipeline::Impl { std::lock_guard lock{m_awaiting_requests_mutex}; m_awaiting_requests.push_back(sequence_group); } - return std::make_unique(sequence_group->get_generation_stream(), sampling_params); + return std::make_shared(sequence_group->get_generation_stream(), sampling_params); } GenerationHandle add_request(uint64_t request_id, const std::string& prompt, ov::genai::GenerationConfig sampling_params) { @@ -227,6 +236,15 @@ class ContinuousBatchingPipeline::Impl { timer.end(); } + // notify requests dropped by handle + + { + static ManualTimer timer("notify requests dropped by handle"); + timer.start(); + _notify_requests_dropped_by_handle(); + timer.end(); + } + // free non running requests for current step { diff --git a/src/cpp/src/generation_handle.cpp b/src/cpp/src/generation_handle.cpp index 26cc12604f..0bd96cc56a 100644 --- a/src/cpp/src/generation_handle.cpp +++ b/src/cpp/src/generation_handle.cpp @@ -9,7 +9,7 @@ using namespace ov::genai; GenerationHandleImpl::~GenerationHandleImpl() { - m_generation_stream->drop(); + drop(); } GenerationStatus GenerationHandleImpl::get_status() { @@ -17,14 +17,24 @@ GenerationStatus GenerationHandleImpl::get_status() { } bool GenerationHandleImpl::can_read() { - return m_generation_stream->can_read(); + return !is_dropped() && m_generation_stream->can_read(); +} + +bool GenerationHandleImpl::is_dropped() { + return get_status() == GenerationStatus::DROPPED_BY_HANDLE; +} + +void GenerationHandleImpl::drop() { + m_generation_stream->drop(); } std::unordered_map GenerationHandleImpl::back() { + OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped."); return m_generation_stream->back(); } std::unordered_map GenerationHandleImpl::read() { + OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped."); return m_generation_stream->read(); } @@ -41,6 +51,7 @@ void add_partial_result(std::unordered_map& partial_ } std::vector GenerationHandleImpl::read_all() { + OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped."); std::vector results; std::unordered_map partial_results; // We iterate until generation is running or there are tokens we haven't read yet diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 0a91a8d4a6..eba251528e 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -446,6 +446,10 @@ class SequenceGroup { return m_generation_stream->get_status() == GenerationStatus::DROPPED_BY_HANDLE; } + void push_empty_outputs() { + m_generation_stream->push({}); + } + void push_outputs() { GenerationOutputs outputs; for (auto& sequence: m_sequences) {