Skip to content

Commit

Permalink
fix(futures-bounded): register replaced Streams/Futures as ready
Browse files Browse the repository at this point in the history
Currently, when a `Stream` or `Future` is replaced with a new one, it might happen that we miss a task wake-up and thus the task polling `FuturesMap` or `StreamMap` is never called again. This can be fixed by first removing the old `Stream`/`Future` and properly adding a new one via `.push`. The inner `SelectAll` calls a waker in that case which allows the outer task to continue.

Pull-Request: #4865.
  • Loading branch information
thomaseizinger authored Nov 14, 2023
1 parent b6eb2bf commit e3db5a5
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 62 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ rust-version = "1.73.0"

[workspace.dependencies]
asynchronous-codec = { version = "0.7.0" }
futures-bounded = { version = "0.2.1", path = "misc/futures-bounded" }
futures-bounded = { version = "0.2.2", path = "misc/futures-bounded" }
libp2p = { version = "0.53.0", path = "libp2p" }
libp2p-allow-block-list = { version = "0.3.0", path = "misc/allow-block-list" }
libp2p-autonat = { version = "0.12.0", path = "protocols/autonat" }
Expand Down
5 changes: 5 additions & 0 deletions misc/futures-bounded/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.2.2

- Fix an issue where `{Futures,Stream}Map` returns `Poll::Pending` despite being ready after an item has been replaced as part of `try_push`.
See [PR 4865](https://github.com/libp2p/rust-lib2pp/pulls/4865).

## 0.2.1

- Add `.len()` getter to `FuturesMap`, `FuturesSet`, `StreamMap` and `StreamSet`.
Expand Down
5 changes: 3 additions & 2 deletions misc/futures-bounded/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "futures-bounded"
version = "0.2.1"
version = "0.2.2"
edition = "2021"
rust-version.workspace = true
license = "MIT"
Expand All @@ -17,7 +17,8 @@ futures-util = { version = "0.3.29" }
futures-timer = "3.0.2"

[dev-dependencies]
tokio = { version = "1.34.0", features = ["macros", "rt"] }
tokio = { version = "1.34.0", features = ["macros", "rt", "sync"] }
futures = "0.3.28"

[lints]
workspace = true
122 changes: 89 additions & 33 deletions misc/futures-bounded/src/futures_map.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::future::Future;
use std::hash::Hash;
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use std::time::Duration;
use std::{future, mem};

use futures_timer::Delay;
use futures_util::future::BoxFuture;
Expand Down Expand Up @@ -38,6 +38,7 @@ impl<ID, O> FuturesMap<ID, O> {
impl<ID, O> FuturesMap<ID, O>
where
ID: Clone + Hash + Eq + Send + Unpin + 'static,
O: 'static,
{
/// Push a future into the map.
///
Expand All @@ -58,32 +59,30 @@ where
waker.wake();
}

match self.inner.iter_mut().find(|tagged| tagged.tag == future_id) {
None => {
self.inner.push(TaggedFuture {
tag: future_id,
inner: TimeoutFuture {
inner: future.boxed(),
timeout: Delay::new(self.timeout),
},
});

Ok(())
}
Some(existing) => {
let old_future = mem::replace(
&mut existing.inner,
TimeoutFuture {
inner: future.boxed(),
timeout: Delay::new(self.timeout),
},
);

Err(PushError::Replaced(old_future.inner))
}
let old = self.remove(future_id.clone());
self.inner.push(TaggedFuture {
tag: future_id,
inner: TimeoutFuture {
inner: future.boxed(),
timeout: Delay::new(self.timeout),
cancelled: false,
},
});
match old {
None => Ok(()),
Some(old) => Err(PushError::Replaced(old)),
}
}

pub fn remove(&mut self, id: ID) -> Option<BoxFuture<'static, O>> {
let tagged = self.inner.iter_mut().find(|s| s.tag == id)?;

let inner = mem::replace(&mut tagged.inner.inner, future::pending().boxed());
tagged.inner.cancelled = true;

Some(inner)
}

pub fn len(&self) -> usize {
self.inner.len()
}
Expand All @@ -104,39 +103,55 @@ where
}

pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(ID, Result<O, Timeout>)> {
let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx));
loop {
let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx));

match maybe_result {
None => {
self.empty_waker = Some(cx.waker().clone());
Poll::Pending
match maybe_result {
None => {
self.empty_waker = Some(cx.waker().clone());
return Poll::Pending;
}
Some((id, Ok(output))) => return Poll::Ready((id, Ok(output))),
Some((id, Err(TimeoutError::Timeout))) => {
return Poll::Ready((id, Err(Timeout::new(self.timeout))))
}
Some((_, Err(TimeoutError::Cancelled))) => continue,
}
Some((id, Ok(output))) => Poll::Ready((id, Ok(output))),
Some((id, Err(_timeout))) => Poll::Ready((id, Err(Timeout::new(self.timeout)))),
}
}
}

struct TimeoutFuture<F> {
inner: F,
timeout: Delay,

cancelled: bool,
}

impl<F> Future for TimeoutFuture<F>
where
F: Future + Unpin,
{
type Output = Result<F::Output, ()>;
type Output = Result<F::Output, TimeoutError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.cancelled {
return Poll::Ready(Err(TimeoutError::Cancelled));
}

if self.timeout.poll_unpin(cx).is_ready() {
return Poll::Ready(Err(()));
return Poll::Ready(Err(TimeoutError::Timeout));
}

self.inner.poll_unpin(cx).map(Ok)
}
}

enum TimeoutError {
Timeout,
Cancelled,
}

struct TaggedFuture<T, F> {
tag: T,
inner: F,
Expand All @@ -158,6 +173,8 @@ where

#[cfg(test)]
mod tests {
use futures::channel::oneshot;
use futures_util::task::noop_waker_ref;
use std::future::{pending, poll_fn, ready};
use std::pin::Pin;
use std::time::Instant;
Expand Down Expand Up @@ -197,6 +214,45 @@ mod tests {
assert!(result.is_err())
}

#[test]
fn resources_of_removed_future_are_cleaned_up() {
let mut futures = FuturesMap::new(Duration::from_millis(100), 1);

let _ = futures.try_push("ID", pending::<()>());
futures.remove("ID");

let poll = futures.poll_unpin(&mut Context::from_waker(noop_waker_ref()));
assert!(poll.is_pending());

assert_eq!(futures.len(), 0);
}

#[tokio::test]
async fn replaced_pending_future_is_polled() {
let mut streams = FuturesMap::new(Duration::from_millis(100), 3);

let (_tx1, rx1) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();

let _ = streams.try_push("ID1", rx1);
let _ = streams.try_push("ID2", rx2);

let _ = tx2.send(2);
let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
assert_eq!(id, "ID2");
assert_eq!(res.unwrap().unwrap(), 2);

let (new_tx1, new_rx1) = oneshot::channel();
let replaced = streams.try_push("ID1", new_rx1);
assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));

let _ = new_tx1.send(4);
let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;

assert_eq!(id, "ID1");
assert_eq!(res.unwrap().unwrap(), 4);
}

// Each future causes a delay, `Task` only has a capacity of 1, meaning they must be processed in sequence.
// We stop after NUM_FUTURES tasks, meaning the overall execution must at least take DELAY * NUM_FUTURES.
#[tokio::test]
Expand Down
5 changes: 4 additions & 1 deletion misc/futures-bounded/src/futures_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ impl<O> FuturesSet<O> {
}
}

impl<O> FuturesSet<O> {
impl<O> FuturesSet<O>
where
O: 'static,
{
/// Push a future into the list.
///
/// This method adds the given future to the list.
Expand Down
73 changes: 49 additions & 24 deletions misc/futures-bounded/src/stream_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,33 +53,22 @@ where
waker.wake();
}

match self.inner.iter_mut().find(|tagged| tagged.key == id) {
None => {
self.inner.push(TaggedStream::new(
id,
TimeoutStream {
inner: stream.boxed(),
timeout: Delay::new(self.timeout),
},
));

Ok(())
}
Some(existing) => {
let old = mem::replace(
&mut existing.inner,
TimeoutStream {
inner: stream.boxed(),
timeout: Delay::new(self.timeout),
},
);

Err(PushError::Replaced(old.inner))
}
let old = self.remove(id.clone());
self.inner.push(TaggedStream::new(
id,
TimeoutStream {
inner: stream.boxed(),
timeout: Delay::new(self.timeout),
},
));

match old {
None => Ok(()),
Some(old) => Err(PushError::Replaced(old)),
}
}

pub fn remove(&mut self, id: ID) -> Option<BoxStream<O>> {
pub fn remove(&mut self, id: ID) -> Option<BoxStream<'static, O>> {
let tagged = self.inner.iter_mut().find(|s| s.key == id)?;

let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed());
Expand Down Expand Up @@ -189,7 +178,9 @@ where

#[cfg(test)]
mod tests {
use futures::channel::mpsc;
use futures_util::stream::{once, pending};
use futures_util::SinkExt;
use std::future::{poll_fn, ready, Future};
use std::pin::Pin;
use std::time::Instant;
Expand Down Expand Up @@ -266,6 +257,40 @@ mod tests {
);
}

#[tokio::test]
async fn replaced_stream_is_still_registered() {
let mut streams = StreamMap::new(Duration::from_millis(100), 3);

let (mut tx1, rx1) = mpsc::channel(5);
let (mut tx2, rx2) = mpsc::channel(5);

let _ = streams.try_push("ID1", rx1);
let _ = streams.try_push("ID2", rx2);

let _ = tx2.send(2).await;
let _ = tx1.send(1).await;
let _ = tx2.send(3).await;
let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
assert_eq!(id, "ID1");
assert_eq!(res.unwrap().unwrap(), 1);
let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
assert_eq!(id, "ID2");
assert_eq!(res.unwrap().unwrap(), 2);
let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
assert_eq!(id, "ID2");
assert_eq!(res.unwrap().unwrap(), 3);

let (mut new_tx1, new_rx1) = mpsc::channel(5);
let replaced = streams.try_push("ID1", new_rx1);
assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));

let _ = new_tx1.send(4).await;
let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;

assert_eq!(id, "ID1");
assert_eq!(res.unwrap().unwrap(), 4);
}

// Each stream emits 1 item with delay, `Task` only has a capacity of 1, meaning they must be processed in sequence.
// We stop after NUM_STREAMS tasks, meaning the overall execution must at least take DELAY * NUM_STREAMS.
#[tokio::test]
Expand Down

0 comments on commit e3db5a5

Please sign in to comment.