diff --git a/bindings/kotlin/ldk-node-android/lib/build.gradle.kts b/bindings/kotlin/ldk-node-android/lib/build.gradle.kts index 6c8d29327..5e6775cdc 100644 --- a/bindings/kotlin/ldk-node-android/lib/build.gradle.kts +++ b/bindings/kotlin/ldk-node-android/lib/build.gradle.kts @@ -43,6 +43,7 @@ android { dependencies { implementation("net.java.dev.jna:jna:5.12.0@aar") implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk7") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4") implementation("androidx.appcompat:appcompat:1.4.0") implementation("androidx.core:core-ktx:1.7.0") api("org.slf4j:slf4j-api:1.7.30") diff --git a/bindings/kotlin/ldk-node-jvm/lib/build.gradle.kts b/bindings/kotlin/ldk-node-jvm/lib/build.gradle.kts index 33953ef7b..5c9e6c47c 100644 --- a/bindings/kotlin/ldk-node-jvm/lib/build.gradle.kts +++ b/bindings/kotlin/ldk-node-jvm/lib/build.gradle.kts @@ -46,6 +46,7 @@ dependencies { // Use the Kotlin JDK 8 standard library. implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4") implementation("net.java.dev.jna:jna:5.12.0") } diff --git a/bindings/ldk_node.udl b/bindings/ldk_node.udl index c6c7f4f38..b5c459d67 100644 --- a/bindings/ldk_node.udl +++ b/bindings/ldk_node.udl @@ -43,6 +43,8 @@ interface LDKNode { void stop(); Event? next_event(); Event wait_next_event(); + [Async] + Event next_event_async(); void event_handled(); PublicKey node_id(); sequence? listening_addresses(); diff --git a/src/event.rs b/src/event.rs index 70b49610f..a22607f7f 100644 --- a/src/event.rs +++ b/src/event.rs @@ -26,6 +26,8 @@ use lightning::util::ser::{Readable, ReadableArgs, Writeable, Writer}; use bitcoin::blockdata::locktime::absolute::LockTime; use bitcoin::secp256k1::PublicKey; use bitcoin::OutPoint; +use core::future::Future; +use core::task::{Poll, Waker}; use rand::{thread_rng, Rng}; use std::collections::VecDeque; use std::ops::Deref; @@ -125,7 +127,8 @@ pub struct EventQueue where L::Target: Logger, { - queue: Mutex>, + queue: Arc>>, + waker: Arc>>, notifier: Condvar, kv_store: Arc, logger: L, @@ -136,9 +139,10 @@ where L::Target: Logger, { pub(crate) fn new(kv_store: Arc, logger: L) -> Self { - let queue: Mutex> = Mutex::new(VecDeque::new()); + let queue = Arc::new(Mutex::new(VecDeque::new())); + let waker = Arc::new(Mutex::new(None)); let notifier = Condvar::new(); - Self { queue, notifier, kv_store, logger } + Self { queue, waker, notifier, kv_store, logger } } pub(crate) fn add_event(&self, event: Event) -> Result<(), Error> { @@ -149,6 +153,10 @@ where } self.notifier.notify_one(); + + if let Some(waker) = self.waker.lock().unwrap().take() { + waker.wake(); + } Ok(()) } @@ -157,6 +165,10 @@ where locked_queue.front().map(|e| e.clone()) } + pub(crate) async fn next_event_async(&self) -> Event { + EventFuture { event_queue: Arc::clone(&self.queue), waker: Arc::clone(&self.waker) }.await + } + pub(crate) fn wait_next_event(&self) -> Event { let locked_queue = self.notifier.wait_while(self.queue.lock().unwrap(), |queue| queue.is_empty()).unwrap(); @@ -170,6 +182,10 @@ where self.persist_queue(&locked_queue)?; } self.notifier.notify_one(); + + if let Some(waker) = self.waker.lock().unwrap().take() { + waker.wake(); + } Ok(()) } @@ -207,9 +223,10 @@ where ) -> Result { let (kv_store, logger) = args; let read_queue: EventQueueDeserWrapper = Readable::read(reader)?; - let queue: Mutex> = Mutex::new(read_queue.0); + let queue = Arc::new(Mutex::new(read_queue.0)); + let waker = Arc::new(Mutex::new(None)); let notifier = Condvar::new(); - Ok(Self { queue, notifier, kv_store, logger }) + Ok(Self { queue, waker, notifier, kv_store, logger }) } } @@ -240,6 +257,26 @@ impl Writeable for EventQueueSerWrapper<'_> { } } +struct EventFuture { + event_queue: Arc>>, + waker: Arc>>, +} + +impl Future for EventFuture { + type Output = Event; + + fn poll( + self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>, + ) -> core::task::Poll { + if let Some(event) = self.event_queue.lock().unwrap().front() { + Poll::Ready(event.clone()) + } else { + *self.waker.lock().unwrap() = Some(cx.waker().clone()); + Poll::Pending + } + } +} + pub(crate) struct EventHandler where L::Target: Logger, @@ -796,12 +833,14 @@ where mod tests { use super::*; use lightning::util::test_utils::{TestLogger, TestStore}; + use std::sync::atomic::{AtomicU16, Ordering}; + use std::time::Duration; - #[test] - fn event_queue_persistence() { + #[tokio::test] + async fn event_queue_persistence() { let store = Arc::new(TestStore::new(false)); let logger = Arc::new(TestLogger::new()); - let event_queue = EventQueue::new(Arc::clone(&store), Arc::clone(&logger)); + let event_queue = Arc::new(EventQueue::new(Arc::clone(&store), Arc::clone(&logger))); assert_eq!(event_queue.next_event(), None); let expected_event = Event::ChannelReady { @@ -814,6 +853,7 @@ mod tests { // Check we get the expected event and that it is returned until we mark it handled. for _ in 0..5 { assert_eq!(event_queue.wait_next_event(), expected_event); + assert_eq!(event_queue.next_event_async().await, expected_event); assert_eq!(event_queue.next_event(), Some(expected_event.clone())); } @@ -832,4 +872,96 @@ mod tests { event_queue.event_handled().unwrap(); assert_eq!(event_queue.next_event(), None); } + + #[tokio::test] + async fn event_queue_concurrency() { + let store = Arc::new(TestStore::new(false)); + let logger = Arc::new(TestLogger::new()); + let event_queue = Arc::new(EventQueue::new(Arc::clone(&store), Arc::clone(&logger))); + assert_eq!(event_queue.next_event(), None); + + let expected_event = Event::ChannelReady { + channel_id: ChannelId([23u8; 32]), + user_channel_id: UserChannelId(2323), + counterparty_node_id: None, + }; + + // Check `next_event_async` won't return if the queue is empty and always rather timeout. + tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(1)) => { + // Timeout + } + _ = event_queue.next_event_async() => { + panic!(); + } + } + + assert_eq!(event_queue.next_event(), None); + // Check we get the expected number of events when polling/enqueuing concurrently. + let enqueued_events = AtomicU16::new(0); + let received_events = AtomicU16::new(0); + let mut delayed_enqueue = false; + + for _ in 0..25 { + event_queue.add_event(expected_event.clone()).unwrap(); + enqueued_events.fetch_add(1, Ordering::SeqCst); + } + + loop { + tokio::select! { + _ = tokio::time::sleep(Duration::from_millis(10)), if !delayed_enqueue => { + event_queue.add_event(expected_event.clone()).unwrap(); + enqueued_events.fetch_add(1, Ordering::SeqCst); + delayed_enqueue = true; + } + e = event_queue.next_event_async() => { + assert_eq!(e, expected_event); + event_queue.event_handled().unwrap(); + received_events.fetch_add(1, Ordering::SeqCst); + + event_queue.add_event(expected_event.clone()).unwrap(); + enqueued_events.fetch_add(1, Ordering::SeqCst); + } + e = event_queue.next_event_async() => { + assert_eq!(e, expected_event); + event_queue.event_handled().unwrap(); + received_events.fetch_add(1, Ordering::SeqCst); + } + } + + if delayed_enqueue + && received_events.load(Ordering::SeqCst) == enqueued_events.load(Ordering::SeqCst) + { + break; + } + } + assert_eq!(event_queue.next_event(), None); + + // Check we operate correctly, even when mixing and matching blocking and async API calls. + let (tx, mut rx) = tokio::sync::watch::channel(()); + let thread_queue = Arc::clone(&event_queue); + let thread_event = expected_event.clone(); + std::thread::spawn(move || { + let e = thread_queue.wait_next_event(); + assert_eq!(e, thread_event); + thread_queue.event_handled().unwrap(); + tx.send(()).unwrap(); + }); + + let thread_queue = Arc::clone(&event_queue); + let thread_event = expected_event.clone(); + std::thread::spawn(move || { + // Sleep a bit before we enqueue the events everybody is waiting for. + std::thread::sleep(Duration::from_millis(20)); + thread_queue.add_event(thread_event.clone()).unwrap(); + thread_queue.add_event(thread_event.clone()).unwrap(); + }); + + let e = event_queue.next_event_async().await; + assert_eq!(e, expected_event.clone()); + event_queue.event_handled().unwrap(); + + rx.changed().await.unwrap(); + assert_eq!(event_queue.next_event(), None); + } } diff --git a/src/lib.rs b/src/lib.rs index 0e64e9fe3..6b5b66f69 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -811,6 +811,15 @@ impl Node { self.event_queue.next_event() } + /// Returns the next event in the event queue. + /// + /// Will asynchronously poll the event queue until the next event is ready. + /// + /// **Note:** this will always return the same event until handling is confirmed via [`Node::event_handled`]. + pub async fn next_event_async(&self) -> Event { + self.event_queue.next_event_async().await + } + /// Returns the next event in the event queue. /// /// Will block the current thread until the next event is available.