diff --git a/Cargo.toml b/Cargo.toml index 39ff48424..1a0c4abf6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ authors = ["Sean McArthur "] readme = "README.md" license = "MIT OR Apache-2.0" edition = "2021" -rust-version = "1.63.0" +rust-version = "1.64.0" autotests = true [package.metadata.docs.rs] @@ -105,6 +105,7 @@ url = "2.4" bytes = "1.0" serde = "1.0" serde_urlencoded = "0.7.1" +tower = { version = "0.5.2", default-features = false, features = ["timeout", "util"] } tower-service = "0.3" futures-core = { version = "0.3.28", default-features = false } futures-util = { version = "0.3.28", default-features = false } @@ -169,7 +170,6 @@ quinn = { version = "0.11.1", default-features = false, features = ["rustls", "r slab = { version = "0.4.9", optional = true } # just to get minimal versions working with quinn futures-channel = { version = "0.3", optional = true } - [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] env_logger = "0.10" hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] } @@ -222,6 +222,11 @@ features = [ wasm-bindgen = { version = "0.2.89", features = ["serde-serialize"] } wasm-bindgen-test = "0.3" +[dev-dependencies] +tower = { version = "0.5.2", default-features = false, features = ["limit"] } +num_cpus = "1.0" +libc = "0" + [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(reqwest_unstable)'] } @@ -253,6 +258,10 @@ path = "examples/form.rs" name = "simple" path = "examples/simple.rs" +[[example]] +name = "connect_via_lower_priority_tokio_runtime" +path = "examples/connect_via_lower_priority_tokio_runtime.rs" + [[test]] name = "blocking" path = "tests/blocking.rs" diff --git a/examples/connect_via_lower_priority_tokio_runtime.rs b/examples/connect_via_lower_priority_tokio_runtime.rs new file mode 100644 index 000000000..33151d4a1 --- /dev/null +++ b/examples/connect_via_lower_priority_tokio_runtime.rs @@ -0,0 +1,264 @@ +#![deny(warnings)] +// This example demonstrates how to delegate the connect calls, which contain TLS handshakes, +// to a secondary tokio runtime of lower OS thread priority using a custom tower layer. +// This helps to ensure that long-running futures during handshake crypto operations don't block other I/O futures. +// +// This does introduce overhead of additional threads, channels, extra vtables, etc, +// so it is best suited to services with large numbers of incoming connections or that +// are otherwise very sensitive to any blocking futures. Or, you might want fewer threads +// and/or to use the current_thread runtime. +// +// This is using the `tokio` runtime and certain other dependencies: +// +// `tokio = { version = "1", features = ["full"] }` +// `num_cpus = "1.0"` +// `libc = "0"` +// `pin-project-lite = "0.2"` +// `tower = { version = "0.5", default-features = false}` + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::main] +async fn main() -> Result<(), reqwest::Error> { + background_threadpool::init_background_runtime(); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + let client = reqwest::Client::builder() + .connector_layer(background_threadpool::BackgroundProcessorLayer::new()) + .build() + .expect("should be able to build reqwest client"); + + let url = if let Some(url) = std::env::args().nth(1) { + url + } else { + println!("No CLI URL provided, using default."); + "https://hyper.rs".into() + }; + + eprintln!("Fetching {url:?}..."); + + let res = client.get(url).send().await?; + + eprintln!("Response: {:?} {}", res.version(), res.status()); + eprintln!("Headers: {:#?}\n", res.headers()); + + let body = res.text().await?; + + println!("{body}"); + + Ok(()) +} + +// separating out for convenience to avoid a million #[cfg(not(target_arch = "wasm32"))] +#[cfg(not(target_arch = "wasm32"))] +mod background_threadpool { + use std::{ + future::Future, + pin::Pin, + sync::OnceLock, + task::{Context, Poll}, + }; + + use futures_util::TryFutureExt; + use pin_project_lite::pin_project; + use tokio::{runtime::Handle, select, sync::mpsc::error::TrySendError}; + use tower::{BoxError, Layer, Service}; + + static CPU_HEAVY_THREAD_POOL: OnceLock< + tokio::sync::mpsc::Sender + Send + 'static>>>, + > = OnceLock::new(); + + pub(crate) fn init_background_runtime() { + std::thread::Builder::new() + .name("cpu-heavy-background-threadpool".to_string()) + .spawn(move || { + let rt = tokio::runtime::Builder::new_multi_thread() + .thread_name("cpu-heavy-background-pool-thread") + .worker_threads(num_cpus::get() as usize) + // ref: https://github.com/tokio-rs/tokio/issues/4941 + // consider uncommenting if seeing heavy task contention + // .disable_lifo_slot() + .on_thread_start(move || { + #[cfg(target_os = "linux")] + unsafe { + // Increase thread pool thread niceness, so they are lower priority + // than the foreground executor and don't interfere with I/O tasks + { + *libc::__errno_location() = 0; + if libc::nice(10) == -1 && *libc::__errno_location() != 0 { + let error = std::io::Error::last_os_error(); + log::error!("failed to set threadpool niceness: {}", error); + } + } + } + }) + .enable_all() + .build() + .unwrap_or_else(|e| panic!("cpu heavy runtime failed_to_initialize: {}", e)); + rt.block_on(async { + log::debug!("starting background cpu-heavy work"); + process_cpu_work().await; + }); + }) + .unwrap_or_else(|e| panic!("cpu heavy thread failed_to_initialize: {}", e)); + } + + #[cfg(not(target_arch = "wasm32"))] + async fn process_cpu_work() { + // we only use this channel for routing work, it should move pretty quick, it can be small + let (tx, mut rx) = tokio::sync::mpsc::channel(10); + // share the handle to the background channel globally + CPU_HEAVY_THREAD_POOL.set(tx).unwrap(); + + while let Some(work) = rx.recv().await { + tokio::task::spawn(work); + } + } + + // retrieve the sender to the background channel, and send the future over to it for execution + fn send_to_background_runtime(future: impl Future + Send + 'static) { + let tx = CPU_HEAVY_THREAD_POOL.get().expect( + "start up the secondary tokio runtime before sending to `CPU_HEAVY_THREAD_POOL`", + ); + + match tx.try_send(Box::pin(future)) { + Ok(_) => (), + Err(TrySendError::Closed(_)) => { + panic!("background cpu heavy runtime channel is closed") + } + Err(TrySendError::Full(msg)) => { + log::warn!( + "background cpu heavy runtime channel is full, task spawning loop delayed" + ); + let tx = tx.clone(); + Handle::current().spawn(async move { + tx.send(msg) + .await + .expect("background cpu heavy runtime channel is closed") + }); + } + } + } + + // This tower layer injects futures with a oneshot channel, and then sends them to the background runtime for processing. + // We don't use the Buffer service because that is intended to process sequentially on a single task, whereas we want to + // spawn a new task per call. + #[derive(Copy, Clone)] + pub struct BackgroundProcessorLayer {} + impl BackgroundProcessorLayer { + pub fn new() -> Self { + Self {} + } + } + impl Layer for BackgroundProcessorLayer { + type Service = BackgroundProcessor; + fn layer(&self, service: S) -> Self::Service { + BackgroundProcessor::new(service) + } + } + + impl std::fmt::Debug for BackgroundProcessorLayer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("BackgroundProcessorLayer").finish() + } + } + + // This tower service injects futures with a oneshot channel, and then sends them to the background runtime for processing. + #[derive(Debug, Clone)] + pub struct BackgroundProcessor { + inner: S, + } + + impl BackgroundProcessor { + pub fn new(inner: S) -> Self { + BackgroundProcessor { inner } + } + } + + impl Service for BackgroundProcessor + where + S: Service, + S::Response: Send + 'static, + S::Error: Into + Send, + S::Future: Send + 'static, + { + type Response = S::Response; + + type Error = BoxError; + + type Future = BackgroundResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + + // wrap our inner service's future with a future that writes to this oneshot channel + let (mut tx, rx) = tokio::sync::oneshot::channel(); + let future = async move { + select!( + _ = tx.closed() => { + // receiver already dropped, don't need to do anything + } + result = response.map_err(|err| Into::::into(err)) => { + // if this fails, the receiver already dropped, so we don't need to do anything + let _ = tx.send(result); + } + ) + }; + // send the wrapped future to the background + send_to_background_runtime(future); + + BackgroundResponseFuture::new(rx) + } + } + + // `BackgroundProcessor` response future + pin_project! { + #[derive(Debug)] + pub struct BackgroundResponseFuture { + #[pin] + rx: tokio::sync::oneshot::Receiver>, + } + } + + impl BackgroundResponseFuture { + pub(crate) fn new(rx: tokio::sync::oneshot::Receiver>) -> Self { + BackgroundResponseFuture { rx } + } + } + + impl Future for BackgroundResponseFuture + where + S: Send + 'static, + { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // now poll on the receiver end of the oneshot to get the result + match this.rx.poll(cx) { + Poll::Ready(v) => match v { + Ok(v) => Poll::Ready(v.map_err(Into::into)), + Err(err) => Poll::Ready(Err(Box::new(err) as BoxError)), + }, + Poll::Pending => Poll::Pending, + } + } + } +} + +// The [cfg(not(target_arch = "wasm32"))] above prevent building the tokio::main function +// for wasm32 target, because tokio isn't compatible with wasm32. +// If you aren't building for wasm32, you don't need that line. +// The two lines below avoid the "'main' function not found" error when building for wasm32 target. +#[cfg(any(target_arch = "wasm32"))] +fn main() {} diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 579050041..354a23205 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -1,27 +1,14 @@ #[cfg(any(feature = "native-tls", feature = "__rustls",))] use std::any::Any; +use std::future::Future; use std::net::IpAddr; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::Duration; use std::{collections::HashMap, convert::TryInto, net::SocketAddr}; use std::{fmt, str}; -use bytes::Bytes; -use http::header::{ - Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, - CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, -}; -use http::uri::Scheme; -use http::Uri; -use hyper_util::client::legacy::connect::HttpConnector; -#[cfg(feature = "default-tls")] -use native_tls_crate::TlsConnector; -use pin_project_lite::pin_project; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::time::Sleep; - use super::decoder::Accepts; use super::request::{Request, RequestBuilder}; use super::response::Response; @@ -30,13 +17,16 @@ use super::Body; use crate::async_impl::h3_client::connect::H3Connector; #[cfg(feature = "http3")] use crate::async_impl::h3_client::{H3Client, H3ResponseFuture}; -use crate::connect::Connector; +use crate::connect::{ + sealed::{Conn, Unnameable}, + BoxedConnectorLayer, BoxedConnectorService, Connector, ConnectorBuilder, +}; #[cfg(feature = "cookies")] use crate::cookie; #[cfg(feature = "hickory-dns")] use crate::dns::hickory::HickoryDnsResolver; use crate::dns::{gai::GaiResolver, DnsResolverWithOverrides, DynResolver, Resolve}; -use crate::error; +use crate::error::{self, BoxError}; use crate::into_url::try_uri; use crate::redirect::{self, remove_sensitive_headers}; #[cfg(feature = "__rustls")] @@ -48,11 +38,25 @@ use crate::Certificate; #[cfg(any(feature = "native-tls", feature = "__rustls"))] use crate::Identity; use crate::{IntoUrl, Method, Proxy, StatusCode, Url}; +use bytes::Bytes; +use http::header::{ + Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, + CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, +}; +use http::uri::Scheme; +use http::Uri; +use hyper_util::client::legacy::connect::HttpConnector; use log::debug; +#[cfg(feature = "default-tls")] +use native_tls_crate::TlsConnector; +use pin_project_lite::pin_project; #[cfg(feature = "http3")] use quinn::TransportConfig; #[cfg(feature = "http3")] use quinn::VarInt; +use tokio::time::Sleep; +use tower::util::BoxCloneSyncServiceLayer; +use tower::{Layer, Service}; type HyperResponseFuture = hyper_util::client::legacy::ResponseFuture; @@ -130,6 +134,7 @@ struct Config { tls_info: bool, #[cfg(feature = "__tls")] tls: TlsBackend, + connector_layers: Vec, http_version_pref: HttpVersionPref, http09_responses: bool, http1_title_case_headers: bool, @@ -185,7 +190,7 @@ impl ClientBuilder { /// Constructs a new `ClientBuilder`. /// /// This is the same as `Client::builder()`. - pub fn new() -> ClientBuilder { + pub fn new() -> Self { let mut headers: HeaderMap = HeaderMap::with_capacity(2); headers.insert(ACCEPT, HeaderValue::from_static("*/*")); @@ -233,6 +238,7 @@ impl ClientBuilder { tls_info: false, #[cfg(feature = "__tls")] tls: TlsBackend::default(), + connector_layers: Vec::new(), http_version_pref: HttpVersionPref::All, http09_responses: false, http1_title_case_headers: false, @@ -278,7 +284,9 @@ impl ClientBuilder { }, } } +} +impl ClientBuilder { /// Returns a `Client` that uses this `ClientBuilder` configuration. /// /// # Errors @@ -302,7 +310,7 @@ impl ClientBuilder { #[cfg(feature = "http3")] let mut h3_connector = None; - let mut connector = { + let mut connector_builder = { #[cfg(feature = "__tls")] fn user_agent(headers: &HeaderMap) -> Option { headers.get(USER_AGENT).cloned() @@ -445,7 +453,7 @@ impl ClientBuilder { tls.max_protocol_version(Some(protocol)); } - Connector::new_default_tls( + ConnectorBuilder::new_default_tls( http, tls, proxies.clone(), @@ -462,7 +470,7 @@ impl ClientBuilder { )? } #[cfg(feature = "native-tls")] - TlsBackend::BuiltNativeTls(conn) => Connector::from_built_default_tls( + TlsBackend::BuiltNativeTls(conn) => ConnectorBuilder::from_built_default_tls( http, conn, proxies.clone(), @@ -489,7 +497,7 @@ impl ClientBuilder { )?; } - Connector::new_rustls_tls( + ConnectorBuilder::new_rustls_tls( http, conn, proxies.clone(), @@ -684,7 +692,7 @@ impl ClientBuilder { )?; } - Connector::new_rustls_tls( + ConnectorBuilder::new_rustls_tls( http, tls, proxies.clone(), @@ -709,7 +717,7 @@ impl ClientBuilder { } #[cfg(not(feature = "__tls"))] - Connector::new( + ConnectorBuilder::new( http, proxies.clone(), config.local_address, @@ -719,8 +727,9 @@ impl ClientBuilder { ) }; - connector.set_timeout(config.connect_timeout); - connector.set_verbose(config.connection_verbose); + connector_builder.set_timeout(config.connect_timeout); + connector_builder.set_verbose(config.connection_verbose); + connector_builder.set_keepalive(config.tcp_keepalive); let mut builder = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()); @@ -763,7 +772,6 @@ impl ClientBuilder { builder.pool_timer(hyper_util::rt::TokioTimer::new()); builder.pool_idle_timeout(config.pool_idle_timeout); builder.pool_max_idle_per_host(config.pool_max_idle_per_host); - connector.set_keepalive(config.tcp_keepalive); if config.http09_responses { builder.http09_responses(true); @@ -801,7 +809,7 @@ impl ClientBuilder { } None => None, }, - hyper: builder.build(connector), + hyper: builder.build(connector_builder.build(config.connector_layers)), headers: config.headers, redirect_policy: config.redirect_policy, referer: config.referer, @@ -1953,6 +1961,43 @@ impl ClientBuilder { self.config.quic_send_window = Some(value); self } + + /// Adds a new Tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to the + /// base connector [`Service`](https://docs.rs/tower/latest/tower/trait.Service.html) which + /// is responsible for connection establishment.a + /// + /// Each subsequent invocation of this function will wrap previous layers. + /// + /// If configured, the `connect_timeout` will be the outermost layer. + /// + /// Example usage: + /// ``` + /// use std::time::Duration; + /// + /// # #[cfg(not(feature = "rustls-tls-no-provider"))] + /// let client = reqwest::Client::builder() + /// // resolved to outermost layer, meaning while we are waiting on concurrency limit + /// .connect_timeout(Duration::from_millis(200)) + /// // underneath the concurrency check, so only after concurrency limit lets us through + /// .connector_layer(tower::timeout::TimeoutLayer::new(Duration::from_millis(50))) + /// .connector_layer(tower::limit::concurrency::ConcurrencyLimitLayer::new(2)) + /// .build() + /// .unwrap(); + /// ``` + /// + pub fn connector_layer(mut self, layer: L) -> ClientBuilder + where + L: Layer + Clone + Send + Sync + 'static, + L::Service: + Service + Clone + Send + Sync + 'static, + >::Future: Send + 'static, + { + let layer = BoxCloneSyncServiceLayer::new(layer); + + self.config.connector_layers.push(layer); + + self + } } type HyperClient = hyper_util::client::legacy::Client; diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 73f25208f..700ce57a9 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -12,11 +12,16 @@ use std::time::Duration; use http::header::HeaderValue; use log::{error, trace}; use tokio::sync::{mpsc, oneshot}; +use tower::Layer; +use tower::Service; use super::request::{Request, RequestBuilder}; use super::response::Response; use super::wait; +use crate::connect::sealed::{Conn, Unnameable}; +use crate::connect::BoxedConnectorService; use crate::dns::Resolve; +use crate::error::BoxError; #[cfg(feature = "__tls")] use crate::tls; #[cfg(feature = "__rustls")] @@ -84,13 +89,15 @@ impl ClientBuilder { /// Constructs a new `ClientBuilder`. /// /// This is the same as `Client::builder()`. - pub fn new() -> ClientBuilder { + pub fn new() -> Self { ClientBuilder { inner: async_impl::ClientBuilder::new(), timeout: Timeout::default(), } } +} +impl ClientBuilder { /// Returns a `Client` that uses this `ClientBuilder` configuration. /// /// # Errors @@ -968,6 +975,35 @@ impl ClientBuilder { self.with_inner(|inner| inner.dns_resolver(resolver)) } + /// Adds a new Tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to the + /// base connector [`Service`](https://docs.rs/tower/latest/tower/trait.Service.html) which + /// is responsible for connection establishment. + /// + /// Each subsequent invocation of this function will wrap previous layers. + /// + /// Example usage: + /// ``` + /// use std::time::Duration; + /// + /// let client = reqwest::blocking::Client::builder() + /// // resolved to outermost layer, meaning while we are waiting on concurrency limit + /// .connect_timeout(Duration::from_millis(200)) + /// // underneath the concurrency check, so only after concurrency limit lets us through + /// .connector_layer(tower::timeout::TimeoutLayer::new(Duration::from_millis(50))) + /// .connector_layer(tower::limit::concurrency::ConcurrencyLimitLayer::new(2)) + /// .build() + /// .unwrap(); + /// ``` + pub fn connector_layer(self, layer: L) -> ClientBuilder + where + L: Layer + Clone + Send + Sync + 'static, + L::Service: + Service + Clone + Send + Sync + 'static, + >::Future: Send + 'static, + { + self.with_inner(|inner| inner.connector_layer(layer)) + } + // private fn with_inner(mut self, func: F) -> ClientBuilder diff --git a/src/connect.rs b/src/connect.rs index ff86ba3c9..c366473cc 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -8,9 +8,11 @@ use hyper_util::client::legacy::connect::{Connected, Connection}; use hyper_util::rt::TokioIo; #[cfg(feature = "default-tls")] use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; +use pin_project_lite::pin_project; +use tower::util::{BoxCloneSyncServiceLayer, MapRequestLayer}; +use tower::{timeout::TimeoutLayer, util::BoxCloneSyncService, ServiceBuilder}; use tower_service::Service; -use pin_project_lite::pin_project; use std::future::Future; use std::io::{self, IoSlice}; use std::net::IpAddr; @@ -24,13 +26,47 @@ use self::native_tls_conn::NativeTlsConn; #[cfg(feature = "__rustls")] use self::rustls_tls_conn::RustlsTlsConn; use crate::dns::DynResolver; -use crate::error::BoxError; +use crate::error::{cast_to_internal_error, BoxError}; use crate::proxy::{Proxy, ProxyScheme}; +use sealed::{Conn, Unnameable}; pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector; #[derive(Clone)] -pub(crate) struct Connector { +pub(crate) enum Connector { + // base service, with or without an embedded timeout + Simple(ConnectorService), + // at least one custom layer along with maybe an outer timeout layer + // from `builder.connect_timeout()` + WithLayers(BoxCloneSyncService), +} + +impl Service for Connector { + type Response = Conn; + type Error = BoxError; + type Future = Connecting; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self { + Connector::Simple(service) => service.poll_ready(cx), + Connector::WithLayers(service) => service.poll_ready(cx), + } + } + + fn call(&mut self, dst: Uri) -> Self::Future { + match self { + Connector::Simple(service) => service.call(dst), + Connector::WithLayers(service) => service.call(Unnameable(dst)), + } + } +} + +pub(crate) type BoxedConnectorService = BoxCloneSyncService; + +pub(crate) type BoxedConnectorLayer = + BoxCloneSyncServiceLayer; + +pub(crate) struct ConnectorBuilder { inner: Inner, proxies: Arc>, verbose: verbose::Wrapper, @@ -43,21 +79,70 @@ pub(crate) struct Connector { user_agent: Option, } -#[derive(Clone)] -enum Inner { - #[cfg(not(feature = "__tls"))] - Http(HttpConnector), - #[cfg(feature = "default-tls")] - DefaultTls(HttpConnector, TlsConnector), - #[cfg(feature = "__rustls")] - RustlsTls { - http: HttpConnector, - tls: Arc, - tls_proxy: Arc, - }, -} +impl ConnectorBuilder { + pub(crate) fn build(self, layers: Vec) -> Connector +where { + // construct the inner tower service + let mut base_service = ConnectorService { + inner: self.inner, + proxies: self.proxies, + verbose: self.verbose, + #[cfg(feature = "__tls")] + nodelay: self.nodelay, + #[cfg(feature = "__tls")] + tls_info: self.tls_info, + #[cfg(feature = "__tls")] + user_agent: self.user_agent, + simple_timeout: None, + }; + + if layers.is_empty() { + // we have no user-provided layers, only use concrete types + base_service.simple_timeout = self.timeout; + return Connector::Simple(base_service); + } + + // otherwise we have user provided layers + // so we need type erasure all the way through + // as well as mapping the unnameable type of the layers back to Uri for the inner service + let unnameable_service = ServiceBuilder::new() + .layer(MapRequestLayer::new(|request: Unnameable| request.0)) + .service(base_service); + let mut service = BoxCloneSyncService::new(unnameable_service); + + for layer in layers { + service = ServiceBuilder::new().layer(layer).service(service); + } + + // now we handle the concrete stuff - any `connect_timeout`, + // plus a final map_err layer we can use to cast default tower layer + // errors to internal errors + match self.timeout { + Some(timeout) => { + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(timeout)) + .service(service); + let service = ServiceBuilder::new() + .map_err(|error: BoxError| cast_to_internal_error(error)) + .service(service); + let service = BoxCloneSyncService::new(service); + + Connector::WithLayers(service) + } + None => { + // no timeout, but still map err + // no named timeout layer but we still map errors since + // we might have user-provided timeout layer + let service = ServiceBuilder::new().service(service); + let service = ServiceBuilder::new() + .map_err(|error: BoxError| cast_to_internal_error(error)) + .service(service); + let service = BoxCloneSyncService::new(service); + Connector::WithLayers(service) + } + } + } -impl Connector { #[cfg(not(feature = "__tls"))] pub(crate) fn new( mut http: HttpConnector, @@ -66,7 +151,7 @@ impl Connector { #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] interface: Option<&str>, nodelay: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -77,10 +162,10 @@ impl Connector { } http.set_nodelay(nodelay); - Connector { + ConnectorBuilder { inner: Inner::Http(http), - verbose: verbose::OFF, proxies, + verbose: verbose::OFF, timeout: None, } } @@ -96,7 +181,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> crate::Result + ) -> crate::Result where T: Into>, { @@ -125,7 +210,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -137,14 +222,14 @@ impl Connector { http.set_nodelay(nodelay); http.enforce_http(false); - Connector { + ConnectorBuilder { inner: Inner::DefaultTls(http, tls), proxies, verbose: verbose::OFF, - timeout: None, nodelay, tls_info, user_agent, + timeout: None, } } @@ -159,7 +244,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -180,7 +265,7 @@ impl Connector { (Arc::new(tls), Arc::new(tls_proxy)) }; - Connector { + ConnectorBuilder { inner: Inner::RustlsTls { http, tls, @@ -188,10 +273,10 @@ impl Connector { }, proxies, verbose: verbose::OFF, - timeout: None, nodelay, tls_info, user_agent, + timeout: None, } } @@ -203,6 +288,52 @@ impl Connector { self.verbose.0 = enabled; } + pub(crate) fn set_keepalive(&mut self, dur: Option) { + match &mut self.inner { + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), + #[cfg(feature = "__rustls")] + Inner::RustlsTls { http, .. } => http.set_keepalive(dur), + #[cfg(not(feature = "__tls"))] + Inner::Http(http) => http.set_keepalive(dur), + } + } +} + +#[allow(missing_debug_implementations)] +#[derive(Clone)] +pub(crate) struct ConnectorService { + inner: Inner, + proxies: Arc>, + verbose: verbose::Wrapper, + /// When there is a single timeout layer and no other layers, + /// we embed it directly inside our base Service::call(). + /// This lets us avoid an extra `Box::pin` indirection layer + /// since `tokio::time::Timeout` is `Unpin` + simple_timeout: Option, + #[cfg(feature = "__tls")] + nodelay: bool, + #[cfg(feature = "__tls")] + tls_info: bool, + #[cfg(feature = "__tls")] + user_agent: Option, +} + +#[derive(Clone)] +enum Inner { + #[cfg(not(feature = "__tls"))] + Http(HttpConnector), + #[cfg(feature = "default-tls")] + DefaultTls(HttpConnector, TlsConnector), + #[cfg(feature = "__rustls")] + RustlsTls { + http: HttpConnector, + tls: Arc, + tls_proxy: Arc, + }, +} + +impl ConnectorService { #[cfg(feature = "socks")] async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result { let dns = match proxy { @@ -449,17 +580,6 @@ impl Connector { self.connect_with_maybe_proxy(proxy_dst, true).await } - - pub fn set_keepalive(&mut self, dur: Option) { - match &mut self.inner { - #[cfg(feature = "default-tls")] - Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), - #[cfg(feature = "__rustls")] - Inner::RustlsTls { http, .. } => http.set_keepalive(dur), - #[cfg(not(feature = "__tls"))] - Inner::Http(http) => http.set_keepalive(dur), - } - } } fn into_uri(scheme: Scheme, host: Authority) -> Uri { @@ -487,7 +607,7 @@ where } } -impl Service for Connector { +impl Service for ConnectorService { type Response = Conn; type Error = BoxError; type Future = Connecting; @@ -498,7 +618,7 @@ impl Service for Connector { fn call(&mut self, dst: Uri) -> Self::Future { log::debug!("starting new connection: {dst:?}"); - let timeout = self.timeout; + let timeout = self.simple_timeout; for prox in self.proxies.iter() { if let Some(proxy_scheme) = prox.intercept(&dst) { return Box::pin(with_timeout( @@ -633,80 +753,87 @@ impl AsyncConnWithInfo for T {} type BoxConn = Box; -pin_project! { - /// Note: the `is_proxy` member means *is plain text HTTP proxy*. - /// This tells hyper whether the URI should be written in - /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or - /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. - pub(crate) struct Conn { - #[pin] - inner: BoxConn, - is_proxy: bool, - // Only needed for __tls, but #[cfg()] on fields breaks pin_project! - tls_info: bool, +pub(crate) mod sealed { + use super::*; + #[derive(Debug, Clone)] + pub struct Unnameable(pub(super) Uri); + + pin_project! { + /// Note: the `is_proxy` member means *is plain text HTTP proxy*. + /// This tells hyper whether the URI should be written in + /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or + /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. + #[allow(missing_debug_implementations)] + pub struct Conn { + #[pin] + pub(super)inner: BoxConn, + pub(super) is_proxy: bool, + // Only needed for __tls, but #[cfg()] on fields breaks pin_project! + pub(super) tls_info: bool, + } } -} -impl Connection for Conn { - fn connected(&self) -> Connected { - let connected = self.inner.connected().proxy(self.is_proxy); - #[cfg(feature = "__tls")] - if self.tls_info { - if let Some(tls_info) = self.inner.tls_info() { - connected.extra(tls_info) + impl Connection for Conn { + fn connected(&self) -> Connected { + let connected = self.inner.connected().proxy(self.is_proxy); + #[cfg(feature = "__tls")] + if self.tls_info { + if let Some(tls_info) = self.inner.tls_info() { + connected.extra(tls_info) + } else { + connected + } } else { connected } - } else { + #[cfg(not(feature = "__tls"))] connected } - #[cfg(not(feature = "__tls"))] - connected } -} -impl Read for Conn { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context, - buf: ReadBufCursor<'_>, - ) -> Poll> { - let this = self.project(); - Read::poll_read(this.inner, cx, buf) + impl Read for Conn { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: ReadBufCursor<'_>, + ) -> Poll> { + let this = self.project(); + Read::poll_read(this.inner, cx, buf) + } } -} -impl Write for Conn { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - let this = self.project(); - Write::poll_write(this.inner, cx, buf) - } + impl Write for Conn { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + Write::poll_write(this.inner, cx, buf) + } - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - let this = self.project(); - Write::poll_write_vectored(this.inner, cx, bufs) - } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let this = self.project(); + Write::poll_write_vectored(this.inner, cx, bufs) + } - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.project(); - Write::poll_flush(this.inner, cx) - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); + Write::poll_flush(this.inner, cx) + } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.project(); - Write::poll_shutdown(this.inner, cx) + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); + Write::poll_shutdown(this.inner, cx) + } } } diff --git a/src/error.rs b/src/error.rs index ca7413fd6..6a9f07e51 100644 --- a/src/error.rs +++ b/src/error.rs @@ -165,6 +165,18 @@ impl Error { } } +/// Converts from external types to reqwest's +/// internal equivalents. +/// +/// Currently only is used for `tower::timeout::error::Elapsed`. +pub(crate) fn cast_to_internal_error(error: BoxError) -> BoxError { + if error.is::() { + Box::new(crate::error::TimedOut) as BoxError + } else { + error + } +} + impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut builder = f.debug_struct("reqwest::Error"); diff --git a/tests/connector_layers.rs b/tests/connector_layers.rs new file mode 100644 index 000000000..1be18aeb8 --- /dev/null +++ b/tests/connector_layers.rs @@ -0,0 +1,374 @@ +#![cfg(not(target_arch = "wasm32"))] +#![cfg(not(feature = "rustls-tls-manual-roots-no-provider"))] +mod support; + +use std::time::Duration; + +use futures_util::future::join_all; +use tower::layer::util::Identity; +use tower::limit::ConcurrencyLimitLayer; +use tower::timeout::TimeoutLayer; + +use support::{delay_layer::DelayLayer, server}; + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn non_op_layer() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(Identity::new()) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn non_op_layer_with_timeout() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connector_layer(Identity::new()) + .connect_timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_connect_timeout_layer_never_returning() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connector_layer(TimeoutLayer::new(Duration::from_millis(100))) + .no_proxy() + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_connect_timeout_layer_slow() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(200))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(100))) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn multiple_timeout_layers_under_threshold() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(300))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(500))) + .connect_timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn multiple_timeout_layers_over_threshold() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connect_timeout(Duration::from_millis(50)) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_concurrency_limit_layer_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(200)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .no_proxy() + .build() + .unwrap(); + + // first call succeeds since no resource contention + let res = client.get(url.clone()).send().await; + assert!(res.is_ok()); + + // 3 calls where the second two wait on the first and time out + let mut futures = Vec::new(); + for _ in 0..3 { + futures.push(client.clone().get(url.clone()).send()); + } + + let all_res = join_all(futures).await; + + let timed_out = all_res + .into_iter() + .any(|res| res.is_err_and(|err| err.is_timeout())); + + assert!(timed_out, "at least one request should have timed out"); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_concurrency_limit_layer_success() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(1000)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .no_proxy() + .build() + .unwrap(); + + // first call succeeds since no resource contention + let res = client.get(url.clone()).send().await; + assert!(res.is_ok()); + + // 3 calls of which all are individually below the inner timeout + // and the sum is below outer timeout which affects the final call which waited the whole time + let mut futures = Vec::new(); + for _ in 0..3 { + futures.push(client.clone().get(url.clone()).send()); + } + + let all_res = join_all(futures).await; + + for res in all_res.into_iter() { + assert!( + res.is_ok(), + "neither outer long timeout or inner short timeout should be exceeded" + ); + } +} + +#[cfg(feature = "blocking")] +#[test] +fn non_op_layer_blocking_client() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(Identity::new()) + .build() + .unwrap(); + + let res = client.get(url).send(); + + assert!(res.is_ok()); +} + +#[cfg(feature = "blocking")] +#[test] +fn timeout_layer_blocking_client() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send(); + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(feature = "blocking")] +#[test] +fn concurrency_layer_blocking_client_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(200)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .build() + .unwrap(); + + let res = client.get(url.clone()).send(); + + assert!(res.is_ok()); + + // 3 calls where the second two wait on the first and time out + let mut join_handles = Vec::new(); + for _ in 0..3 { + let client = client.clone(); + let url = url.clone(); + let join_handle = std::thread::spawn(move || client.get(url.clone()).send()); + join_handles.push(join_handle); + } + + let timed_out = join_handles + .into_iter() + .any(|handle| handle.join().unwrap().is_err_and(|err| err.is_timeout())); + + assert!(timed_out, "at least one request should have timed out"); +} + +#[cfg(feature = "blocking")] +#[test] +fn concurrency_layer_blocking_client_success() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(1000)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .build() + .unwrap(); + + let res = client.get(url.clone()).send(); + + assert!(res.is_ok()); + + // 3 calls of which all are individually below the inner timeout + // and the sum is below outer timeout which affects the final call which waited the whole time + let mut join_handles = Vec::new(); + for _ in 0..3 { + let client = client.clone(); + let url = url.clone(); + let join_handle = std::thread::spawn(move || client.get(url.clone()).send()); + join_handles.push(join_handle); + } + + for handle in join_handles { + let res = handle.join().unwrap(); + assert!( + res.is_ok(), + "neither outer long timeout or inner short timeout should be exceeded" + ); + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn no_generic_bounds_required_for_client_new() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::new(); + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(feature = "blocking")] +#[test] +fn no_generic_bounds_required_for_client_new_blocking() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::new(); + let res = client.get(url).send(); + + assert!(res.is_ok()); +} diff --git a/tests/support/delay_layer.rs b/tests/support/delay_layer.rs new file mode 100644 index 000000000..b8eec42a1 --- /dev/null +++ b/tests/support/delay_layer.rs @@ -0,0 +1,119 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use pin_project_lite::pin_project; +use tokio::time::Sleep; +use tower::{BoxError, Layer, Service}; + +/// This tower layer injects an arbitrary delay before calling downstream layers. +#[derive(Clone)] +pub struct DelayLayer { + delay: Duration, +} + +impl DelayLayer { + pub const fn new(delay: Duration) -> Self { + DelayLayer { delay } + } +} + +impl Layer for DelayLayer { + type Service = Delay; + fn layer(&self, service: S) -> Self::Service { + Delay::new(service, self.delay) + } +} + +impl std::fmt::Debug for DelayLayer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("DelayLayer") + .field("delay", &self.delay) + .finish() + } +} + +/// This tower service injects an arbitrary delay before calling downstream layers. +#[derive(Debug, Clone)] +pub struct Delay { + inner: S, + delay: Duration, +} +impl Delay { + pub fn new(inner: S, delay: Duration) -> Self { + Delay { inner, delay } + } +} + +impl Service for Delay +where + S: Service, + S::Error: Into, +{ + type Response = S::Response; + + type Error = BoxError; + + type Future = ResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + let sleep = tokio::time::sleep(self.delay); + + ResponseFuture::new(response, sleep) + } +} + +// `Delay` response future +pin_project! { + #[derive(Debug)] + pub struct ResponseFuture { + #[pin] + response: S, + #[pin] + sleep: Sleep, + } +} + +impl ResponseFuture { + pub(crate) fn new(response: S, sleep: Sleep) -> Self { + ResponseFuture { response, sleep } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // First poll the sleep until complete + match this.sleep.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(_) => {} + } + + // Then poll the inner future + match this.response.poll(cx) { + Poll::Ready(v) => Poll::Ready(v.map_err(Into::into)), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index c796956d8..9d4ce7b9b 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -1,3 +1,4 @@ +pub mod delay_layer; pub mod delay_server; pub mod server; diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 79a6fbb4d..71dc0ce66 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -337,6 +337,24 @@ fn timeout_blocking_request() { assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); } +#[cfg(feature = "blocking")] +#[test] +fn connect_timeout_blocking_request() { + let _ = env_logger::try_init(); + + let client = reqwest::blocking::Client::builder() + .connect_timeout(Duration::from_millis(100)) + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let err = client.get(url).send().unwrap_err(); + + assert!(err.is_timeout()); +} + #[cfg(feature = "blocking")] #[cfg(feature = "stream")] #[test]