From e3db5a57813f6ba049ebd4eba26a3aaa2ff3568d Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 15 Nov 2023 09:57:43 +1100 Subject: [PATCH] fix(futures-bounded): register replaced `Stream`s/`Future`s as ready Currently, when a `Stream` or `Future` is replaced with a new one, it might happen that we miss a task wake-up and thus the task polling `FuturesMap` or `StreamMap` is never called again. This can be fixed by first removing the old `Stream`/`Future` and properly adding a new one via `.push`. The inner `SelectAll` calls a waker in that case which allows the outer task to continue. Pull-Request: #4865. --- Cargo.lock | 3 +- Cargo.toml | 2 +- misc/futures-bounded/CHANGELOG.md | 5 + misc/futures-bounded/Cargo.toml | 5 +- misc/futures-bounded/src/futures_map.rs | 122 +++++++++++++++++------- misc/futures-bounded/src/futures_set.rs | 5 +- misc/futures-bounded/src/stream_map.rs | 73 +++++++++----- 7 files changed, 153 insertions(+), 62 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8c252a1d8b8..0baa7d79bcc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1572,8 +1572,9 @@ dependencies = [ [[package]] name = "futures-bounded" -version = "0.2.1" +version = "0.2.2" dependencies = [ + "futures", "futures-timer", "futures-util", "tokio", diff --git a/Cargo.toml b/Cargo.toml index a79c55bbf91..793c46b0454 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,7 @@ rust-version = "1.73.0" [workspace.dependencies] asynchronous-codec = { version = "0.7.0" } -futures-bounded = { version = "0.2.1", path = "misc/futures-bounded" } +futures-bounded = { version = "0.2.2", path = "misc/futures-bounded" } libp2p = { version = "0.53.0", path = "libp2p" } libp2p-allow-block-list = { version = "0.3.0", path = "misc/allow-block-list" } libp2p-autonat = { version = "0.12.0", path = "protocols/autonat" } diff --git a/misc/futures-bounded/CHANGELOG.md b/misc/futures-bounded/CHANGELOG.md index 6e3b720fe4c..3a26f6436ba 100644 --- a/misc/futures-bounded/CHANGELOG.md +++ b/misc/futures-bounded/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.2.2 + +- Fix an issue where `{Futures,Stream}Map` returns `Poll::Pending` despite being ready after an item has been replaced as part of `try_push`. + See [PR 4865](https://github.com/libp2p/rust-lib2pp/pulls/4865). + ## 0.2.1 - Add `.len()` getter to `FuturesMap`, `FuturesSet`, `StreamMap` and `StreamSet`. diff --git a/misc/futures-bounded/Cargo.toml b/misc/futures-bounded/Cargo.toml index b7c4086c87d..42743a8ac85 100644 --- a/misc/futures-bounded/Cargo.toml +++ b/misc/futures-bounded/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "futures-bounded" -version = "0.2.1" +version = "0.2.2" edition = "2021" rust-version.workspace = true license = "MIT" @@ -17,7 +17,8 @@ futures-util = { version = "0.3.29" } futures-timer = "3.0.2" [dev-dependencies] -tokio = { version = "1.34.0", features = ["macros", "rt"] } +tokio = { version = "1.34.0", features = ["macros", "rt", "sync"] } +futures = "0.3.28" [lints] workspace = true diff --git a/misc/futures-bounded/src/futures_map.rs b/misc/futures-bounded/src/futures_map.rs index 8e8802254bc..fba3543f67b 100644 --- a/misc/futures-bounded/src/futures_map.rs +++ b/misc/futures-bounded/src/futures_map.rs @@ -1,9 +1,9 @@ use std::future::Future; use std::hash::Hash; -use std::mem; use std::pin::Pin; use std::task::{Context, Poll, Waker}; use std::time::Duration; +use std::{future, mem}; use futures_timer::Delay; use futures_util::future::BoxFuture; @@ -38,6 +38,7 @@ impl FuturesMap { impl FuturesMap where ID: Clone + Hash + Eq + Send + Unpin + 'static, + O: 'static, { /// Push a future into the map. /// @@ -58,32 +59,30 @@ where waker.wake(); } - match self.inner.iter_mut().find(|tagged| tagged.tag == future_id) { - None => { - self.inner.push(TaggedFuture { - tag: future_id, - inner: TimeoutFuture { - inner: future.boxed(), - timeout: Delay::new(self.timeout), - }, - }); - - Ok(()) - } - Some(existing) => { - let old_future = mem::replace( - &mut existing.inner, - TimeoutFuture { - inner: future.boxed(), - timeout: Delay::new(self.timeout), - }, - ); - - Err(PushError::Replaced(old_future.inner)) - } + let old = self.remove(future_id.clone()); + self.inner.push(TaggedFuture { + tag: future_id, + inner: TimeoutFuture { + inner: future.boxed(), + timeout: Delay::new(self.timeout), + cancelled: false, + }, + }); + match old { + None => Ok(()), + Some(old) => Err(PushError::Replaced(old)), } } + pub fn remove(&mut self, id: ID) -> Option> { + let tagged = self.inner.iter_mut().find(|s| s.tag == id)?; + + let inner = mem::replace(&mut tagged.inner.inner, future::pending().boxed()); + tagged.inner.cancelled = true; + + Some(inner) + } + pub fn len(&self) -> usize { self.inner.len() } @@ -104,15 +103,20 @@ where } pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(ID, Result)> { - let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx)); + loop { + let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx)); - match maybe_result { - None => { - self.empty_waker = Some(cx.waker().clone()); - Poll::Pending + match maybe_result { + None => { + self.empty_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + Some((id, Ok(output))) => return Poll::Ready((id, Ok(output))), + Some((id, Err(TimeoutError::Timeout))) => { + return Poll::Ready((id, Err(Timeout::new(self.timeout)))) + } + Some((_, Err(TimeoutError::Cancelled))) => continue, } - Some((id, Ok(output))) => Poll::Ready((id, Ok(output))), - Some((id, Err(_timeout))) => Poll::Ready((id, Err(Timeout::new(self.timeout)))), } } } @@ -120,23 +124,34 @@ where struct TimeoutFuture { inner: F, timeout: Delay, + + cancelled: bool, } impl Future for TimeoutFuture where F: Future + Unpin, { - type Output = Result; + type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.cancelled { + return Poll::Ready(Err(TimeoutError::Cancelled)); + } + if self.timeout.poll_unpin(cx).is_ready() { - return Poll::Ready(Err(())); + return Poll::Ready(Err(TimeoutError::Timeout)); } self.inner.poll_unpin(cx).map(Ok) } } +enum TimeoutError { + Timeout, + Cancelled, +} + struct TaggedFuture { tag: T, inner: F, @@ -158,6 +173,8 @@ where #[cfg(test)] mod tests { + use futures::channel::oneshot; + use futures_util::task::noop_waker_ref; use std::future::{pending, poll_fn, ready}; use std::pin::Pin; use std::time::Instant; @@ -197,6 +214,45 @@ mod tests { assert!(result.is_err()) } + #[test] + fn resources_of_removed_future_are_cleaned_up() { + let mut futures = FuturesMap::new(Duration::from_millis(100), 1); + + let _ = futures.try_push("ID", pending::<()>()); + futures.remove("ID"); + + let poll = futures.poll_unpin(&mut Context::from_waker(noop_waker_ref())); + assert!(poll.is_pending()); + + assert_eq!(futures.len(), 0); + } + + #[tokio::test] + async fn replaced_pending_future_is_polled() { + let mut streams = FuturesMap::new(Duration::from_millis(100), 3); + + let (_tx1, rx1) = oneshot::channel(); + let (tx2, rx2) = oneshot::channel(); + + let _ = streams.try_push("ID1", rx1); + let _ = streams.try_push("ID2", rx2); + + let _ = tx2.send(2); + let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await; + assert_eq!(id, "ID2"); + assert_eq!(res.unwrap().unwrap(), 2); + + let (new_tx1, new_rx1) = oneshot::channel(); + let replaced = streams.try_push("ID1", new_rx1); + assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_))); + + let _ = new_tx1.send(4); + let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await; + + assert_eq!(id, "ID1"); + assert_eq!(res.unwrap().unwrap(), 4); + } + // Each future causes a delay, `Task` only has a capacity of 1, meaning they must be processed in sequence. // We stop after NUM_FUTURES tasks, meaning the overall execution must at least take DELAY * NUM_FUTURES. #[tokio::test] diff --git a/misc/futures-bounded/src/futures_set.rs b/misc/futures-bounded/src/futures_set.rs index ea8f700991d..af7cedfcc85 100644 --- a/misc/futures-bounded/src/futures_set.rs +++ b/misc/futures-bounded/src/futures_set.rs @@ -23,7 +23,10 @@ impl FuturesSet { } } -impl FuturesSet { +impl FuturesSet +where + O: 'static, +{ /// Push a future into the list. /// /// This method adds the given future to the list. diff --git a/misc/futures-bounded/src/stream_map.rs b/misc/futures-bounded/src/stream_map.rs index 40294ce0fba..8464f432d02 100644 --- a/misc/futures-bounded/src/stream_map.rs +++ b/misc/futures-bounded/src/stream_map.rs @@ -53,33 +53,22 @@ where waker.wake(); } - match self.inner.iter_mut().find(|tagged| tagged.key == id) { - None => { - self.inner.push(TaggedStream::new( - id, - TimeoutStream { - inner: stream.boxed(), - timeout: Delay::new(self.timeout), - }, - )); - - Ok(()) - } - Some(existing) => { - let old = mem::replace( - &mut existing.inner, - TimeoutStream { - inner: stream.boxed(), - timeout: Delay::new(self.timeout), - }, - ); - - Err(PushError::Replaced(old.inner)) - } + let old = self.remove(id.clone()); + self.inner.push(TaggedStream::new( + id, + TimeoutStream { + inner: stream.boxed(), + timeout: Delay::new(self.timeout), + }, + )); + + match old { + None => Ok(()), + Some(old) => Err(PushError::Replaced(old)), } } - pub fn remove(&mut self, id: ID) -> Option> { + pub fn remove(&mut self, id: ID) -> Option> { let tagged = self.inner.iter_mut().find(|s| s.key == id)?; let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed()); @@ -189,7 +178,9 @@ where #[cfg(test)] mod tests { + use futures::channel::mpsc; use futures_util::stream::{once, pending}; + use futures_util::SinkExt; use std::future::{poll_fn, ready, Future}; use std::pin::Pin; use std::time::Instant; @@ -266,6 +257,40 @@ mod tests { ); } + #[tokio::test] + async fn replaced_stream_is_still_registered() { + let mut streams = StreamMap::new(Duration::from_millis(100), 3); + + let (mut tx1, rx1) = mpsc::channel(5); + let (mut tx2, rx2) = mpsc::channel(5); + + let _ = streams.try_push("ID1", rx1); + let _ = streams.try_push("ID2", rx2); + + let _ = tx2.send(2).await; + let _ = tx1.send(1).await; + let _ = tx2.send(3).await; + let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await; + assert_eq!(id, "ID1"); + assert_eq!(res.unwrap().unwrap(), 1); + let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await; + assert_eq!(id, "ID2"); + assert_eq!(res.unwrap().unwrap(), 2); + let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await; + assert_eq!(id, "ID2"); + assert_eq!(res.unwrap().unwrap(), 3); + + let (mut new_tx1, new_rx1) = mpsc::channel(5); + let replaced = streams.try_push("ID1", new_rx1); + assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_))); + + let _ = new_tx1.send(4).await; + let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await; + + assert_eq!(id, "ID1"); + assert_eq!(res.unwrap().unwrap(), 4); + } + // Each stream emits 1 item with delay, `Task` only has a capacity of 1, meaning they must be processed in sequence. // We stop after NUM_STREAMS tasks, meaning the overall execution must at least take DELAY * NUM_STREAMS. #[tokio::test]