From 8b3d4e407d2b38ce5ea47a2192c75c9c95013fed Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 25 Oct 2023 11:06:05 +1100 Subject: [PATCH] feat(futures-bounded): add support for streams This is the next logical extension to the `futures-bounded` crate by adding support for streams. For the moment, this isn't used in `rust-libp2p` but given that it is a general-purpose crate, putting this code here makes sense. Related: https://github.com/firezone/firezone/pull/2279. Pull-Request: #4616. --- Cargo.lock | 2 +- Cargo.toml | 2 +- misc/futures-bounded/CHANGELOG.md | 5 + misc/futures-bounded/Cargo.toml | 2 +- .../src/{map.rs => futures_map.rs} | 17 +- .../src/{set.rs => futures_set.rs} | 2 +- misc/futures-bounded/src/lib.rs | 24 +- misc/futures-bounded/src/stream_map.rs | 333 ++++++++++++++++++ misc/futures-bounded/src/stream_set.rs | 60 ++++ protocols/relay/src/priv_client/handler.rs | 2 +- 10 files changed, 427 insertions(+), 22 deletions(-) rename misc/futures-bounded/src/{map.rs => futures_map.rs} (94%) rename misc/futures-bounded/src/{set.rs => futures_set.rs} (94%) create mode 100644 misc/futures-bounded/src/stream_map.rs create mode 100644 misc/futures-bounded/src/stream_set.rs diff --git a/Cargo.lock b/Cargo.lock index 9ef1c930dc5..d409a1edf2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1584,7 +1584,7 @@ dependencies = [ [[package]] name = "futures-bounded" -version = "0.1.0" +version = "0.2.0" dependencies = [ "futures-timer", "futures-util", diff --git a/Cargo.toml b/Cargo.toml index 347ccb74c39..2c6f214c5d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,7 @@ resolver = "2" rust-version = "1.73.0" [workspace.dependencies] -futures-bounded = { version = "0.1.0", path = "misc/futures-bounded" } +futures-bounded = { version = "0.2.0", path = "misc/futures-bounded" } libp2p = { version = "0.53.0", path = "libp2p" } libp2p-allow-block-list = { version = "0.3.0", path = "misc/allow-block-list" } libp2p-autonat = { version = "0.12.0", path = "protocols/autonat" } diff --git a/misc/futures-bounded/CHANGELOG.md b/misc/futures-bounded/CHANGELOG.md index bd05a0f8261..90bd47f2f61 100644 --- a/misc/futures-bounded/CHANGELOG.md +++ b/misc/futures-bounded/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.2.0 + +- Add `StreamMap` type and remove `Future`-suffix from `PushError::ReplacedFuture` to reuse it for `StreamMap`. + See [PR 4616](https://github.com/libp2p/rust-lib2pp/pulls/4616). + ## 0.1.0 Initial release. diff --git a/misc/futures-bounded/Cargo.toml b/misc/futures-bounded/Cargo.toml index 52af0d228f2..9332667e476 100644 --- a/misc/futures-bounded/Cargo.toml +++ b/misc/futures-bounded/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "futures-bounded" -version = "0.1.0" +version = "0.2.0" edition = "2021" rust-version.workspace = true license = "MIT" diff --git a/misc/futures-bounded/src/map.rs b/misc/futures-bounded/src/futures_map.rs similarity index 94% rename from misc/futures-bounded/src/map.rs rename to misc/futures-bounded/src/futures_map.rs index cecf6070efe..5fd06037608 100644 --- a/misc/futures-bounded/src/map.rs +++ b/misc/futures-bounded/src/futures_map.rs @@ -10,7 +10,7 @@ use futures_util::future::BoxFuture; use futures_util::stream::FuturesUnordered; use futures_util::{FutureExt, StreamExt}; -use crate::Timeout; +use crate::{PushError, Timeout}; /// Represents a map of [`Future`]s. /// @@ -23,15 +23,6 @@ pub struct FuturesMap { full_waker: Option, } -/// Error of a future pushing -#[derive(PartialEq, Debug)] -pub enum PushError { - /// The length of the set is equal to the capacity - BeyondCapacity(F), - /// The set already contains the given future's ID - ReplacedFuture(F), -} - impl FuturesMap { pub fn new(timeout: Duration, capacity: usize) -> Self { Self { @@ -54,7 +45,7 @@ where /// If the length of the map is equal to the capacity, this method returns [PushError::BeyondCapacity], /// that contains the passed future. In that case, the future is not inserted to the map. /// If a future with the given `future_id` already exists, then the old future will be replaced by a new one. - /// In that case, the returned error [PushError::ReplacedFuture] contains the old future. + /// In that case, the returned error [PushError::Replaced] contains the old future. pub fn try_push(&mut self, future_id: ID, future: F) -> Result<(), PushError>> where F: Future + Send + 'static, @@ -88,7 +79,7 @@ where }, ); - Err(PushError::ReplacedFuture(old_future.inner)) + Err(PushError::Replaced(old_future.inner)) } } } @@ -187,7 +178,7 @@ mod tests { assert!(futures.try_push("ID", ready(())).is_ok()); matches!( futures.try_push("ID", ready(())), - Err(PushError::ReplacedFuture(_)) + Err(PushError::Replaced(_)) ); } diff --git a/misc/futures-bounded/src/set.rs b/misc/futures-bounded/src/futures_set.rs similarity index 94% rename from misc/futures-bounded/src/set.rs rename to misc/futures-bounded/src/futures_set.rs index 96140d82f9a..79a82fde110 100644 --- a/misc/futures-bounded/src/set.rs +++ b/misc/futures-bounded/src/futures_set.rs @@ -38,7 +38,7 @@ impl FuturesSet { match self.inner.try_push(self.id, future) { Ok(()) => Ok(()), Err(PushError::BeyondCapacity(w)) => Err(w), - Err(PushError::ReplacedFuture(_)) => unreachable!("we never reuse IDs"), + Err(PushError::Replaced(_)) => unreachable!("we never reuse IDs"), } } diff --git a/misc/futures-bounded/src/lib.rs b/misc/futures-bounded/src/lib.rs index e7b461dc822..6882a96f5e9 100644 --- a/misc/futures-bounded/src/lib.rs +++ b/misc/futures-bounded/src/lib.rs @@ -1,8 +1,13 @@ -mod map; -mod set; +mod futures_map; +mod futures_set; +mod stream_map; +mod stream_set; + +pub use futures_map::FuturesMap; +pub use futures_set::FuturesSet; +pub use stream_map::StreamMap; +pub use stream_set::StreamSet; -pub use map::{FuturesMap, PushError}; -pub use set::FuturesSet; use std::fmt; use std::fmt::Formatter; use std::time::Duration; @@ -25,4 +30,15 @@ impl fmt::Display for Timeout { } } +/// Error of a future pushing +#[derive(PartialEq, Debug)] +pub enum PushError { + /// The length of the set is equal to the capacity + BeyondCapacity(T), + /// The map already contained an item with this key. + /// + /// The old item is returned. + Replaced(T), +} + impl std::error::Error for Timeout {} diff --git a/misc/futures-bounded/src/stream_map.rs b/misc/futures-bounded/src/stream_map.rs new file mode 100644 index 00000000000..7fcdd15e132 --- /dev/null +++ b/misc/futures-bounded/src/stream_map.rs @@ -0,0 +1,333 @@ +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; +use std::time::Duration; + +use futures_timer::Delay; +use futures_util::stream::{BoxStream, SelectAll}; +use futures_util::{stream, FutureExt, Stream, StreamExt}; + +use crate::{PushError, Timeout}; + +/// Represents a map of [`Stream`]s. +/// +/// Each stream must finish within the specified time and the map never outgrows its capacity. +pub struct StreamMap { + timeout: Duration, + capacity: usize, + inner: SelectAll>>>, + empty_waker: Option, + full_waker: Option, +} + +impl StreamMap +where + ID: Clone + Unpin, +{ + pub fn new(timeout: Duration, capacity: usize) -> Self { + Self { + timeout, + capacity, + inner: Default::default(), + empty_waker: None, + full_waker: None, + } + } +} + +impl StreamMap +where + ID: Clone + PartialEq + Send + Unpin + 'static, + O: Send + 'static, +{ + /// Push a stream into the map. + pub fn try_push(&mut self, id: ID, stream: F) -> Result<(), PushError>> + where + F: Stream + Send + 'static, + { + if self.inner.len() >= self.capacity { + return Err(PushError::BeyondCapacity(stream.boxed())); + } + + if let Some(waker) = self.empty_waker.take() { + waker.wake(); + } + + match self.inner.iter_mut().find(|tagged| tagged.key == id) { + None => { + self.inner.push(TaggedStream::new( + id, + TimeoutStream { + inner: stream.boxed(), + timeout: Delay::new(self.timeout), + }, + )); + + Ok(()) + } + Some(existing) => { + let old = mem::replace( + &mut existing.inner, + TimeoutStream { + inner: stream.boxed(), + timeout: Delay::new(self.timeout), + }, + ); + + Err(PushError::Replaced(old.inner)) + } + } + } + + pub fn remove(&mut self, id: ID) -> Option> { + let tagged = self.inner.iter_mut().find(|s| s.key == id)?; + + let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed()); + tagged.exhausted = true; // Setting this will emit `None` on the next poll and ensure `SelectAll` cleans up the resources. + + Some(inner) + } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] // &mut Context is idiomatic. + pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if self.inner.len() < self.capacity { + return Poll::Ready(()); + } + + self.full_waker = Some(cx.waker().clone()); + + Poll::Pending + } + + pub fn poll_next_unpin( + &mut self, + cx: &mut Context<'_>, + ) -> Poll<(ID, Option>)> { + match futures_util::ready!(self.inner.poll_next_unpin(cx)) { + None => { + self.empty_waker = Some(cx.waker().clone()); + Poll::Pending + } + Some((id, Some(Ok(output)))) => Poll::Ready((id, Some(Ok(output)))), + Some((id, Some(Err(())))) => { + self.remove(id.clone()); // Remove stream, otherwise we keep reporting the timeout. + + Poll::Ready((id, Some(Err(Timeout::new(self.timeout))))) + } + Some((id, None)) => Poll::Ready((id, None)), + } + } +} + +struct TimeoutStream { + inner: S, + timeout: Delay, +} + +impl Stream for TimeoutStream +where + F: Stream + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.timeout.poll_unpin(cx).is_ready() { + return Poll::Ready(Some(Err(()))); + } + + self.inner.poll_next_unpin(cx).map(|a| a.map(Ok)) + } +} + +struct TaggedStream { + key: K, + inner: S, + + exhausted: bool, +} + +impl TaggedStream { + fn new(key: K, inner: S) -> Self { + Self { + key, + inner, + exhausted: false, + } + } +} + +impl Stream for TaggedStream +where + K: Clone + Unpin, + S: Stream + Unpin, +{ + type Item = (K, Option); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.exhausted { + return Poll::Ready(None); + } + + match futures_util::ready!(self.inner.poll_next_unpin(cx)) { + Some(item) => Poll::Ready(Some((self.key.clone(), Some(item)))), + None => { + self.exhausted = true; + + Poll::Ready(Some((self.key.clone(), None))) + } + } + } +} + +#[cfg(test)] +mod tests { + use futures_util::stream::{once, pending}; + use std::future::{poll_fn, ready, Future}; + use std::pin::Pin; + use std::time::Instant; + + use super::*; + + #[test] + fn cannot_push_more_than_capacity_tasks() { + let mut streams = StreamMap::new(Duration::from_secs(10), 1); + + assert!(streams.try_push("ID_1", once(ready(()))).is_ok()); + matches!( + streams.try_push("ID_2", once(ready(()))), + Err(PushError::BeyondCapacity(_)) + ); + } + + #[test] + fn cannot_push_the_same_id_few_times() { + let mut streams = StreamMap::new(Duration::from_secs(10), 5); + + assert!(streams.try_push("ID", once(ready(()))).is_ok()); + matches!( + streams.try_push("ID", once(ready(()))), + Err(PushError::Replaced(_)) + ); + } + + #[tokio::test] + async fn streams_timeout() { + let mut streams = StreamMap::new(Duration::from_millis(100), 1); + + let _ = streams.try_push("ID", pending::<()>()); + Delay::new(Duration::from_millis(150)).await; + let (_, result) = poll_fn(|cx| streams.poll_next_unpin(cx)).await; + + assert!(result.unwrap().is_err()) + } + + #[tokio::test] + async fn timed_out_stream_gets_removed() { + let mut streams = StreamMap::new(Duration::from_millis(100), 1); + + let _ = streams.try_push("ID", pending::<()>()); + Delay::new(Duration::from_millis(150)).await; + poll_fn(|cx| streams.poll_next_unpin(cx)).await; + + let poll = streams.poll_next_unpin(&mut Context::from_waker( + futures_util::task::noop_waker_ref(), + )); + assert!(poll.is_pending()) + } + + #[test] + fn removing_stream() { + let mut streams = StreamMap::new(Duration::from_millis(100), 1); + + let _ = streams.try_push("ID", stream::once(ready(()))); + + { + let cancelled_stream = streams.remove("ID"); + assert!(cancelled_stream.is_some()); + } + + let poll = streams.poll_next_unpin(&mut Context::from_waker( + futures_util::task::noop_waker_ref(), + )); + + assert!(poll.is_pending()); + assert_eq!( + streams.inner.len(), + 0, + "resources of cancelled streams are cleaned up properly" + ); + } + + // Each stream emits 1 item with delay, `Task` only has a capacity of 1, meaning they must be processed in sequence. + // We stop after NUM_STREAMS tasks, meaning the overall execution must at least take DELAY * NUM_STREAMS. + #[tokio::test] + async fn backpressure() { + const DELAY: Duration = Duration::from_millis(100); + const NUM_STREAMS: u32 = 10; + + let start = Instant::now(); + Task::new(DELAY, NUM_STREAMS, 1).await; + let duration = start.elapsed(); + + assert!(duration >= DELAY * NUM_STREAMS); + } + + struct Task { + item_delay: Duration, + num_streams: usize, + num_processed: usize, + inner: StreamMap, + } + + impl Task { + fn new(item_delay: Duration, num_streams: u32, capacity: usize) -> Self { + Self { + item_delay, + num_streams: num_streams as usize, + num_processed: 0, + inner: StreamMap::new(Duration::from_secs(60), capacity), + } + } + } + + impl Future for Task { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + while this.num_processed < this.num_streams { + match this.inner.poll_next_unpin(cx) { + Poll::Ready((_, Some(result))) => { + if result.is_err() { + panic!("Timeout is great than item delay") + } + + this.num_processed += 1; + continue; + } + Poll::Ready((_, None)) => { + continue; + } + _ => {} + } + + if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) { + // We push the constant ID to prove that user can use the same ID if the stream was finished + let maybe_future = this.inner.try_push(1u8, once(Delay::new(this.item_delay))); + assert!(maybe_future.is_ok(), "we polled for readiness"); + + continue; + } + + return Poll::Pending; + } + + Poll::Ready(()) + } + } +} diff --git a/misc/futures-bounded/src/stream_set.rs b/misc/futures-bounded/src/stream_set.rs new file mode 100644 index 00000000000..4fcb649fd49 --- /dev/null +++ b/misc/futures-bounded/src/stream_set.rs @@ -0,0 +1,60 @@ +use futures_util::stream::BoxStream; +use futures_util::Stream; +use std::task::{ready, Context, Poll}; +use std::time::Duration; + +use crate::{PushError, StreamMap, Timeout}; + +/// Represents a set of [Stream]s. +/// +/// Each stream must finish within the specified time and the list never outgrows its capacity. +pub struct StreamSet { + id: u32, + inner: StreamMap, +} + +impl StreamSet { + pub fn new(timeout: Duration, capacity: usize) -> Self { + Self { + id: 0, + inner: StreamMap::new(timeout, capacity), + } + } +} + +impl StreamSet +where + O: Send + 'static, +{ + /// Push a stream into the list. + /// + /// This method adds the given stream to the list. + /// If the length of the list is equal to the capacity, this method returns a error that contains the passed stream. + /// In that case, the stream is not added to the set. + pub fn try_push(&mut self, stream: F) -> Result<(), BoxStream> + where + F: Stream + Send + 'static, + { + self.id = self.id.wrapping_add(1); + + match self.inner.try_push(self.id, stream) { + Ok(()) => Ok(()), + Err(PushError::BeyondCapacity(w)) => Err(w), + Err(PushError::Replaced(_)) => unreachable!("we never reuse IDs"), + } + } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> { + self.inner.poll_ready_unpin(cx) + } + + pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll>> { + let (_, res) = ready!(self.inner.poll_next_unpin(cx)); + + Poll::Ready(res) + } +} diff --git a/protocols/relay/src/priv_client/handler.rs b/protocols/relay/src/priv_client/handler.rs index b2effdbde56..2c6db0008bd 100644 --- a/protocols/relay/src/priv_client/handler.rs +++ b/protocols/relay/src/priv_client/handler.rs @@ -269,7 +269,7 @@ impl Handler { Err(PushError::BeyondCapacity(_)) => log::warn!( "Dropping inbound circuit request to be denied from {src_peer_id} due to exceeding limit." ), - Err(PushError::ReplacedFuture(_)) => log::warn!( + Err(PushError::Replaced(_)) => log::warn!( "Dropping existing inbound circuit request to be denied from {src_peer_id} in favor of new one." ), Ok(()) => {}