Skip to content

Commit

Permalink
Merge pull request #16 from DataDog/jan/HIS-5143-enqueue-wakers
Browse files Browse the repository at this point in the history
[HIS-5143] Enqueue wakers
  • Loading branch information
gesundkrank authored Nov 19, 2024
2 parents cdb50e0 + a0a1478 commit bd70a6e
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 43 deletions.
5 changes: 2 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ log = "0.4.8"
serde = { version = "1.0.0", features = ["derive"] }
serde_derive = "1.0.0"
serde_json = "1.0.0"
slab = "0.4"
tokio = { version = "1.18", features = ["rt", "time"], optional = true }
tracing = { version = "0.1.30", optional = true }

Expand Down
91 changes: 52 additions & 39 deletions src/consumer/stream_consumer.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
//! High-level consumers with a [`Stream`](futures_util::Stream) interface.
use std::collections::VecDeque;
use std::marker::PhantomData;
use std::os::raw::c_void;
use std::pin::Pin;
use std::ptr;
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll, Waker};
use std::time::Duration;

Expand All @@ -13,11 +15,9 @@ use futures_channel::oneshot;
use futures_util::future::{self, Either, FutureExt};
use futures_util::pin_mut;
use futures_util::stream::{Stream, StreamExt};
use slab::Slab;

use rdkafka_sys as rdsys;
use rdkafka_sys::types::*;

use crate::client::{Client, EventPollResult, NativeQueue};
use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext};
use crate::consumer::base_consumer::{BaseConsumer, PartitionQueue};
Expand All @@ -33,11 +33,11 @@ use crate::topic_partition_list::{Offset, TopicPartitionList};
use crate::util::{AsyncRuntime, DefaultRuntime, Timeout};

unsafe extern "C" fn native_message_queue_nonempty_cb(_: *mut RDKafka, opaque_ptr: *mut c_void) {
let wakers = &*(opaque_ptr as *const WakerSlab);
let wakers = &*(opaque_ptr as *const WakerQueue);
wakers.wake_all();
}

unsafe fn enable_nonempty_callback(queue: &NativeQueue, wakers: &Arc<WakerSlab>) {
unsafe fn enable_nonempty_callback(queue: &NativeQueue, wakers: &Arc<WakerQueue>) {
rdsys::rd_kafka_queue_cb_event_enable(
queue.ptr(),
Some(native_message_queue_nonempty_cb),
Expand All @@ -49,76 +49,89 @@ unsafe fn disable_nonempty_callback(queue: &NativeQueue) {
rdsys::rd_kafka_queue_cb_event_enable(queue.ptr(), None, ptr::null_mut())
}

struct WakerSlab {
wakers: Mutex<Slab<Option<Waker>>>,
/// Data structure to store waker instances from idle streams.
/// This queue wakes up the wakers in FIFO order.
/// On registration streams receive an id they can reuse for setting and removing wakers.
///
/// Implementation is optimized for adding new wakers and waking all of them.
/// Removing a single waker is less efficient and has a complexity of O(len(waker_queue)).
struct WakerQueue {
/// Vec containing the set waker instances and the ids of their owning streams
waker_queue: Mutex<VecDeque<(usize, Waker)>>,
/// Counts the number of registered clients. This used to generate ids on registration
num_registered: AtomicUsize,
}

impl WakerSlab {
fn new() -> WakerSlab {
WakerSlab {
wakers: Mutex::new(Slab::new()),
impl WakerQueue {
fn new() -> WakerQueue {
WakerQueue {
waker_queue: Mutex::new(VecDeque::new()),
num_registered: AtomicUsize::new(0),
}
}

/// Awakes all set wakers in order of their arrival
fn wake_all(&self) {
let mut wakers = self.wakers.lock().unwrap();
for (_, waker) in wakers.iter_mut() {
if let Some(waker) = waker.take() {
waker.wake();
}
let mut waker_queue = self.waker_queue.lock().expect("lock poisoned");
while let Some((_, waker)) = waker_queue.pop_front() {
waker.wake();
}
}

/// Returns a unique id that can be reused by the caller to set and unset its waker
fn register(&self) -> usize {
let mut wakers = self.wakers.lock().expect("lock poisoned");
wakers.insert(None)
self.num_registered.fetch_add(1, Ordering::Relaxed)
}

fn unregister(&self, slot: usize) {
let mut wakers = self.wakers.lock().expect("lock poisoned");
wakers.remove(slot);
/// Removes the waker with the given id from the queue
fn unset_waker(&self, id: usize) {
let mut waker_queue = self.waker_queue.lock().expect("lock poisoned");
if let Some(index) = waker_queue.iter().position(|(i, _)| *i == id) {
waker_queue.remove(index);
}
}

fn set_waker(&self, slot: usize, waker: Waker) {
let mut wakers = self.wakers.lock().expect("lock poisoned");
wakers[slot] = Some(waker);
/// Add the given waker to the end of the queue
fn set_waker(&self, id: usize, waker: Waker) {
let mut wakers = self.waker_queue.lock().expect("lock poisoned");
wakers.push_back((id, waker));
}
}

/// A stream of messages from a [`StreamConsumer`].
///
/// See the documentation of [`StreamConsumer::stream`] for details.
pub struct MessageStream<'a, C: ConsumerContext> {
wakers: &'a WakerSlab,
wakers: &'a WakerQueue,
consumer: &'a BaseConsumer<C>,
partition_queue: Option<&'a NativeQueue>,
slot: usize,
waker_id: usize,
}

impl<'a, C: ConsumerContext> MessageStream<'a, C> {
fn new(wakers: &'a WakerSlab, consumer: &'a BaseConsumer<C>) -> MessageStream<'a, C> {
fn new(wakers: &'a WakerQueue, consumer: &'a BaseConsumer<C>) -> MessageStream<'a, C> {
Self::new_with_optional_partition_queue(wakers, consumer, None)
}

fn new_with_partition_queue(
wakers: &'a WakerSlab,
wakers: &'a WakerQueue,
consumer: &'a BaseConsumer<C>,
partition_queue: &'a NativeQueue,
) -> MessageStream<'a, C> {
Self::new_with_optional_partition_queue(wakers, consumer, Some(partition_queue))
}

fn new_with_optional_partition_queue(
wakers: &'a WakerSlab,
wakers: &'a WakerQueue,
consumer: &'a BaseConsumer<C>,
partition_queue: Option<&'a NativeQueue>,
) -> MessageStream<'a, C> {
let slot = wakers.register();
let waker_id = wakers.register();
MessageStream {
wakers,
consumer,
partition_queue,
slot,
waker_id,
}
}

Expand Down Expand Up @@ -150,9 +163,9 @@ impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> {
EventPollResult::None => {
// Otherwise, we need to wait for a message to become available. Store
// the waker so that we are woken up if the queue flips from non-empty
// to empty. We have to store the waker repatedly in case this future
// to empty. We have to store the waker repeatedly in case this future
// migrates between tasks.
self.wakers.set_waker(self.slot, cx.waker().clone());
self.wakers.set_waker(self.waker_id, cx.waker().clone());

// Check whether a new message became available after we installed the
// waker. This avoids a race where `poll` returns None to indicate that
Expand All @@ -174,7 +187,7 @@ impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> {

impl<'a, C: ConsumerContext> Drop for MessageStream<'a, C> {
fn drop(&mut self) {
self.wakers.unregister(self.slot);
self.wakers.unset_waker(self.waker_id);
}
}

Expand All @@ -198,7 +211,7 @@ where
C: ConsumerContext,
{
base: Arc<BaseConsumer<C>>,
wakers: Arc<WakerSlab>,
wakers: Arc<WakerQueue>,
_shutdown_trigger: oneshot::Sender<()>,
_runtime: PhantomData<R>,
}
Expand Down Expand Up @@ -231,7 +244,7 @@ where
let base = Arc::new(BaseConsumer::new(config, native_config, context)?);
let native_ptr = base.client().native_ptr() as usize;

let wakers = Arc::new(WakerSlab::new());
let wakers = Arc::new(WakerQueue::new());
unsafe { enable_nonempty_callback(base.get_queue(), &wakers) }

// We need to make sure we poll the consumer at least once every max
Expand Down Expand Up @@ -359,11 +372,11 @@ where
self.base
.split_partition_queue(topic, partition)
.map(|queue| {
let wakers = Arc::new(WakerSlab::new());
let wakers = Arc::new(WakerQueue::new());
unsafe { enable_nonempty_callback(&queue.queue, &wakers) };
StreamPartitionQueue {
queue,
wakers,
waker_queue: wakers,
_consumer: self.clone(),
}
})
Expand Down Expand Up @@ -562,7 +575,7 @@ where
C: ConsumerContext,
{
queue: PartitionQueue<C>,
wakers: Arc<WakerSlab>,
waker_queue: Arc<WakerQueue>,
_consumer: Arc<StreamConsumer<C, R>>,
}

Expand All @@ -584,7 +597,7 @@ where
/// multiple consumers, not multiple partition streams.
pub fn stream(&self) -> MessageStream<'_, C> {
MessageStream::new_with_partition_queue(
&self.wakers,
&self.waker_queue,
&self._consumer.base,
&self.queue.queue,
)
Expand Down

0 comments on commit bd70a6e

Please sign in to comment.